M4CXR-TNNLS / task_examples.py
Jayden Park
Update README and add requirements.txt
789ad27
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from interface import load_image_from_url, do_generate
findings = "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, pleural other, fracture, support devices"
templates = {
"single-image": (
"radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}",
"Based on the previous conversation, provide a description of the findings in the radiology image.",
),
"multi-image": (
"radiology images: {images} Which of the following findings are present in the radiology images? Findings: {findings}",
"Based on the previous conversation, provide a description of the findings in the radiology images.",
),
"multi-study": (
"prior radiology images: {prior_images}, prior radiology report: {prior_report} follow-up images: {images}, The radiology studies are given in chronological order. Which of the following findings are present in the current follow-up radiology images? Findings: {findings}",
"Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.",
),
"visual-grounding": (
"radiology image: <image> Provide the bounding box coordinate of the region this phrase describes: {phrase}",
),
"summarize": (
"radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}",
"Based on the previous conversation, provide a description of the findings in the radiology image.",
"Summarize the description in one concise sentence.",
),
}
def do_generate_multi_turn(
sequential_questions, images, model, processor, generation_config
):
chats = []
for question in sequential_questions:
chats.append({"role": "user", "content": question})
# mini-batch size 1
prompts = []
prompt = processor.apply_chat_template(chats, tokenize=False)
prompts.append(prompt)
outputs = do_generate(prompts, images, model, processor, generation_config)
chats.append({"role": "assistant", "content": outputs[0]})
return chats
if __name__ == "__main__":
# Setup constant
device = torch.device("cuda")
dtype = torch.bfloat16
do_sample = False
# Load Processor and Model
processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
model = AutoModelForCausalLM.from_pretrained(
"Deepnoid/M4CXR-TNNLS",
trust_remote_code=True,
torch_dtype=dtype,
device_map=device,
)
# example image
image = load_image_from_url(
"https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
)
# Task 1: single-image medical report generation (CoT Prompting)
images = [image]
questions = list(templates["single-image"])
questions[0] = questions[0].format(findings=findings)
chats = do_generate_multi_turn(
questions, images, model, processor, generation_config
)
print("=" * 5, "single-image medical report generation", "=" * 5)
print(chats)
# Task 2: multi-image medical report generation (CoT Prompting)
images = [image, image, image]
image_tokens = " ".join("<image>" for _ in images)
questions = list(templates["multi-image"])
questions[0] = questions[0].format(images=image_tokens, findings=findings)
chats = do_generate_multi_turn(
questions, images, model, processor, generation_config
)
print("=" * 5, "multi-image medical report generation", "=" * 5)
print(chats)
# Task 3: multi-study medical report generation (CoT Prompting)
prior_images = [image, image]
prior_image_tokens = " ".join("<image>" for _ in prior_images)
prior_report = "The lungs are clear. There is no pneumothorax."
follow_up_images = [image, image, image]
follow_up_image_tokens = " ".join("<image>" for _ in follow_up_images)
images = prior_images + follow_up_images
questions = list(templates["multi-study"])
questions[0] = questions[0].format(
prior_images=prior_image_tokens,
prior_report=prior_report,
images=follow_up_image_tokens,
findings=findings,
)
chats = do_generate_multi_turn(
questions, images, model, processor, generation_config
)
print("=" * 5, "multi-study medical report generation", "=" * 5)
print(chats)
# Task 4: visual grounding
images = [image]
phrase = "right lower lobe"
questions = list(templates["visual-grounding"])
questions[0] = questions[0].format(phrase=phrase)
chats = do_generate_multi_turn(
questions, images, model, processor, generation_config
)
print("=" * 5, "visual grounding", "=" * 5)
print(chats)
# Task 5: summarize (mrg & summarize)
images = [image]
questions = list(templates["summarize"])
questions[0] = questions[0].format(findings=findings)
chats = do_generate_multi_turn(
questions, images, model, processor, generation_config
)
print("=" * 5, "medical report generation & summarize", "=" * 5)
print(chats)