import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch.nn.functional as F
import csv
import time
from skimage import io, transform
import matplotlib.patches as mpatches
from segment_anything import sam_model_registry

#See segment_masks_get_metrics.py for setting up your CLASSES dictionary
CLASSES = {
    "kidney": {
         "target_color": np.array([255, 55, 0]),      
         # "blue":
         "display_color": np.array([0/255, 0/255, 255/255, 0.6])
    },
    "small_intestine": {
         "target_color": np.array([124, 155, 5]),
         # "orange":
         "display_color": np.array([255/255, 69/255, 0/255, 0.6])
    },
    "instrument-shaft": {
         "target_color": np.array([0, 255, 0]),
         # "red":
         "display_color": np.array([255/255, 0/255, 0/255, 0.6])
    },
    "instrument-clasper": {
         "target_color": np.array([0, 255, 255]),
         # "yellow":
         "display_color": np.array([255/255, 255/255, 0/255, 0.6])
    },
    "instrument-wrist": {
         "target_color": np.array([125, 255, 12]),
         # "purple":
         "display_color": np.array([128/255, 0/255, 128/255, 0.6])
    },
    "clamps": {
         "target_color": np.array([0, 255, 125]),
         "display_color": np.array([0, 100/255, 0, 0.6])
    },
}

print("Color mappings for classes:")
print("  kidney => Blue")
print("  small_intestine => Orange")
print("  instrument-shaft => Red")
print("  instrument-clasper => Yellow")
print("  instrument-wrist => Purple")
print("  clamps => Dark Green\n")

# Visualization helpers
def show_mask(mask, ax, color=None):
    if color is None:
        color = np.array([0/255, 0/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax, edge_color=None):
    if edge_color is None:
        edge_color = 'green'
    x0, y0 = box[0], box[1]
    w, h = box[2]-box[0], box[3]-box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edge_color, facecolor=(0,0,0,0), lw=2))

#Segmentation metrics helper functions
def dice_coefficient(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    intersection = np.sum(pred * gt)
    return (2.0 * intersection + smooth) / (np.sum(pred) + np.sum(gt) + smooth)

def iou_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    intersection = np.sum(pred * gt)
    union = np.sum(np.maximum(pred, gt))
    return (intersection + smooth) / (union + smooth)

def precision_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    tp = np.sum(pred * gt)
    fp = np.sum(pred * (1 - gt))
    return (tp + smooth) / (tp + fp + smooth)

def recall_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    tp = np.sum(pred * gt)
    fn = np.sum((1 - pred) * gt)
    return (tp + smooth) / (tp + fn + smooth)

#Bounding box for simple targets (single bounding box suffice)
def compute_bbox_from_mask_simple(mask_path, target_color_rgb):
    mask_img = cv2.imread(mask_path)
    if mask_img is None:
        print(f"Warning: Could not read mask image {mask_path}")
        return None
    mask_rgb = cv2.cvtColor(mask_img, cv2.COLOR_BGR2RGB)
    binary_mask = ((mask_rgb[:,:,0] == target_color_rgb[0]) &
                   (mask_rgb[:,:,1] == target_color_rgb[1]) &
                   (mask_rgb[:,:,2] == target_color_rgb[2])).astype(np.uint8)
    ys, xs = np.where(binary_mask)
    if ys.size == 0 or xs.size == 0:
        return None
    min_y, max_y = np.min(ys), np.max(ys)
    min_x, max_x = np.min(xs), np.max(xs)
    return [min_x, min_y, max_x, max_y]

def compute_bboxes_from_mask(mask_path, target_color_rgb, min_area=10):
    mask_img = cv2.imread(mask_path)
    if mask_img is None:
        print(f"Warning: Could not read mask image {mask_path}")
        return []
    mask_rgb = cv2.cvtColor(mask_img, cv2.COLOR_BGR2RGB)
    binary_mask = ((mask_rgb[:,:,0] == target_color_rgb[0]) &
                   (mask_rgb[:,:,1] == target_color_rgb[1]) &
                   (mask_rgb[:,:,2] == target_color_rgb[2])).astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
    bboxes = []
    for label in range(1, num_labels):
        area = stats[label, cv2.CC_STAT_AREA]
        if area < min_area:
            continue
        left = stats[label, cv2.CC_STAT_LEFT]
        top = stats[label, cv2.CC_STAT_TOP]
        width = stats[label, cv2.CC_STAT_WIDTH]
        height = stats[label, cv2.CC_STAT_HEIGHT]
        bboxes.append([left, top, left+width, top+height])
    return bboxes

def compute_bbox_from_predicted_mask(mask, margin=10, image_shape=None):
    if mask.ndim > 2:
        mask = mask.squeeze()
    ys, xs = np.where(mask)
    if ys.size == 0 or xs.size == 0:
        return None
    min_x, max_x = np.min(xs), np.max(xs)
    min_y, max_y = np.min(ys), np.max(ys)
    min_x_exp = min_x - margin
    max_x_exp = max_x + margin
    min_y_exp = min_y - margin
    max_y_exp = max_y + margin
    if image_shape is not None:
        H, W = image_shape[:2]
        min_x_exp = max(0, min_x_exp)
        min_y_exp = max(0, min_y_exp)
        max_x_exp = min(W, max_x_exp)
        max_y_exp = min(H, max_y_exp)
    return [min_x_exp, min_y_exp, max_x_exp, max_y_exp]

def clean_mask(mask, kernel_size=3, min_area=50):
    """
    Apply a simple smoothing (Gaussian blur) and remove small specks in a 0/1 mask
    mask shape can be (H, W) or (1, H, W)
    """
    if mask.ndim == 3 and mask.shape[0] == 1:
        mask = np.squeeze(mask, axis=0)

    mask_255 = (mask > 0).astype(np.uint8) * 255

    blur = cv2.GaussianBlur(mask_255, (kernel_size, kernel_size), 0)

    _, thresh = cv2.threshold(blur, 127, 255, cv2.THRESH_BINARY)

    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, connectivity=8)
    cleaned = np.zeros_like(thresh)
    for label_id in range(1, num_labels):
        area = stats[label_id, cv2.CC_STAT_AREA]
        if area >= min_area:
            cleaned[labels == label_id] = 255

    cleaned_binary = (cleaned > 0).astype(np.uint8)
    return cleaned_binary

