| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi import FastAPI, HTTPException, File, UploadFile, Form |
| | from fastapi.responses import JSONResponse, FileResponse |
| | from pydantic import BaseModel |
| | from typing import Optional |
| | import subprocess |
| | import os |
| | import logging |
| | from inference_transform import process_smiles, process_pdb, process_sdf, extract_and_convert_to_sdf, is_valid_smiles |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI() |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=['*'], |
| | allow_credentials=True, |
| | allow_methods=['*'], |
| | allow_headers=['*'] |
| | ) |
| |
|
| | sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" |
| |
|
| | class InferenceRequest(BaseModel): |
| | prompt: str |
| | max_tokens: int = 256 |
| | temperature: float = 1.0 |
| |
|
| | @app.post("/predict_base") |
| | async def predict_base( |
| | prompt: str = Form(...), |
| | max_tokens: int = Form(256), |
| | temperature: float = Form(1.0), |
| | file: Optional[UploadFile] = File(None) |
| | ): |
| | try: |
| | if file: |
| | file_path = f"/tmp/{file.filename}" |
| | with open(file_path, "wb") as f: |
| | f.write(file.file.read()) |
| | if file.filename.endswith(".pdb"): |
| | prompt += f" {process_pdb(file_path)}" |
| | elif file.filename.endswith(".sdf"): |
| | prompt += f" {process_sdf(file_path)}" |
| | else: |
| | try: |
| | sdf_file = extract_and_convert_to_sdf(prompt) |
| | if sdf_file: |
| | prompt += f" {sdf_file}" |
| | except ValueError as e: |
| | logger.info(str(e)) |
| |
|
| | command = [ |
| | "python", |
| | "/root/CHEMISTral7Bv0.3/mistral_chat_script.py", |
| | "/root/mistral_models/7B-v0.3/", |
| | prompt, |
| | f"--max_tokens={max_tokens}", |
| | f"--temperature={temperature}", |
| | "--instruct" |
| | ] |
| |
|
| | logger.info(f"Running command: {' '.join(command)}") |
| | result = subprocess.run(command, capture_output=True, text=True) |
| | |
| | if result.returncode != 0: |
| | logger.error(f"Command failed with return code {result.returncode}") |
| | logger.error(f"stderr: {result.stderr}") |
| | raise HTTPException(status_code=500, detail=result.stderr) |
| | |
| | response = result.stdout.strip() |
| | sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" |
| |
|
| | return { |
| | "response": response, |
| | "sdf_file_path": sdf_file_path |
| | } |
| | except Exception as e: |
| | logger.exception("Exception occurred during inference.") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/predict") |
| | async def predict_alternative( |
| | prompt: str = Form(...), |
| | max_tokens: int = Form(256), |
| | temperature: float = Form(1.0), |
| | file: Optional[UploadFile] = File(None) |
| | ): |
| | try: |
| | if file: |
| | file_path = f"/tmp/{file.filename}" |
| | with open(file_path, "wb") as f: |
| | f.write(await file.read()) |
| | if file.filename.endswith(".pdb"): |
| | prompt += f" {process_pdb(file_path)}" |
| | elif file.filename.endswith(".sdf"): |
| | prompt += f" {process_sdf(file_path)}" |
| | else: |
| | try: |
| | sdf_file = extract_and_convert_to_sdf(prompt) |
| | if sdf_file: |
| | prompt += f" {sdf_file}" |
| | except ValueError as e: |
| | logger.info(str(e)) |
| |
|
| | command = [ |
| | "python", |
| | "/root/CHEMISTral7Bv0.3/mistral_chat_script.py", |
| | "/root/mistral_models/7B-v0.3/", |
| | prompt, |
| | f"--max_tokens={max_tokens}", |
| | f"--temperature={temperature}", |
| | "--instruct", |
| | "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors" |
| | ] |
| | logger.info(f"Running command: {' '.join(command)}") |
| | result = subprocess.run(command, capture_output=True, text=True) |
| | if result.returncode != 0: |
| | logger.error(f"Command failed with return code {result.returncode}") |
| | logger.error(f"stderr: {result.stderr}") |
| | raise HTTPException(status_code=500, detail=result.stderr) |
| |
|
| | response = result.stdout.strip() |
| | sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" |
| |
|
| | |
| | return FileResponse(sdf_file_path, media_type='chemical/x-mdl-sdfile', filename="Conformer3D_COMPOUND_CID_240.sdf") |
| |
|
| | except Exception as e: |
| | logger.exception("Exception occurred during inference.") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @app.get("/download_sdf") |
| | async def download_sdf(): |
| | try: |
| | return FileResponse(path=sdf_file_path, filename="Conformer3D_COMPOUND_CID_240.sdf") |
| | except Exception as e: |
| | logger.exception("Exception occurred while sending SDF file.") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|