import torch from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig from interface import load_image_from_url, do_generate findings = "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, pleural other, fracture, support devices" templates = { "single-image": ( "radiology image: Which of the following findings are present in the radiology image? Findings: {findings}", "Based on the previous conversation, provide a description of the findings in the radiology image.", ), "multi-image": ( "radiology images: {images} Which of the following findings are present in the radiology images? Findings: {findings}", "Based on the previous conversation, provide a description of the findings in the radiology images.", ), "multi-study": ( "prior radiology images: {prior_images}, prior radiology report: {prior_report} follow-up images: {images}, The radiology studies are given in chronological order. Which of the following findings are present in the current follow-up radiology images? Findings: {findings}", "Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.", ), "visual-grounding": ( "radiology image: Provide the bounding box coordinate of the region this phrase describes: {phrase}", ), "summarize": ( "radiology image: Which of the following findings are present in the radiology image? Findings: {findings}", "Based on the previous conversation, provide a description of the findings in the radiology image.", "Summarize the description in one concise sentence.", ), } def do_generate_multi_turn( sequential_questions, images, model, processor, generation_config ): chats = [] for question in sequential_questions: chats.append({"role": "user", "content": question}) # mini-batch size 1 prompts = [] prompt = processor.apply_chat_template(chats, tokenize=False) prompts.append(prompt) outputs = do_generate(prompts, images, model, processor, generation_config) chats.append({"role": "assistant", "content": outputs[0]}) return chats if __name__ == "__main__": # Setup constant device = torch.device("cuda") dtype = torch.bfloat16 do_sample = False # Load Processor and Model processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True) generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS") model = AutoModelForCausalLM.from_pretrained( "Deepnoid/M4CXR-TNNLS", trust_remote_code=True, torch_dtype=dtype, device_map=device, ) # example image image = load_image_from_url( "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg" ) # Task 1: single-image medical report generation (CoT Prompting) images = [image] questions = list(templates["single-image"]) questions[0] = questions[0].format(findings=findings) chats = do_generate_multi_turn( questions, images, model, processor, generation_config ) print("=" * 5, "single-image medical report generation", "=" * 5) print(chats) # Task 2: multi-image medical report generation (CoT Prompting) images = [image, image, image] image_tokens = " ".join("" for _ in images) questions = list(templates["multi-image"]) questions[0] = questions[0].format(images=image_tokens, findings=findings) chats = do_generate_multi_turn( questions, images, model, processor, generation_config ) print("=" * 5, "multi-image medical report generation", "=" * 5) print(chats) # Task 3: multi-study medical report generation (CoT Prompting) prior_images = [image, image] prior_image_tokens = " ".join("" for _ in prior_images) prior_report = "The lungs are clear. There is no pneumothorax." follow_up_images = [image, image, image] follow_up_image_tokens = " ".join("" for _ in follow_up_images) images = prior_images + follow_up_images questions = list(templates["multi-study"]) questions[0] = questions[0].format( prior_images=prior_image_tokens, prior_report=prior_report, images=follow_up_image_tokens, findings=findings, ) chats = do_generate_multi_turn( questions, images, model, processor, generation_config ) print("=" * 5, "multi-study medical report generation", "=" * 5) print(chats) # Task 4: visual grounding images = [image] phrase = "right lower lobe" questions = list(templates["visual-grounding"]) questions[0] = questions[0].format(phrase=phrase) chats = do_generate_multi_turn( questions, images, model, processor, generation_config ) print("=" * 5, "visual grounding", "=" * 5) print(chats) # Task 5: summarize (mrg & summarize) images = [image] questions = list(templates["summarize"]) questions[0] = questions[0].format(findings=findings) chats = do_generate_multi_turn( questions, images, model, processor, generation_config ) print("=" * 5, "medical report generation & summarize", "=" * 5) print(chats)