| import torch |
|
|
| from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig |
| from interface import load_image_from_url, do_generate |
|
|
| findings = "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, pleural other, fracture, support devices" |
|
|
| templates = { |
| "single-image": ( |
| "radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the radiology image.", |
| ), |
| "multi-image": ( |
| "radiology images: {images} Which of the following findings are present in the radiology images? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the radiology images.", |
| ), |
| "multi-study": ( |
| "prior radiology images: {prior_images}, prior radiology report: {prior_report} follow-up images: {images}, The radiology studies are given in chronological order. Which of the following findings are present in the current follow-up radiology images? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.", |
| ), |
| "visual-grounding": ( |
| "radiology image: <image> Provide the bounding box coordinate of the region this phrase describes: {phrase}", |
| ), |
| "summarize": ( |
| "radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the radiology image.", |
| "Summarize the description in one concise sentence.", |
| ), |
| } |
|
|
|
|
| def do_generate_multi_turn( |
| sequential_questions, images, model, processor, generation_config |
| ): |
| chats = [] |
| for question in sequential_questions: |
| chats.append({"role": "user", "content": question}) |
|
|
| |
| prompts = [] |
| prompt = processor.apply_chat_template(chats, tokenize=False) |
| prompts.append(prompt) |
|
|
| outputs = do_generate(prompts, images, model, processor, generation_config) |
|
|
| chats.append({"role": "assistant", "content": outputs[0]}) |
|
|
| return chats |
|
|
|
|
| if __name__ == "__main__": |
| |
| device = torch.device("cuda") |
| dtype = torch.bfloat16 |
| do_sample = False |
|
|
| |
| processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True) |
| generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS") |
| model = AutoModelForCausalLM.from_pretrained( |
| "Deepnoid/M4CXR-TNNLS", |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| device_map=device, |
| ) |
|
|
| |
| image = load_image_from_url( |
| "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" |
| ) |
|
|
| |
| images = [image] |
| questions = list(templates["single-image"]) |
| questions[0] = questions[0].format(findings=findings) |
| chats = do_generate_multi_turn( |
| questions, images, model, processor, generation_config |
| ) |
| print("=" * 5, "single-image medical report generation", "=" * 5) |
| print(chats) |
|
|
| |
| images = [image, image, image] |
| image_tokens = " ".join("<image>" for _ in images) |
| questions = list(templates["multi-image"]) |
| questions[0] = questions[0].format(images=image_tokens, findings=findings) |
| chats = do_generate_multi_turn( |
| questions, images, model, processor, generation_config |
| ) |
| print("=" * 5, "multi-image medical report generation", "=" * 5) |
| print(chats) |
|
|
| |
| prior_images = [image, image] |
| prior_image_tokens = " ".join("<image>" for _ in prior_images) |
|
|
| prior_report = "The lungs are clear. There is no pneumothorax." |
|
|
| follow_up_images = [image, image, image] |
| follow_up_image_tokens = " ".join("<image>" for _ in follow_up_images) |
| images = prior_images + follow_up_images |
|
|
| questions = list(templates["multi-study"]) |
| questions[0] = questions[0].format( |
| prior_images=prior_image_tokens, |
| prior_report=prior_report, |
| images=follow_up_image_tokens, |
| findings=findings, |
| ) |
| chats = do_generate_multi_turn( |
| questions, images, model, processor, generation_config |
| ) |
| print("=" * 5, "multi-study medical report generation", "=" * 5) |
| print(chats) |
|
|
| |
| images = [image] |
| phrase = "right lower lobe" |
| questions = list(templates["visual-grounding"]) |
| questions[0] = questions[0].format(phrase=phrase) |
| chats = do_generate_multi_turn( |
| questions, images, model, processor, generation_config |
| ) |
| print("=" * 5, "visual grounding", "=" * 5) |
| print(chats) |
|
|
| |
| images = [image] |
| questions = list(templates["summarize"]) |
| questions[0] = questions[0].format(findings=findings) |
| chats = do_generate_multi_turn( |
| questions, images, model, processor, generation_config |
| ) |
| print("=" * 5, "medical report generation & summarize", "=" * 5) |
| print(chats) |
|
|