| | """inference.py - Code generation model wrapper for smolagents""" |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | class CodeModel: |
| | def __init__(self, model_id: str, device: str = None): |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) |
| | dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
| | self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device, dtype=dtype) |
| | self.model.eval() |
| |
|
| | def generate(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str: |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | do_sample=True, |
| | top_p=0.9, |
| | repetition_penalty=1.2, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | ) |
| |
|
| | new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
| | return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
| |
|
| | def chat(self, messages: list[dict], max_new_tokens: int = 256) -> str: |
| | """Generate response using chat template.""" |
| | text = self.tokenizer.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | tokenize=False |
| | ) |
| | inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.9, |
| | repetition_penalty=1.2, |
| | ) |
| |
|
| | new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
| | return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import os |
| | |
| | model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" |
| | model = CodeModel(model_id) |
| |
|
| | |
| | result = model.generate("Write a Python function to calculate factorial") |
| | print("Generated code:") |
| | print(result) |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": "You are a helpful coding assistant."}, |
| | {"role": "user", "content": "Write a function to reverse a string"} |
| | ] |
| | response = model.chat(messages) |
| | print("\nChat response:") |
| | print(response) |
| |
|