| import io |
|
|
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig |
|
|
|
|
| def load_image_from_url(url): |
| try: |
| response = requests.get(url) |
| response.raise_for_status() |
|
|
| image = Image.open(io.BytesIO(response.content)) |
| return image |
|
|
| except requests.exceptions.RequestException as e: |
| print(f"Error loading image: {e}") |
| return None |
|
|
|
|
| def do_generate(prompts, images, model, processor, generation_config): |
| """The interface for generation |
| |
| Args: |
| prompts (List[str]): List of prompt texts for entire batch |
| images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch |
| model (MllmForConditionalGeneration): MllmForConditionalGeneration |
| processor (MllmProcessor): MllmProcessor |
| generation_config (GenerationConfig): generation configurations |
| |
| Returns: |
| outputs (List[str]): Generated responses for entire batch |
| """ |
|
|
| |
| inputs = processor(texts=prompts, images=images) |
|
|
| |
| inputs = { |
| k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items() |
| } |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| |
| with torch.inference_mode(): |
| res = model.generate(**inputs, generation_config=generation_config) |
|
|
| |
| outputs = processor.batch_decode(res, skip_special_tokens=True) |
| return outputs |
|
|
|
|
| if __name__ == "__main__": |
| |
| device = torch.device("cuda") |
| dtype = torch.bfloat16 |
| do_sample = False |
|
|
| |
| processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True) |
| generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS") |
| model = AutoModelForCausalLM.from_pretrained( |
| "Deepnoid/M4CXR-TNNLS", |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| device_map=device, |
| ) |
|
|
| |
| images = [ |
| load_image_from_url( |
| "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" |
| ), |
| load_image_from_url( |
| "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" |
| ), |
| ] |
|
|
| |
| questions = [ |
| "radiology image: <image> What is the view of this chest X-ray?", |
| "radiology image: <image> Provide a description of the findings in the radiology image.", |
| ] |
|
|
| |
| prompts = [] |
| for question in questions: |
| chats = [{"role": "user", "content": question}] |
| prompt = processor.apply_chat_template(chats, tokenize=False) |
| prompts.append(prompt) |
|
|
| |
| generation_config.do_sample = do_sample |
| outputs = do_generate(prompts, images, model, processor, generation_config) |
| print(outputs) |
|
|