#MedSAM inference
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, orig_H, orig_W):
    # Ensure box_1024 is of shape (B, 4)
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if box_torch.ndim == 1:
        box_torch = box_torch.unsqueeze(0)  # Now shape is (1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=False,
    )
    low_res_pred = torch.sigmoid(low_res_logits)
    low_res_pred = F.interpolate(low_res_pred, size=(orig_H, orig_W), mode="bilinear", align_corners=False)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

def compute_averages(all_metrics):
    """
    all_metrics[dataset] = [ { "image_name":..., class_name:{ dice, iou, precision, recall, gt_area? }, "aggregated": {...}, ...}, ... ]
    Returns: { dataset: {"macro": {...}, "weighted": {...}, "aggregated": {...}} }
    with keys "dice","iou","precision","recall" in each sub-dict.
    """

    results = {}
    for dataset, images_metrics in all_metrics.items():
        class_sums = {}
        aggregated_sums = { "dice": 0, "iou": 0, "precision": 0, "recall": 0, "count": 0 }
        for im in images_metrics:
            for cls_name, metrics in im.items():
                if cls_name == "image_name":
                    continue
                if cls_name == "aggregated":
                    aggregated_sums["dice"]      += metrics["dice"]
                    aggregated_sums["iou"]       += metrics["iou"]
                    aggregated_sums["precision"] += metrics["precision"]
                    aggregated_sums["recall"]    += metrics["recall"]
                    aggregated_sums["count"]     += 1
                else:
                    if cls_name not in class_sums:
                        class_sums[cls_name] = {
                            "dice": 0, "iou": 0, "precision": 0, "recall": 0,
                            "count": 0, "gt_area": 0
                        }
                    class_sums[cls_name]["dice"]      += metrics["dice"]
                    class_sums[cls_name]["iou"]       += metrics["iou"]
                    class_sums[cls_name]["precision"] += metrics["precision"]
                    class_sums[cls_name]["recall"]    += metrics["recall"]
                    class_sums[cls_name]["count"]     += 1
                    if "gt_area" in metrics:
                        class_sums[cls_name]["gt_area"] += metrics["gt_area"]

        # Aggregated average
        agg_count = aggregated_sums["count"]
        aggregated_avg = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if agg_count > 0:
            aggregated_avg["dice"]      = aggregated_sums["dice"] / agg_count
            aggregated_avg["iou"]       = aggregated_sums["iou"] / agg_count
            aggregated_avg["precision"] = aggregated_sums["precision"] / agg_count
            aggregated_avg["recall"]    = aggregated_sums["recall"] / agg_count

        # Macro average
        macro_sums = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        macro_count = 0
        # Weighted average
        w_sums = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        total_area = 0

        for cls_name, sums in class_sums.items():
            ccount = sums["count"]
            if ccount == 0:
                continue
            avg_dice = sums["dice"] / ccount
            avg_iou  = sums["iou"]  / ccount
            avg_prec = sums["precision"] / ccount
            avg_rec  = sums["recall"]    / ccount
            macro_sums["dice"]      += avg_dice
            macro_sums["iou"]       += avg_iou
            macro_sums["precision"] += avg_prec
            macro_sums["recall"]    += avg_rec
            macro_count += 1

            # Weighting
            cls_area = sums["gt_area"]  # total area across frames
            w_sums["dice"]      += (avg_dice * cls_area)
            w_sums["iou"]       += (avg_iou  * cls_area)
            w_sums["precision"] += (avg_prec * cls_area)
            w_sums["recall"]    += (avg_rec  * cls_area)
            total_area          += cls_area

        macro_avg = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if macro_count > 0:
            macro_avg["dice"]      = macro_sums["dice"] / macro_count
            macro_avg["iou"]       = macro_sums["iou"]  / macro_count
            macro_avg["precision"] = macro_sums["precision"] / macro_count
            macro_avg["recall"]    = macro_sums["recall"] / macro_count

        weighted_avg = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if total_area > 0:
            weighted_avg["dice"]      = w_sums["dice"] / total_area
            weighted_avg["iou"]       = w_sums["iou"]  / total_area
            weighted_avg["precision"] = w_sums["precision"] / total_area
            weighted_avg["recall"]    = w_sums["recall"] / total_area

        results[dataset] = {
            "macro": macro_avg,
            "weighted": weighted_avg,
            "aggregated": aggregated_avg
        }

    return results

