| import os |
| import numpy as np |
| import torch |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import cv2 |
|
|
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
|
|
| if device.type == "cuda": |
| |
| torch.autocast("cuda", dtype=torch.bfloat16).__enter__() |
| |
| if torch.cuda.get_device_properties(0).major >= 8: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| np.random.seed(3) |
|
|
| def show_anns(anns, borders=True): |
| if len(anns) == 0: |
| return |
| sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
| ax = plt.gca() |
| ax.set_autoscale_on(False) |
|
|
| img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) |
| img[:, :, 3] = 0 |
| for ann in sorted_anns: |
| m = ann['segmentation'] |
| color_mask = np.concatenate([np.random.random(3), [0.5]]) |
| img[m] = color_mask |
| if borders: |
| import cv2 |
| contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| |
| contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
| cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) |
|
|
| ax.imshow(img) |
|
|
|
|
| from sam2.build_sam import build_sam2 |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
|
| sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt" |
| model_cfg = "configs/sam2/sam2_hiera_b+.yaml" |
|
|
| sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) |
|
|
| mask_generator = SAM2AutomaticMaskGenerator(sam2) |
|
|
|
|
| image = Image.open('/home/yuqian_fu/Projects/sam2/DSCF0669.JPG') |
| image = np.array(image.convert("RGB")) |
| masks = mask_generator.generate(image) |
|
|
| save_path = "/home/yuqian_fu/Projects/sam2/results" |
| i = 8 |
| save_path = os.path.join(save_path, str(i) + ".png") |
| ann = masks[1]["segmentation"] |
| binary_mask = (ann.astype(np.uint8)) * 255 |
| cv2.imwrite(save_path, binary_mask) |
| |
|
|
|
|