import io import requests import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig def load_image_from_url(url): try: response = requests.get(url) response.raise_for_status() image = Image.open(io.BytesIO(response.content)) return image except requests.exceptions.RequestException as e: print(f"Error loading image: {e}") return None def do_generate(prompts, images, model, processor, generation_config): """The interface for generation Args: prompts (List[str]): List of prompt texts for entire batch images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch model (MllmForConditionalGeneration): MllmForConditionalGeneration processor (MllmProcessor): MllmProcessor generation_config (GenerationConfig): generation configurations Returns: outputs (List[str]): Generated responses for entire batch """ # image, text processing inputs = processor(texts=prompts, images=images) # prepare inputs inputs = { k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items() } inputs = {k: v.to(model.device) for k, v in inputs.items()} # batch decoding with torch.inference_mode(): res = model.generate(**inputs, generation_config=generation_config) # decode tokens outputs = processor.batch_decode(res, skip_special_tokens=True) return outputs if __name__ == "__main__": # Setup constant device = torch.device("cuda") dtype = torch.bfloat16 do_sample = False # Load Processor and Model processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True) generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS") model = AutoModelForCausalLM.from_pretrained( "Deepnoid/M4CXR-TNNLS", trust_remote_code=True, torch_dtype=dtype, device_map=device, ) # Prepare images images = [ load_image_from_url( "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" ), load_image_from_url( "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" ), ] # seperate question list questions = [ "radiology image: What is the view of this chest X-ray?", "radiology image: Provide a description of the findings in the radiology image.", ] # build prompts with chat template prompts = [] for question in questions: chats = [{"role": "user", "content": question}] prompt = processor.apply_chat_template(chats, tokenize=False) prompts.append(prompt) # Generate responses generation_config.do_sample = do_sample outputs = do_generate(prompts, images, model, processor, generation_config) print(outputs)