| | import argparse |
| | import os |
| | import pprint |
| | import yaml |
| | from typing import Tuple, List, Optional, Dict |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.amp import autocast |
| | from torch.amp import GradScaler |
| | from tqdm import tqdm |
| | import random |
| | import torch.backends.cudnn as cudnn |
| | import cv2 |
| | from torch.utils.data import DataLoader |
| | import time |
| |
|
| | from src.wireseghr.model import WireSegHR |
| | from src.wireseghr.model.minmax import MinMaxLuminance |
| | from src.wireseghr.data.dataset import WireSegDataset |
| | from src.wireseghr.model.label_downsample import downsample_label_maxpool |
| | from src.wireseghr.data.sampler import BalancedPatchSampler |
| | from src.wireseghr.metrics import compute_metrics |
| | from infer import _coarse_forward, _tiled_fine_forward |
| | from pathlib import Path |
| |
|
| |
|
| | class SizeBatchSampler: |
| | """Batch sampler that groups indices by exact (H, W) so all samples in a batch share size. |
| | |
| | This enables DataLoader prefetching while preserving the existing assumption |
| | in `_prepare_batch()` that all items in a batch have the same full resolution. |
| | """ |
| |
|
| | def __init__(self, dset: WireSegDataset, batch_size: int): |
| | self.dset = dset |
| | self.batch_size = batch_size |
| | |
| | bins = self.dset.size_bins |
| | self._len = 0 |
| | for hw, idxs in bins.items(): |
| | _ = hw |
| | self._len += len(idxs) // self.batch_size |
| |
|
| | def __len__(self) -> int: |
| | return self._len |
| |
|
| | def __iter__(self): |
| | |
| | bins = self.dset.size_bins |
| | keys = list(bins.keys()) |
| | random.shuffle(keys) |
| | for hw in keys: |
| | pool = list(bins[hw]) |
| | random.shuffle(pool) |
| | |
| | for i in range( |
| | 0, len(pool) - (len(pool) % self.batch_size), self.batch_size |
| | ): |
| | yield pool[i : i + self.batch_size] |
| |
|
| |
|
| | def collate_train(batch: List[Dict]): |
| | """Collate function that returns lists of numpy arrays to match existing pipeline.""" |
| | imgs = [b["image"] for b in batch] |
| | masks = [b["mask"] for b in batch] |
| | return imgs, masks |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)") |
| | parser.add_argument( |
| | "--config", type=str, default="configs/default.yaml", help="Path to YAML config" |
| | ) |
| | args = parser.parse_args() |
| |
|
| | cfg_path = args.config |
| | if not Path(cfg_path).is_absolute(): |
| | cfg_path = str(Path.cwd() / cfg_path) |
| |
|
| | with open(cfg_path, "r") as f: |
| | cfg = yaml.safe_load(f) |
| |
|
| | print("[WireSegHR][train] Loaded config from:", cfg_path) |
| | pprint.pprint(cfg) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"[WireSegHR][train] Device: {device}") |
| |
|
| | |
| | coarse_train = int(cfg["coarse"]["train_size"]) |
| | coarse_test = int(cfg["coarse"]["test_size"]) |
| | patch_size = int(cfg["fine"]["patch_size"]) |
| | overlap = int(cfg["fine"]["overlap"]) |
| | eval_patch_size = int(cfg["inference"]["fine_patch_size"]) |
| | eval_cfg = cfg.get("eval", {}) |
| | eval_fine_batch = int(eval_cfg.get("fine_batch", 16)) |
| | assert eval_fine_batch >= 1 |
| | eval_max_samples = int(eval_cfg.get("max_samples", 16)) |
| | assert eval_max_samples >= 1 |
| | iters = int(cfg["optim"]["iters"]) |
| | batch_size = int(cfg["optim"]["batch_size"]) |
| | base_lr = float(cfg["optim"]["lr"]) |
| | weight_decay = float(cfg["optim"]["weight_decay"]) |
| | power = float(cfg["optim"]["power"]) |
| | precision = str(cfg["optim"].get("precision", "fp32")).lower() |
| | assert precision in ("fp32", "fp16", "bf16") |
| | |
| | amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16")) |
| | |
| | if amp_enabled: |
| | cc_major, cc_minor = torch.cuda.get_device_capability() |
| | if precision == "fp16": |
| | assert cc_major >= 7, ( |
| | f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}" |
| | ) |
| | elif precision == "bf16": |
| | assert cc_major >= 8, ( |
| | f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}" |
| | ) |
| | amp_dtype = ( |
| | torch.float16 |
| | if precision == "fp16" |
| | else (torch.bfloat16 if precision == "bf16" else None) |
| | ) |
| |
|
| | |
| | seed = int(cfg.get("seed", 42)) |
| | out_dir = cfg.get("out_dir", "runs/wireseghr") |
| | eval_interval = int(cfg["eval_interval"]) |
| | ckpt_interval = int(cfg["ckpt_interval"]) |
| | os.makedirs(out_dir, exist_ok=True) |
| | set_seed(seed) |
| |
|
| | |
| | train_images = cfg["data"]["train_images"] |
| | train_masks = cfg["data"]["train_masks"] |
| | dset = WireSegDataset(train_images, train_masks, split="train") |
| | |
| | loader_cfg = cfg.get("loader", {}) |
| | num_workers = int(loader_cfg.get("num_workers", 4)) |
| | prefetch_factor = int(loader_cfg.get("prefetch_factor", 2)) |
| | pin_memory = bool(loader_cfg.get("pin_memory", True)) |
| | persistent_workers = ( |
| | bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False |
| | ) |
| | batch_sampler = SizeBatchSampler(dset, batch_size) |
| | loader_kwargs = dict( |
| | batch_sampler=batch_sampler, |
| | num_workers=num_workers, |
| | pin_memory=pin_memory, |
| | persistent_workers=persistent_workers, |
| | collate_fn=collate_train, |
| | ) |
| | if num_workers > 0: |
| | loader_kwargs["prefetch_factor"] = prefetch_factor |
| | train_loader = DataLoader(dset, **loader_kwargs) |
| | |
| | val_images = cfg["data"].get("val_images", None) |
| | val_masks = cfg["data"].get("val_masks", None) |
| | test_images = cfg["data"].get("test_images", None) |
| | test_masks = cfg["data"].get("test_masks", None) |
| | dset_val = ( |
| | WireSegDataset(val_images, val_masks, split="val") |
| | if val_images and val_masks |
| | else None |
| | ) |
| | dset_test = ( |
| | WireSegDataset(test_images, test_masks, split="test") |
| | if test_images and test_masks |
| | else None |
| | ) |
| | sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01) |
| | minmax = ( |
| | MinMaxLuminance(kernel=cfg["minmax"]["kernel"]) |
| | if cfg["minmax"]["enable"] |
| | else None |
| | ) |
| |
|
| | |
| | prob_thresh = float(cfg["inference"]["prob_threshold"]) |
| | mm_enable = bool(cfg["minmax"]["enable"]) |
| | mm_kernel = int(cfg["minmax"]["kernel"]) |
| |
|
| | |
| | |
| | pretrained_flag = bool(cfg.get("pretrained", False)) |
| | model = WireSegHR( |
| | backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag |
| | ) |
| | model = model.to(device) |
| |
|
| | |
| | optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay) |
| | scaler = GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16")) |
| | ce = nn.CrossEntropyLoss() |
| |
|
| | |
| | start_step = 0 |
| | best_f1 = -1.0 |
| | resume_path = cfg.get("resume", None) |
| | if resume_path and Path(resume_path).is_file(): |
| | print(f"[WireSegHR][train] Resuming from {resume_path}") |
| | start_step, best_f1 = _load_checkpoint( |
| | resume_path, model, optim, scaler, device |
| | ) |
| |
|
| | |
| | model.train() |
| | step = start_step |
| | pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100) |
| | data_iter = iter(train_loader) |
| | while step < iters: |
| | optim.zero_grad(set_to_none=True) |
| | try: |
| | imgs, masks = next(data_iter) |
| | except StopIteration: |
| | data_iter = iter(train_loader) |
| | imgs, masks = next(data_iter) |
| | batch = _prepare_batch( |
| | imgs, masks, coarse_train, patch_size, sampler, minmax, device |
| | ) |
| |
|
| | with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled): |
| | logits_coarse, cond_map = model.forward_coarse( |
| | batch["x_coarse"] |
| | ) |
| |
|
| | |
| | B, _, hc4, wc4 = cond_map.shape |
| | x_fine = _build_fine_inputs(batch, cond_map, device) |
| | with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled): |
| | logits_fine = model.forward_fine(x_fine) |
| |
|
| | |
| | y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device) |
| | y_fine = _build_fine_targets( |
| | batch["mask_patches"], |
| | logits_fine.shape[2], |
| | logits_fine.shape[3], |
| | device, |
| | ) |
| |
|
| | loss_coarse = ce(logits_coarse, y_coarse) |
| | loss_fine = ce(logits_fine, y_fine) |
| | loss = loss_coarse + loss_fine |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(optim) |
| | scaler.update() |
| |
|
| | |
| | lr = base_lr * ((1.0 - float(step) / float(iters)) ** power) |
| | for pg in optim.param_groups: |
| | pg["lr"] = lr |
| |
|
| | if step % 50 == 0: |
| | print(f"[Iter {step}/{iters}] lr={lr:.6e}") |
| |
|
| | |
| | if (step % eval_interval == 0) and (dset_val is not None): |
| | |
| | del ( |
| | x_fine, |
| | logits_coarse, |
| | cond_map, |
| | logits_fine, |
| | y_coarse, |
| | y_fine, |
| | loss_coarse, |
| | loss_fine, |
| | loss, |
| | ) |
| | torch.cuda.empty_cache() |
| | model.eval() |
| | print( |
| | f"[WireSegHR][train] Eval starting... val_size={len(dset_val)} max={eval_max_samples} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}", |
| | flush=True, |
| | ) |
| | val_stats = validate( |
| | model, |
| | dset_val, |
| | coarse_test, |
| | device, |
| | amp_enabled, |
| | amp_dtype, |
| | prob_thresh, |
| | mm_enable, |
| | mm_kernel, |
| | eval_patch_size, |
| | overlap, |
| | eval_fine_batch, |
| | eval_max_samples, |
| | ) |
| | print( |
| | f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}" |
| | ) |
| | print( |
| | f"[Val @ {step}][Coarse] IoU={val_stats['iou_coarse']:.4f} F1={val_stats['f1_coarse']:.4f} P={val_stats['precision_coarse']:.4f} R={val_stats['recall_coarse']:.4f}" |
| | ) |
| | |
| | if val_stats["f1"] > best_f1: |
| | best_f1 = val_stats["f1"] |
| | _save_checkpoint( |
| | str(Path(out_dir) / "best.pt"), |
| | step, |
| | model, |
| | optim, |
| | scaler, |
| | best_f1, |
| | ) |
| | |
| | if ckpt_interval > 0 and (step % ckpt_interval == 0): |
| | _save_checkpoint( |
| | str(Path(out_dir) / f"ckpt_{step}.pt"), |
| | step, |
| | model, |
| | optim, |
| | scaler, |
| | best_f1, |
| | ) |
| | |
| | if dset_test is not None: |
| | save_test_visuals( |
| | model, |
| | dset_test, |
| | coarse_test, |
| | device, |
| | str(Path(out_dir) / f"test_vis_{step}"), |
| | amp_enabled, |
| | mm_enable, |
| | mm_kernel, |
| | prob_thresh, |
| | max_samples=8, |
| | ) |
| | model.train() |
| |
|
| | step += 1 |
| | pbar.update(1) |
| |
|
| | |
| | _save_checkpoint( |
| | str(Path(out_dir) / f"ckpt_{iters}.pt"), step, model, optim, scaler, best_f1 |
| | ) |
| |
|
| | |
| | if dset_test is not None: |
| | torch.cuda.empty_cache() |
| | model.eval() |
| | print( |
| | f"[WireSegHR][train] Final test starting... test_size={len(dset_test)} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}", |
| | flush=True, |
| | ) |
| | test_stats = validate( |
| | model, |
| | dset_test, |
| | coarse_test, |
| | device, |
| | amp_enabled, |
| | amp_dtype, |
| | prob_thresh, |
| | mm_enable, |
| | mm_kernel, |
| | eval_patch_size, |
| | overlap, |
| | eval_fine_batch, |
| | len(dset_test), |
| | ) |
| | print( |
| | f"[Test Final][Fine] IoU={test_stats['iou']:.4f} F1={test_stats['f1']:.4f} P={test_stats['precision']:.4f} R={test_stats['recall']:.4f}" |
| | ) |
| | print( |
| | f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}" |
| | ) |
| | |
| | final_out = Path(out_dir) / f"final_vis_{step}" |
| | final_out.mkdir(parents=True, exist_ok=True) |
| | |
| | with open(final_out / "metrics.yaml", "w") as f: |
| | yaml.safe_dump({**test_stats, "step": step}, f, sort_keys=False) |
| | |
| | save_final_visuals( |
| | model, |
| | dset_test, |
| | coarse_test, |
| | device, |
| | str(final_out), |
| | amp_enabled, |
| | amp_dtype, |
| | prob_thresh, |
| | mm_enable, |
| | mm_kernel, |
| | eval_patch_size, |
| | overlap, |
| | eval_fine_batch, |
| | ) |
| | model.train() |
| |
|
| | print("[WireSegHR][train] Done.") |
| |
|
| |
|
| |
|
| | def _prepare_batch( |
| | imgs: List[np.ndarray], |
| | masks: List[np.ndarray], |
| | coarse_train: int, |
| | patch_size: int, |
| | sampler: BalancedPatchSampler, |
| | minmax: Optional[MinMaxLuminance], |
| | device: torch.device, |
| | ): |
| | B = len(imgs) |
| | assert B == len(masks) |
| | |
| |
|
| | full_h = imgs[0].shape[0] |
| | full_w = imgs[0].shape[1] |
| | for im, m in zip(imgs, masks): |
| | assert im.shape[0] == full_h and im.shape[1] == full_w |
| | assert m.shape[0] == full_h and m.shape[1] == full_w |
| |
|
| | xs_coarse = [] |
| | patches_rgb = [] |
| | patches_mask = [] |
| | patches_min = [] |
| | patches_max = [] |
| | yx_list: List[tuple[int, int]] = [] |
| |
|
| | for img, mask in zip(imgs, masks): |
| | |
| | imgf = img.astype(np.float32) / 255.0 |
| | t_img = ( |
| | torch.from_numpy(np.transpose(imgf, (2, 0, 1))).unsqueeze(0).to(device) |
| | ) |
| |
|
| | |
| | y_t = ( |
| | 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3] |
| | ) |
| | if minmax is not None: |
| | |
| | y_p = F.pad(y_t, (2, 3, 2, 3), mode="replicate") |
| | y_max_full = F.max_pool2d(y_p, kernel_size=6, stride=1) |
| | y_min_full = -F.max_pool2d(-y_p, kernel_size=6, stride=1) |
| | else: |
| | y_min_full = y_t |
| | y_max_full = y_t |
| |
|
| | |
| | rgb_coarse_t = F.interpolate( |
| | t_img, |
| | size=(coarse_train, coarse_train), |
| | mode="bilinear", |
| | align_corners=False, |
| | )[0] |
| | y_min_c_t = F.interpolate( |
| | y_min_full, |
| | size=(coarse_train, coarse_train), |
| | mode="bilinear", |
| | align_corners=False, |
| | )[0] |
| | y_max_c_t = F.interpolate( |
| | y_max_full, |
| | size=(coarse_train, coarse_train), |
| | mode="bilinear", |
| | align_corners=False, |
| | )[0] |
| | zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device) |
| | c_t = torch.cat( |
| | [rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse], dim=0 |
| | ) |
| | xs_coarse.append(c_t) |
| |
|
| | |
| | y0, x0 = sampler.sample(imgf, mask) |
| | patch_rgb = imgf[y0 : y0 + patch_size, x0 : x0 + patch_size, :] |
| | patch_mask = mask[y0 : y0 + patch_size, x0 : x0 + patch_size] |
| | patches_rgb.append(patch_rgb) |
| | patches_mask.append(patch_mask) |
| | ymin_patch = ( |
| | y_min_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size] |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | ymax_patch = ( |
| | y_max_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size] |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | patches_min.append(ymin_patch) |
| | patches_max.append(ymax_patch) |
| | yx_list.append((y0, x0)) |
| |
|
| | x_coarse = torch.stack(xs_coarse, dim=0) |
| |
|
| | |
| | return { |
| | "x_coarse": x_coarse, |
| | "full_h": full_h, |
| | "full_w": full_w, |
| | "rgb_patches": patches_rgb, |
| | "mask_patches": patches_mask, |
| | "ymin_patches": patches_min, |
| | "ymax_patches": patches_max, |
| | "patch_yx": yx_list, |
| | "mask_full": masks, |
| | } |
| |
|
| |
|
| | def _build_fine_inputs( |
| | batch, cond_map: torch.Tensor, device: torch.device |
| | ) -> torch.Tensor: |
| | |
| | B = cond_map.shape[0] |
| | P = batch["rgb_patches"][0].shape[0] |
| | full_h, full_w = batch["full_h"], batch["full_w"] |
| | hc4, wc4 = cond_map.shape[2], cond_map.shape[3] |
| |
|
| | xs: List[torch.Tensor] = [] |
| | for i in range(B): |
| | rgb = batch["rgb_patches"][i] |
| | ymin = batch["ymin_patches"][i] |
| | ymax = batch["ymax_patches"][i] |
| | y0, x0 = batch["patch_yx"][i] |
| |
|
| | |
| | y1, x1 = y0 + P, x0 + P |
| | y0c = (y0 * hc4) // full_h |
| | y1c = ((y1 * hc4) + full_h - 1) // full_h |
| | x0c = (x0 * wc4) // full_w |
| | x1c = ((x1 * wc4) + full_w - 1) // full_w |
| | cond_sub = cond_map[i : i + 1, :, y0c:y1c, x0c:x1c].float() |
| | cond_patch = F.interpolate( |
| | cond_sub, size=(P, P), mode="bilinear", align_corners=False |
| | ).squeeze(1) |
| |
|
| | |
| | rgb_t = ( |
| | torch.from_numpy(np.transpose(rgb, (2, 0, 1))).to(device).float() |
| | ) |
| | ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() |
| | ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() |
| | x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0) |
| | xs.append(x) |
| | x_fine = torch.stack(xs, dim=0) |
| | return x_fine |
| |
|
| |
|
| | def _build_coarse_targets( |
| | masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device |
| | ) -> torch.Tensor: |
| | ys: List[torch.Tensor] = [] |
| | for m in masks: |
| | dm = downsample_label_maxpool(m, out_h, out_w) |
| | ys.append(torch.from_numpy(dm.astype(np.int64))) |
| | y = torch.stack(ys, dim=0).to(device) |
| | return y |
| |
|
| |
|
| | def _build_fine_targets( |
| | mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device |
| | ) -> torch.Tensor: |
| | ys: List[torch.Tensor] = [] |
| | for m in mask_patches: |
| | dm = downsample_label_maxpool(m, out_h, out_w) |
| | ys.append(torch.from_numpy(dm.astype(np.int64))) |
| | y = torch.stack(ys, dim=0).to(device) |
| | return y |
| |
|
| |
|
| | def set_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | |
| | |
| | cudnn.benchmark = True |
| | cudnn.deterministic = False |
| |
|
| |
|
| | def _save_checkpoint( |
| | path: str, |
| | step: int, |
| | model: nn.Module, |
| | optim: torch.optim.Optimizer, |
| | scaler: GradScaler, |
| | best_f1: float, |
| | ): |
| | Path(path).parent.mkdir(parents=True, exist_ok=True) |
| | state = { |
| | "step": step, |
| | "model": model.state_dict(), |
| | "optim": optim.state_dict(), |
| | "scaler": scaler.state_dict(), |
| | "best_f1": best_f1, |
| | } |
| | torch.save(state, path) |
| | print(f"[WireSegHR][train] Saved checkpoint: {path}") |
| |
|
| |
|
| | def _load_checkpoint( |
| | path: str, |
| | model: nn.Module, |
| | optim: torch.optim.Optimizer, |
| | scaler: GradScaler, |
| | device: torch.device, |
| | ) -> Tuple[int, float]: |
| | ckpt = torch.load(path, map_location=device) |
| | model.load_state_dict(ckpt["model"]) |
| | optim.load_state_dict(ckpt["optim"]) |
| | try: |
| | scaler.load_state_dict(ckpt["scaler"]) |
| | except Exception: |
| | pass |
| | step = int(ckpt.get("step", 0)) |
| | best_f1 = float(ckpt.get("best_f1", -1.0)) |
| | return step, best_f1 |
| |
|
| |
|
| | @torch.no_grad() |
| | def validate( |
| | model: WireSegHR, |
| | dset_val: WireSegDataset, |
| | coarse_size: int, |
| | device: torch.device, |
| | amp_flag: bool, |
| | amp_dtype, |
| | prob_thresh: float, |
| | minmax_enable: bool, |
| | minmax_kernel: int, |
| | fine_patch_size: int, |
| | fine_overlap: int, |
| | fine_batch: int, |
| | max_images: int, |
| | ) -> Dict[str, float]: |
| | |
| | model = model.to(device) |
| | metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| | coarse_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| | n = 0 |
| | t0 = time.time() |
| | total_tiles = 0 |
| | target_n = min(len(dset_val), max_images) |
| | idxs = random.sample(range(len(dset_val)), k=target_n) |
| | print( |
| | f"[Eval] Started: N={target_n}/{len(dset_val)} coarse={coarse_size} patch={fine_patch_size} overlap={fine_overlap} stride={fine_patch_size - fine_overlap} fine_batch={fine_batch}", |
| | flush=True, |
| | ) |
| | for j, i in enumerate(idxs): |
| | if (j % 2) == 0: |
| | print(f"[Eval] Running... {j}/{target_n}", flush=True) |
| | item = dset_val[i] |
| | img = item["image"].astype(np.float32) / 255.0 |
| | mask = item["mask"].astype(np.uint8) |
| | H, W = mask.shape |
| | |
| | prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| | model, |
| | img, |
| | coarse_size, |
| | minmax_enable, |
| | int(minmax_kernel), |
| | device, |
| | amp_flag, |
| | amp_dtype, |
| | ) |
| | |
| | pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy() |
| | m_c = compute_metrics(pred_coarse, mask) |
| | for k in coarse_sum: |
| | coarse_sum[k] += m_c[k] |
| |
|
| | |
| | prob_full = _tiled_fine_forward( |
| | model, |
| | t_img, |
| | cond_map, |
| | y_min_full, |
| | y_max_full, |
| | int(fine_patch_size), |
| | int(fine_overlap), |
| | int(fine_batch), |
| | device, |
| | amp_flag, |
| | amp_dtype, |
| | ) |
| | |
| | P = int(fine_patch_size) |
| | stride = P - int(fine_overlap) |
| | ys = list(range(0, H - P + 1, stride)) |
| | if ys[-1] != (H - P): |
| | ys.append(H - P) |
| | xs = list(range(0, W - P + 1, stride)) |
| | if xs[-1] != (W - P): |
| | xs.append(W - P) |
| | total_tiles += len(ys) * len(xs) |
| | pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy() |
| | m_f = compute_metrics(pred_fine, mask) |
| | for k in metrics_sum: |
| | metrics_sum[k] += m_f[k] |
| | n += 1 |
| | if n > 0: |
| | for k in metrics_sum: |
| | metrics_sum[k] /= n |
| | for k in coarse_sum: |
| | coarse_sum[k] /= n |
| | dt = time.time() - t0 |
| | tp_img = (n / dt) if dt > 0 else 0.0 |
| | tp_tile = (total_tiles / dt) if dt > 0 else 0.0 |
| | print( |
| | f"[Eval] Done in {dt:.2f}s | imgs={n}, tiles={total_tiles}, imgs/s={tp_img:.2f}, tiles/s={tp_tile:.2f}", |
| | flush=True, |
| | ) |
| | out = {k: v for k, v in metrics_sum.items()} |
| | out.update( |
| | { |
| | "iou_coarse": coarse_sum["iou"], |
| | "f1_coarse": coarse_sum["f1"], |
| | "precision_coarse": coarse_sum["precision"], |
| | "recall_coarse": coarse_sum["recall"], |
| | } |
| | ) |
| | return out |
| |
|
| |
|
| | @torch.no_grad() |
| | def save_test_visuals( |
| | model: WireSegHR, |
| | dset_test: WireSegDataset, |
| | coarse_size: int, |
| | device: torch.device, |
| | out_dir: str, |
| | amp_flag: bool, |
| | minmax_enable: bool, |
| | minmax_kernel: int, |
| | prob_thresh: float, |
| | max_samples: int = 8, |
| | ): |
| | Path(out_dir).mkdir(parents=True, exist_ok=True) |
| | for i in range(min(max_samples, len(dset_test))): |
| | item = dset_test[i] |
| | img = item["image"].astype(np.float32) / 255.0 |
| | H, W = img.shape[:2] |
| | prob_up, _cond_map, _t_img, _ymin, _ymax = _coarse_forward( |
| | model, |
| | img, |
| | int(coarse_size), |
| | bool(minmax_enable), |
| | int(minmax_kernel), |
| | device, |
| | bool(amp_flag), |
| | None, |
| | ) |
| | pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| | |
| | img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8) |
| | cv2.imwrite(str(Path(out_dir) / f"{i:03d}_input.jpg"), img_bgr) |
| | cv2.imwrite(str(Path(out_dir) / f"{i:03d}_pred.png"), pred) |
| |
|
| |
|
| | @torch.no_grad() |
| | def save_final_visuals( |
| | model: WireSegHR, |
| | dset_test: WireSegDataset, |
| | coarse_size: int, |
| | device: torch.device, |
| | out_dir: str, |
| | amp_flag: bool, |
| | amp_dtype, |
| | prob_thresh: float, |
| | minmax_enable: bool, |
| | minmax_kernel: int, |
| | fine_patch_size: int, |
| | fine_overlap: int, |
| | fine_batch: int, |
| | ): |
| | Path(out_dir).mkdir(parents=True, exist_ok=True) |
| | for i in range(len(dset_test)): |
| | item = dset_test[i] |
| | img = item["image"].astype(np.float32) / 255.0 |
| | H, W = img.shape[:2] |
| | |
| | prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| | model, |
| | img, |
| | int(coarse_size), |
| | bool(minmax_enable), |
| | int(minmax_kernel), |
| | device, |
| | bool(amp_flag), |
| | amp_dtype, |
| | ) |
| | pred_coarse = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| | |
| | prob_full = _tiled_fine_forward( |
| | model, |
| | t_img, |
| | cond_map, |
| | y_min_full, |
| | y_max_full, |
| | int(fine_patch_size), |
| | int(fine_overlap), |
| | int(fine_batch), |
| | device, |
| | bool(amp_flag), |
| | amp_dtype, |
| | ) |
| | pred_fine = ((prob_full > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| | |
| | img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8) |
| | base = f"{i:03d}" |
| | cv2.imwrite(str(Path(out_dir) / f"{base}_input.jpg"), img_bgr) |
| | cv2.imwrite(str(Path(out_dir) / f"{base}_coarse_pred.png"), pred_coarse) |
| | cv2.imwrite(str(Path(out_dir) / f"{base}_fine_pred.png"), pred_fine) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|