M4CXR-TNNLS / interface.py
Jayden Park
Update README and add requirements.txt
789ad27
import io
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
def load_image_from_url(url):
try:
response = requests.get(url)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content))
return image
except requests.exceptions.RequestException as e:
print(f"Error loading image: {e}")
return None
def do_generate(prompts, images, model, processor, generation_config):
"""The interface for generation
Args:
prompts (List[str]): List of prompt texts for entire batch
images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch
model (MllmForConditionalGeneration): MllmForConditionalGeneration
processor (MllmProcessor): MllmProcessor
generation_config (GenerationConfig): generation configurations
Returns:
outputs (List[str]): Generated responses for entire batch
"""
# image, text processing
inputs = processor(texts=prompts, images=images)
# prepare inputs
inputs = {
k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# batch decoding
with torch.inference_mode():
res = model.generate(**inputs, generation_config=generation_config)
# decode tokens
outputs = processor.batch_decode(res, skip_special_tokens=True)
return outputs
if __name__ == "__main__":
# Setup constant
device = torch.device("cuda")
dtype = torch.bfloat16
do_sample = False
# Load Processor and Model
processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
model = AutoModelForCausalLM.from_pretrained(
"Deepnoid/M4CXR-TNNLS",
trust_remote_code=True,
torch_dtype=dtype,
device_map=device,
)
# Prepare images
images = [
load_image_from_url(
"https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
),
load_image_from_url(
"https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
),
]
# seperate question list
questions = [
"radiology image: <image> What is the view of this chest X-ray?",
"radiology image: <image> Provide a description of the findings in the radiology image.",
]
# build prompts with chat template
prompts = []
for question in questions:
chats = [{"role": "user", "content": question}]
prompt = processor.apply_chat_template(chats, tokenize=False)
prompts.append(prompt)
# Generate responses
generation_config.do_sample = do_sample
outputs = do_generate(prompts, images, model, processor, generation_config)
print(outputs)