| |
| import os |
| import time |
| import uuid |
| import base64 |
| import subprocess |
| from pathlib import Path |
| from typing import Optional |
|
|
| import flask |
| from flask import request, jsonify |
| import requests |
| from bs4 import BeautifulSoup |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| |
| |
| |
| MODEL_ID = os.environ.get("MODEL_ID", "HuggingFaceTB/SmolLM2-360M-Instruct") |
| PORT = int(os.environ.get("PORT", 7860)) |
| FILES_DIR = Path(os.environ.get("FILES_DIR", "engine_files")) |
| FILES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| GEN_DEFAULTS = { |
| "max_new_tokens": int(os.environ.get("MAX_NEW_TOKENS", 512)), |
| "do_sample": os.environ.get("DO_SAMPLE", "true").lower() == "true", |
| "temperature": float(os.environ.get("TEMPERATURE", 0.6)), |
| "top_p": float(os.environ.get("TOP_P", 0.9)), |
| "repetition_penalty": float(os.environ.get("REPETITION_PENALTY", 1.05)), |
| } |
|
|
| |
| MODEL_CONTEXT_TOKENS = int(os.environ.get("MODEL_CONTEXT_TOKENS", 4096)) |
|
|
| |
| |
| |
| app = flask.Flask(__name__) |
| _start_time = time.time() |
|
|
| print(f"🔄 Loading model {MODEL_ID} ... (this may take a while the first time)") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
| |
| if torch.cuda.is_available(): |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
|
|
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype, low_cpu_mem_usage=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
| print(f"✅ Model loaded: {MODEL_ID} on {device} (dtype={dtype})") |
|
|
| |
| |
| |
| def safe_filename(name: str) -> str: |
| safe = "".join(c for c in name if c.isalnum() or c in "._- ").strip() |
| if not safe: |
| safe = str(uuid.uuid4()) |
| return safe |
|
|
| def _truncate_prompt_for_context(prompt: str, max_new_tokens: int) -> str: |
| """ |
| Truncate prompt so that total tokens (prompt + new tokens) <= MODEL_CONTEXT_TOKENS. |
| Keeps the last part of prompt (most recent user content). |
| """ |
| |
| margin = 32 |
| allowed_prompt_tokens = max(MODEL_CONTEXT_TOKENS - max_new_tokens - margin, 32) |
| |
| toks = tokenizer.encode(prompt, add_special_tokens=False) |
| if len(toks) <= allowed_prompt_tokens: |
| return prompt |
| |
| toks = toks[-allowed_prompt_tokens:] |
| return tokenizer.decode(toks, clean_up_tokenization_spaces=True) |
|
|
| def generate_from_model(prompt: str, |
| max_new_tokens: Optional[int] = None, |
| do_sample: Optional[bool] = None, |
| temperature: Optional[float] = None, |
| top_p: Optional[float] = None, |
| repetition_penalty: Optional[float] = None) -> str: |
| cfg = { |
| "max_new_tokens": int(max_new_tokens) if max_new_tokens is not None else GEN_DEFAULTS["max_new_tokens"], |
| "do_sample": do_sample if do_sample is not None else GEN_DEFAULTS["do_sample"], |
| "temperature": float(temperature) if temperature is not None else GEN_DEFAULTS["temperature"], |
| "top_p": float(top_p) if top_p is not None else GEN_DEFAULTS["top_p"], |
| "repetition_penalty": float(repetition_penalty) if repetition_penalty is not None else GEN_DEFAULTS["repetition_penalty"], |
| } |
|
|
| |
| prompt = _truncate_prompt_for_context(prompt, cfg["max_new_tokens"]) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MODEL_CONTEXT_TOKENS).to(device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=cfg["max_new_tokens"], |
| do_sample=cfg["do_sample"], |
| temperature=cfg["temperature"], |
| top_p=cfg["top_p"], |
| repetition_penalty=cfg["repetition_penalty"], |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| text = tokenizer.decode(out[0], skip_special_tokens=True) |
| return text |
|
|
| |
| |
| |
|
|
| @app.route("/health", methods=["GET"]) |
| def health(): |
| uptime = time.time() - _start_time |
| try: |
| import psutil |
| mem = psutil.virtual_memory()._asdict() |
| except Exception: |
| mem = {"info": "psutil not installed or unavailable"} |
| return jsonify({ |
| "status": "ok", |
| "uptime_seconds": int(uptime), |
| "device": str(device), |
| "model_id": MODEL_ID, |
| "memory": mem |
| }) |
|
|
| @app.route("/model_info", methods=["GET"]) |
| def model_info(): |
| return jsonify({ |
| "model_id": MODEL_ID, |
| "device": str(device), |
| "dtype": str(dtype), |
| "vocab_size": getattr(tokenizer, "vocab_size", None), |
| "tokenizer_fast": getattr(tokenizer, "is_fast", None), |
| }) |
|
|
| |
| @app.route("/chat", methods=["POST"]) |
| def chat(): |
| """ |
| POST JSON: |
| { |
| "message": "text", |
| "max_new_tokens": 256, # optional |
| "do_sample": true/false, # optional |
| "temperature": 0.7, # optional |
| "top_p": 0.9, # optional |
| "repetition_penalty": 1.05 # optional |
| } |
| """ |
| try: |
| body = request.get_json(force=True) |
| msg = (body.get("message") or body.get("prompt") or "").strip() |
| if not msg: |
| return jsonify({"error": "No message provided"}), 400 |
|
|
| max_new_tokens = body.get("max_new_tokens") |
| do_sample = body.get("do_sample") |
| temperature = body.get("temperature") |
| top_p = body.get("top_p") |
| repetition_penalty = body.get("repetition_penalty") |
|
|
| |
| prompt = f"User: {msg}\nAssistant:" |
|
|
| full = generate_from_model(prompt, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty) |
|
|
| |
| if "Assistant:" in full: |
| reply = full.split("Assistant:", 1)[1].strip() |
| else: |
| |
| reply = full.replace(prompt, "").strip() |
|
|
| return jsonify({"reply": reply}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/search", methods=["POST"]) |
| def search(): |
| """ |
| POST JSON: |
| { "q": "your query", "top_k": 5 } |
| """ |
| try: |
| data = request.get_json(force=True) |
| q = (data.get("q") or "").strip() |
| if not q: |
| return jsonify({"error": "Query 'q' missing"}), 400 |
| top_k = int(data.get("top_k", 5)) |
|
|
| url = "https://html.duckduckgo.com/html/" |
| r = requests.post(url, data={"q": q}, timeout=10) |
| r.raise_for_status() |
| soup = BeautifulSoup(r.text, "html.parser") |
|
|
| results = [] |
| |
| anchors = soup.select("a.result__a")[:top_k] |
| for a in anchors: |
| title = a.get_text().strip() |
| href = a.get("href") |
| |
| snippet = "" |
| parent = a.parent |
| if parent: |
| s = parent.select_one("a.result__snippet") or parent.select_one(".result__snippet") |
| if s: |
| snippet = s.get_text().strip() |
| results.append({"title": title, "url": href, "snippet": snippet}) |
| return jsonify({"query": q, "results": results}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/fetch_url", methods=["POST"]) |
| def fetch_url(): |
| """ |
| POST JSON: { "url": "https://...", "max_chars": 10000 } |
| """ |
| try: |
| data = request.get_json(force=True) |
| url = data.get("url", "") |
| if not url: |
| return jsonify({"error": "url missing"}), 400 |
| max_chars = int(data.get("max_chars", 10000)) |
| r = requests.get(url, timeout=10) |
| r.raise_for_status() |
| text = r.text |
| if len(text) > max_chars: |
| text = text[:max_chars] + "\n\n...[truncated]" |
| return jsonify({"url": url, "content": text}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/summarize", methods=["POST"]) |
| def summarize(): |
| """ |
| POST JSON: { "text": "...", "max_new_tokens": 200 } |
| """ |
| try: |
| data = request.get_json(force=True) |
| text = (data.get("text") or "").strip() |
| if not text: |
| return jsonify({"error": "text missing"}), 400 |
| max_new_tokens = int(data.get("max_new_tokens", GEN_DEFAULTS["max_new_tokens"])) |
| prompt = f"Summarize the following text concisely and clearly:\n\n{text}\n\nSummary:" |
| out = generate_from_model(prompt, max_new_tokens=max_new_tokens) |
| if "Summary:" in out: |
| summary = out.split("Summary:", 1)[1].strip() |
| else: |
| summary = out.replace(prompt, "").strip() |
| return jsonify({"summary": summary}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/run_code", methods=["POST"]) |
| def run_code(): |
| """ |
| POST JSON: { "code": "print('hi')", "timeout": 8 } |
| Returns stdout, stderr, exit_code |
| """ |
| try: |
| data = request.get_json(force=True) |
| code = data.get("code", "") |
| if not code: |
| return jsonify({"error": "code missing"}), 400 |
| timeout = float(data.get("timeout", 8)) |
| job_id = str(uuid.uuid4()) |
| tmp_file = FILES_DIR / f"job_{job_id}.py" |
| tmp_file.write_text(code, encoding="utf-8") |
|
|
| proc = subprocess.run( |
| ["python3", str(tmp_file)], |
| capture_output=True, |
| text=True, |
| timeout=timeout |
| ) |
| stdout = proc.stdout |
| stderr = proc.stderr |
| exit_code = proc.returncode |
|
|
| return jsonify({"stdout": stdout, "stderr": stderr, "exit_code": exit_code, "job_id": job_id}) |
| except subprocess.TimeoutExpired as te: |
| return jsonify({"error": "timeout", "detail": str(te)}), 500 |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/create_file", methods=["POST"]) |
| def create_file(): |
| """ |
| POST JSON: { "filename": "name.txt", "content": "...", "encode_base64": false } |
| """ |
| try: |
| data = request.get_json(force=True) |
| filename = safe_filename(data.get("filename", f"file_{uuid.uuid4()}.txt")) |
| content = data.get("content", "") |
| b64 = bool(data.get("encode_base64", False)) |
| path = FILES_DIR / filename |
| if b64: |
| decoded = base64.b64decode(content) |
| path.write_bytes(decoded) |
| else: |
| path.write_text(content, encoding="utf-8") |
| return jsonify({"path": str(path), "filename": filename}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/list_files", methods=["GET"]) |
| def list_files(): |
| files = [] |
| for f in FILES_DIR.iterdir(): |
| if f.is_file(): |
| files.append({"name": f.name, "size": f.stat().st_size, "path": str(f)}) |
| return jsonify({"files": files}) |
|
|
| |
| @app.route("/download_file", methods=["POST"]) |
| def download_file(): |
| """ |
| POST JSON: { "filename": "name.txt", "as_base64": false } |
| """ |
| try: |
| data = request.get_json(force=True) |
| filename = data.get("filename", "") |
| if not filename: |
| return jsonify({"error": "filename missing"}), 400 |
| path = FILES_DIR / filename |
| if not path.exists(): |
| return jsonify({"error": "file not found"}), 404 |
| as_b64 = bool(data.get("as_base64", False)) |
| if as_b64: |
| b = path.read_bytes() |
| return jsonify({"filename": filename, "content_base64": base64.b64encode(b).decode()}) |
| else: |
| text = path.read_text(encoding="utf-8", errors="replace") |
| return jsonify({"filename": filename, "content": text}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/delete_file", methods=["POST"]) |
| def delete_file(): |
| try: |
| data = request.get_json(force=True) |
| filename = data.get("filename", "") |
| if not filename: |
| return jsonify({"error": "filename missing"}), 400 |
| path = FILES_DIR / filename |
| if not path.exists(): |
| return jsonify({"error": "file not found"}), 404 |
| path.unlink() |
| return jsonify({"deleted": filename}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| @app.route("/ask_search", methods=["POST"]) |
| def ask_search(): |
| """ |
| POST JSON: { "q": "question", "top_k": 3, "max_new_tokens": 300 } |
| Returns search results + LLM synthesized answer |
| """ |
| try: |
| data = request.get_json(force=True) |
| q = (data.get("q") or "").strip() |
| if not q: |
| return jsonify({"error": "q missing"}), 400 |
| top_k = int(data.get("top_k", 3)) |
| |
| search_resp = requests.post("https://html.duckduckgo.com/html/", data={"q": q}, timeout=10) |
| soup = BeautifulSoup(search_resp.text, "html.parser") |
| anchors = soup.select("a.result__a")[:top_k] |
| snippets = [] |
| results = [] |
| for a in anchors: |
| title = a.get_text().strip() |
| href = a.get("href") |
| results.append({"title": title, "url": href}) |
| |
| try: |
| r2 = requests.get(href, timeout=5) |
| txt = r2.text[:4000] |
| snippets.append(txt) |
| except Exception: |
| pass |
|
|
| |
| combined = "\n\n---\n\n".join(snippets[:3]) |
| prompt = f"Question: {q}\n\nUse the following snippets from web pages to answer the question (be concise and cite urls where useful):\n\n{combined}\n\nAnswer:" |
| max_new_tokens = int(data.get("max_new_tokens", GEN_DEFAULTS["max_new_tokens"])) |
| answer = generate_from_model(prompt, max_new_tokens=max_new_tokens) |
| |
| if "Answer:" in answer: |
| answer_text = answer.split("Answer:", 1)[1].strip() |
| else: |
| answer_text = answer.replace(prompt, "").strip() |
| return jsonify({"query": q, "search_results": results, "answer": answer_text}) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("Engine ready. Endpoints:") |
| print(" /health (GET)") |
| print(" /model_info (GET)") |
| print(" /chat (POST) -> {message}") |
| print(" /search (POST) -> {q, top_k}") |
| print(" /fetch_url (POST) -> {url, max_chars}") |
| print(" /summarize (POST) -> {text}") |
| print(" /run_code (POST) -> {code, timeout}") |
| print(" /create_file (POST) -> {filename, content}") |
| print(" /list_files (GET)") |
| print(" /download_file (POST) -> {filename, as_base64}") |
| print(" /delete_file (POST) -> {filename}") |
| print(" /ask_search (POST) -> {q, top_k}") |
| app.run(host="0.0.0.0", port=PORT, threaded=True) |