def main(args: argparse.Namespace) -> None:
    device = args.device
    print("Loading MedSAM model...")
    medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()

    # Define segmentation output directory
    seg_path = os.path.join(args.data, f"seg_med_{args.iteration}")
    os.makedirs(seg_path, exist_ok=True)
    gt_overlay_path = os.path.join(os.path.dirname(seg_path), "gt_overlay_" + str(args.iteration))
    os.makedirs(gt_overlay_path, exist_ok=True)

    # For storing all metrics for each segmentation label
    all_metrics = {}
    exclude_images = {"video": ["00001", "00028"]}

    # Process each dataset directory
    for dataset in ["test", "video", "novel_views"]:
        image_dir = os.path.join(args.data, dataset, f"ours_{args.iteration}", "renders")
        if not os.path.exists(image_dir):
            print(f"Directory {image_dir} does not exist!")
            continue
        images = sorted([os.path.join(image_dir, f)
                         for f in os.listdir(image_dir)
                         if not os.path.isdir(os.path.join(image_dir, f))])
        # Check for mask directory
        mask_dir = os.path.join(args.data, dataset, f"ours_{args.iteration}", "all_masks")
        use_mask_prompt = os.path.exists(mask_dir)
        if use_mask_prompt:
            mask_files = sorted([os.path.join(mask_dir, f)
                                 for f in os.listdir(mask_dir)
                                 if not os.path.isdir(os.path.join(mask_dir, f))])
            if len(mask_files) != len(images):
                print("Warning: Number of mask files does not match number of images.")
        else:
            print(f"No mask directory found at {mask_dir}.")

        # Create output folders
        output_dir = os.path.join(seg_path, dataset)
        os.makedirs(output_dir, exist_ok=True)
        gt_dataset_dir = os.path.join(gt_overlay_path, dataset)
        os.makedirs(gt_dataset_dir, exist_ok=True)

        for i, image_path in enumerate(tqdm(images, desc=f"Processing {dataset} images")):
            image_name = os.path.splitext(os.path.basename(image_path))[0]
            print(f"Processing '{image_path}' ...")
            orig_image = cv2.imread(image_path)
            if orig_image is None:
                print(f"Could not read image {image_path}")
                continue
            orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
            orig_H, orig_W, _ = orig_image.shape

            # Preprocess image: resize to 1024x1024, normalize, convert to tensor
            img_1024 = transform.resize(orig_image, (1024, 1024), order=3,
                                        preserve_range=True, anti_aliasing=True).astype(np.uint8)
            img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max()-img_1024.min(), a_min=1e-8, a_max=None)
            img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)

            with torch.no_grad():
                image_embedding = medsam_model.image_encoder(img_1024_tensor)

            # Process segmentation per class using GT mask-derived prompts
            class_results = {}
            if use_mask_prompt:
                mask_path = mask_files[i]
                for class_name, class_info in CLASSES.items():
                    target_color = class_info["target_color"]
                    if class_name in ["instrument-shaft", "instrument-clasper", "instrument-wrist", "clamps"]:
                        bboxes = compute_bboxes_from_mask(mask_path, target_color_rgb=target_color, min_area=10)
                        if not bboxes:
                            print(f"No {class_name} pixels found in mask for image {image_name}; skipping.")
                            continue
                        inst_results = []
                        for bbox in bboxes:
                            bbox_1024 = (np.array(bbox) / np.array([orig_W, orig_H, orig_W, orig_H])) * 1024
                            seg = medsam_inference(medsam_model, image_embedding, bbox_1024, orig_H, orig_W)
                            comp_box = compute_bbox_from_predicted_mask(seg, margin=10, image_shape=(orig_H, orig_W))
                            inst_results.append({"mask": seg, "box": comp_box})
                        class_results[class_name] = {"instances": inst_results}
                    else:
                        computed_prompt = compute_bbox_from_mask_simple(mask_path, target_color_rgb=target_color)
                        if computed_prompt is None:
                            print(f"No {class_name} pixels found in mask for image {image_name}; skipping.")
                            continue
                        bbox_1024 = (np.array(computed_prompt) / np.array([orig_W, orig_H, orig_W, orig_H])) * 1024
                        seg = medsam_inference(medsam_model, image_embedding, bbox_1024, orig_H, orig_W)
                        comp_box = compute_bbox_from_predicted_mask(seg, margin=10, image_shape=(orig_H, orig_W))
                        class_results[class_name] = {"mask": seg, "box": comp_box}
            else:
                print("No mask prompts available; skipping segmentation for image", image_name)
                continue

            # Overlay the predicted masks over the image
            plt.figure(figsize=(10,10))
            plt.imshow(orig_image)
            show_bboxes = False 

            for class_name, res in class_results.items():
                display_color = CLASSES[class_name]["display_color"]
                if class_name in ["instrument-shaft", "instrument-clasper", "instrument-wrist", "clamps"]:
                    for inst in res["instances"]:
                        best_mask = inst["mask"]
                        if args.clean_mask:
                            # Postprocessing function
                            best_mask = clean_mask(best_mask, kernel_size=25, min_area=50)
                        show_mask(best_mask, plt.gca(), color=display_color)
                        if show_bboxes and inst["box"] is not None:
                            show_box(inst["box"], plt.gca(), edge_color=display_color)
                else:
                    best_mask = res["mask"]
                    if args.clean_mask:
                        best_mask = clean_mask(best_mask, kernel_size=25, min_area=50)
                    show_mask(best_mask, plt.gca(), color=display_color)
                    if show_bboxes and res["box"] is not None:
                        show_box(res["box"], plt.gca(), edge_color=display_color)
            plt.axis('off')
            plt.xticks([])
            plt.yticks([])
            seg_save_path = os.path.join(output_dir, image_name + '.png')
            plt.savefig(seg_save_path, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Saved segmentation result to '{seg_save_path}'")

            image_metrics = {}
            if dataset == "video" and image_name in exclude_images.get("video", []):
                print(f"Excluding image {image_name} from metrics.")
            else:
                gt_mask = cv2.imread(mask_files[i])
                if gt_mask is not None:
                    gt_mask_rgb = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2RGB)
                    for class_name, res in class_results.items():
                        target_color = CLASSES[class_name]["target_color"]
                        gt_binary = ((gt_mask_rgb[:,:,0] == target_color[0]) &
                                     (gt_mask_rgb[:,:,1] == target_color[1]) &
                                     (gt_mask_rgb[:,:,2] == target_color[2])).astype(np.uint8)
                        
                        gt_area = np.sum(gt_binary)


                        if class_name in ["instrument-shaft", "instrument-clasper", "instrument-wrist", "clamps"]:
                            inst_metrics = []
                            for inst in res.get("instances", []):
                                pred_mask = (inst["mask"] > 0.5).astype(np.uint8)
                                d = dice_coefficient(pred_mask, gt_binary)
                                iou = iou_score(pred_mask, gt_binary)
                                prec = precision_score(pred_mask, gt_binary)
                                rec = recall_score(pred_mask, gt_binary)
                                inst_metrics.append({"dice": d, "iou": iou, "precision": prec, "recall": rec, "gt_area": float(gt_area)})
                            if inst_metrics:
                                avg_dice = np.mean([m["dice"] for m in inst_metrics])
                                avg_iou = np.mean([m["iou"] for m in inst_metrics])
                                avg_prec = np.mean([m["precision"] for m in inst_metrics])
                                avg_rec = np.mean([m["recall"] for m in inst_metrics])
                                image_metrics[class_name] = {"dice": avg_dice, "iou": avg_iou, 
                                                             "precision": avg_prec, "recall": avg_rec, "gt_area": float(gt_area)}
                        else:
                            pred_mask = (res["mask"] > 0.5).astype(np.uint8)
                            d = dice_coefficient(pred_mask, gt_binary)
                            iou = iou_score(pred_mask, gt_binary)
                            prec = precision_score(pred_mask, gt_binary)
                            rec = recall_score(pred_mask, gt_binary)
                            image_metrics[class_name] = {"dice": d, "iou": iou, "precision": prec, "recall": rec, "gt_area": float(gt_area)}
                    
                    # Aggregated binary metrics
                    binary_pred = np.zeros((orig_H, orig_W), dtype=np.uint8)
                    for class_name, res in class_results.items():
                        if class_name in ["instrument-shaft", "instrument-clasper", "instrument-wrist", "clamps"]:
                            for inst in res.get("instances", []):
                                binary_pred = np.logical_or(binary_pred, (inst["mask"] > 0.5))
                        else:
                            binary_pred = np.logical_or(binary_pred, (res["mask"] > 0.5))
                    binary_gt = np.zeros((orig_H, orig_W), dtype=np.uint8)
                    for class_name, class_info in CLASSES.items():
                        target_color = class_info["target_color"]
                        gt_class = ((gt_mask_rgb[:,:,0] == target_color[0]) &
                                    (gt_mask_rgb[:,:,1] == target_color[1]) &
                                    (gt_mask_rgb[:,:,2] == target_color[2])).astype(np.uint8)
                        binary_gt = np.logical_or(binary_gt, gt_class)
                    agg_dice = dice_coefficient(binary_pred.astype(np.uint8), binary_gt.astype(np.uint8))
                    agg_iou = iou_score(binary_pred.astype(np.uint8), binary_gt.astype(np.uint8))
                    agg_prec = precision_score(binary_pred.astype(np.uint8), binary_gt.astype(np.uint8))
                    agg_rec = recall_score(binary_pred.astype(np.uint8), binary_gt.astype(np.uint8))
                    image_metrics["aggregated"] = {"dice": agg_dice, "iou": agg_iou,
                                                   "precision": agg_prec, "recall": agg_rec}
                    print(f"Metrics for image {image_name}: {image_metrics}")
            if dataset not in all_metrics:
                all_metrics[dataset] = []
            all_metrics[dataset].append({"image_name": image_name, **image_metrics})

            #Save GT overlay
            if use_mask_prompt:
                gt_mask_img = cv2.imread(mask_files[i])
                if gt_mask_img is not None:
                    gt_mask_img = cv2.cvtColor(gt_mask_img, cv2.COLOR_BGR2RGB)
                    overlay = cv2.addWeighted(orig_image, 0.5, gt_mask_img, 0.5, 0)
                    gt_save_path = os.path.join(gt_dataset_dir, image_name + '.png')
                    plt.imsave(gt_save_path, overlay)
                    print(f"Saved GT overlay to '{gt_save_path}'")
                else:
                    print(f"Warning: Could not load GT mask from {mask_files[i]}")

    #Compute and save. Define your own path here
    with open("segmentation_metrics.csv", "w", newline="") as csvfile:
        fieldnames = ["dataset", "image_name", "class", "dice", "iou", "precision", "recall"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for dataset, images_metrics in all_metrics.items():
            for im in images_metrics:
                image_name = im["image_name"]
                for cls, metrics in im.items():
                    if cls == "image_name":
                        continue
                    writer.writerow({
                        "dataset": dataset,
                        "image_name": image_name,
                        "class": cls,
                        "dice": metrics["dice"],
                        "iou": metrics["iou"],
                        "precision": metrics["precision"],
                        "recall": metrics["recall"]
                    })

    print("\nAverage Metrics per Dataset:")
    for dataset, images_metrics in all_metrics.items():
        sum_metrics = {}
        for im in images_metrics:
            for cls, metrics in im.items():
                if cls == "image_name":
                    continue
                if cls not in sum_metrics:
                    sum_metrics[cls] = {"dice": 0, "iou": 0, "precision": 0, "recall": 0, "count": 0}
                sum_metrics[cls]["dice"] += metrics["dice"]
                sum_metrics[cls]["iou"] += metrics["iou"]
                sum_metrics[cls]["precision"] += metrics["precision"]
                sum_metrics[cls]["recall"] += metrics["recall"]
                sum_metrics[cls]["count"] += 1
        print(f"Dataset: {dataset}")
        for cls, vals in sum_metrics.items():
            avg_dice = vals["dice"] / vals["count"]
            avg_iou = vals["iou"] / vals["count"]
            avg_prec = vals["precision"] / vals["count"]
            avg_rec = vals["recall"] / vals["count"]
            print(f"  Class {cls}: Dice = {avg_dice:.4f}, IoU = {avg_iou:.4f}, Precision = {avg_prec:.4f}, Recall = {avg_rec:.4f}")

    all_avg_results = compute_averages(all_metrics)

    with open("averaged_metrics.csv", "w", newline="") as csvfile:
            fieldnames = [
                "dataset",
                "macro_dice", "macro_iou", "macro_precision", "macro_recall",
                "weighted_dice", "weighted_iou", "weighted_precision", "weighted_recall",
                "aggregated_dice", "aggregated_iou", "aggregated_precision", "aggregated_recall"
            ]
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for dataset, row in all_avg_results.items():
                writer.writerow({
                    "dataset": dataset,
                    "macro_dice": row["macro"]["dice"],
                    "macro_iou": row["macro"]["iou"],
                    "macro_precision": row["macro"]["precision"],
                    "macro_recall": row["macro"]["recall"],
                    "weighted_dice": row["weighted"]["dice"],
                    "weighted_iou": row["weighted"]["iou"],
                    "weighted_precision": row["weighted"]["precision"],
                    "weighted_recall": row["weighted"]["recall"],
                    "aggregated_dice": row["aggregated"]["dice"],
                    "aggregated_iou": row["aggregated"]["iou"],
                    "aggregated_precision": row["aggregated"]["precision"],
                    "aggregated_recall": row["aggregated"]["recall"]
                })

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="MedSAM segmentation on a dataset with dynamic prompts from GT masks. "
                    "Computes segmentation metrics and saves results to CSV. "
                    "Excludes specified images from 'video' dataset and computes aggregated binary metrics."
    )
    parser.add_argument("--model-type", type=str, required=True,
                        help="Type of model to load, e.g., 'vit_b' for MedSAM")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to the MedSAM checkpoint, e.g., medsam_vit_b.pth")
    parser.add_argument("--data", type=str, required=True,
                        help="Path to the data folder (with subfolders for images, masks, etc.)")
    parser.add_argument("--iteration", type=int, required=True,
                        help="Chosen iteration number")
    parser.add_argument("--image", action="store_true",
                        help="If set, compute image embeddings on the fly (default for MedSAM)")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to run the model on (e.g., 'cuda' or 'cpu')")
    parser.add_argument("--clean-mask", action="store_true",
                        help="If set, postprocess the predicted masks (smooth + remove small specks)")
    args = parser.parse_args()
    start = time.time()
    main(args)
    end = time.time()
    print('Processing time:', end - start)

#sample usage: 
# python segment_masks_medsam.py --checkpoint checkpoints/medsam_vit_b_latest.pth --model-type vit_b --data ../../output/endovis/6000_mlp --iteration 6000 --device "cuda:1"
