import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
import numpy as np

import argparse
import time

import cv2
import os
from sam2.build_sam import build_sam2_video_predictor
import matplotlib.pyplot as plt
from PIL import Image

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

import cv2

def show_mask_with_color(mask, ax, display_color):
    mask = np.squeeze(mask)
    if mask.ndim != 2:
        raise ValueError(f"Expected a 2D mask after squeezing, but got shape {mask.shape}")
    h, w = mask.shape
    overlay = np.zeros((h, w, 4), dtype=np.float32)
    overlay[mask > 0] = display_color
    ax.imshow(overlay)

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def plot_seg_results(
    image_path,
    results_dict,
    show_bboxes=True,
    output_path=None,
    fig_size=(9, 6),
):
    """
    Plots the first image along with all segmentation masks from `results_dict`.
    Optionally overlays bounding boxes. If `output_path` is given, saves the result
    instead of displaying.

    Args:
        image_path (str): Path to the first image (frame) in your video.
        results_dict (dict): The dictionary storing segmentation results per class,
                             e.g. {
                                 "kidney": {
                                     "frame_idx": 0,
                                     "bboxes": [...],
                                     "out_mask_logits": [...],
                                 },
                                 "clamps": {
                                     ...
                                 },
                                 ...
                             }
        show_bboxes (bool): Whether to overlay the bounding boxes on the image.
        output_path (str or None): If provided, the figure is saved there instead
                                   of using plt.show().
        fig_size (tuple): Figure size for matplotlib.
    """

    if not os.path.isfile(image_path):
        print(f"Warning: image {image_path} not found.")
        return

    orig_image = Image.open(image_path)
    plt.figure(figsize=fig_size)
    plt.imshow(orig_image)

    for class_name, res in results_dict.items():
        bboxes = res.get("bboxes", [])
        out_mask_logits_list = res.get("out_mask_logits", [])

        if show_bboxes and bboxes:
            for bbox in bboxes:
                show_box(bbox, plt.gca())

        #Overlay the segmentation masks
        for mask_logit in out_mask_logits_list:
            if hasattr(mask_logit, "cpu"):
                mask_np = (mask_logit > 0.0).cpu().numpy()
            else:
                mask_np = (mask_logit > 0.0).astype(np.uint8)
            show_mask(mask_np, plt.gca(), obj_id=class_name)

    plt.axis("off")
    plt.xticks([])
    plt.yticks([])

    if output_path is not None:
        plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
        plt.close()
        print(f"Saved visualization to '{output_path}'")
    else:
        plt.show()

def show_box(box, ax, edge_color='green'):
    """
    Draw a bounding box on the matplotlib axis
    box is [x1, y1, x2, y2]
    """
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h,
                 edgecolor=edge_color, facecolor=(0,0,0,0), lw=2))
    
import numpy as np

def dice_coefficient(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    intersection = np.sum(pred * gt)
    return (2.0 * intersection + smooth) / (np.sum(pred) + np.sum(gt) + smooth)

def iou_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    intersection = np.sum(pred * gt)
    union = np.sum(np.maximum(pred, gt))
    return (intersection + smooth) / (union + smooth)

def precision_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    tp = np.sum(pred * gt)
    fp = np.sum(pred * (1 - gt))
    return (tp + smooth) / (tp + fp + smooth)

def recall_score(pred, gt, smooth=1e-6):
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)
    tp = np.sum(pred * gt)
    fn = np.sum((1 - pred) * gt)
    return (tp + smooth) / (tp + fn + smooth)

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0
        if isinstance(obj_id, int):
            cmap_idx = obj_id
        elif isinstance(obj_id, str):
            cmap_idx = abs(hash(obj_id)) % 10
        color = np.array([*cmap(cmap_idx)[:3], 0.6])

    h, w = mask.shape[-2:]
    mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_reshaped)

def compute_bboxes_from_mask(mask_img, target_color_rgb, min_area=5):
    if mask_img is None:
        print(f"Warning: Could not read mask image {mask_img}")
        return []
    mask_rgb = cv2.cvtColor(mask_img, cv2.COLOR_BGR2RGB)
    binary_mask = ((mask_rgb[:,:,0] == target_color_rgb[0]) &
                   (mask_rgb[:,:,1] == target_color_rgb[1]) &
                   (mask_rgb[:,:,2] == target_color_rgb[2])).astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
    bboxes = []
    for label in range(1, num_labels):
        area = stats[label, cv2.CC_STAT_AREA]
        if area < min_area:
            continue
        left = stats[label, cv2.CC_STAT_LEFT]
        top = stats[label, cv2.CC_STAT_TOP]
        width = stats[label, cv2.CC_STAT_WIDTH]
        height = stats[label, cv2.CC_STAT_HEIGHT]
        bboxes.append([left, top, left+width, top+height])
    return bboxes

def compute_macro_micro_aggregated(per_class_metrics, aggregated_frame_metrics):
        """Computing macro-average and micro-averaged scores 
        Returns a dict with:
          {
            "macro": { dice, iou, precision, recall },
            "micro": { dice, iou, precision, recall },
            "aggregated": { dice, iou, precision, recall }
          }
        We treat frames as the "items" to average over.
        For each class, we average across frames => store (class_avg).
        Then macro average is the average of class_avg across classes.
        Micro average weights each class_avg by the sum of gt_area across frames.
        Aggregated average is the average of aggregated_frame_metrics across frames.
        """
        class_sums = {} 
        for class_name, frame_dict in per_class_metrics.items():
            dice_sum = iou_sum = prec_sum = rec_sum = 0.0
            count = 0
            gt_area_sum = 0.0
            for idx, metrics_dict in frame_dict.items():
                dice_sum += metrics_dict["dice"]
                iou_sum  += metrics_dict["iou"]
                prec_sum += metrics_dict["precision"]
                rec_sum  += metrics_dict["recall"]
                count    += 1
                gt_area_sum += metrics_dict["gt_area"]  # accumulate total area
            if count > 0:
                class_sums[class_name] = {
                    "dice": dice_sum / count,
                    "iou": iou_sum  / count,
                    "precision": prec_sum / count,
                    "recall": rec_sum / count,
                    "count": count,
                    "gt_area_sum": gt_area_sum
                }

        # Macro : average across classes
        macro_count = 0
        macro_dice = macro_iou = macro_prec = macro_rec = 0.0
        # Micro : area weighted
        total_area = 0.0
        micro_dice_sum = micro_iou_sum = micro_prec_sum = micro_rec_sum = 0.0

        for cname, cvals in class_sums.items():
            macro_dice += cvals["dice"]
            macro_iou  += cvals["iou"]
            macro_prec += cvals["precision"]
            macro_rec  += cvals["recall"]
            macro_count += 1

            area_ = cvals["gt_area_sum"]
            micro_dice_sum += cvals["dice"]      * area_
            micro_iou_sum  += cvals["iou"]       * area_
            micro_prec_sum += cvals["precision"] * area_
            micro_rec_sum  += cvals["recall"]    * area_
            total_area     += area_

        macro_res = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if macro_count > 0:
            macro_res["dice"]      = macro_dice / macro_count
            macro_res["iou"]       = macro_iou  / macro_count
            macro_res["precision"] = macro_prec / macro_count
            macro_res["recall"]    = macro_rec  / macro_count

        micro_res = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if total_area > 0:
            micro_res["dice"]      = micro_dice_sum / total_area
            micro_res["iou"]       = micro_iou_sum  / total_area
            micro_res["precision"] = micro_prec_sum / total_area
            micro_res["recall"]    = micro_rec_sum  / total_area

        # Aggregated => single, averaged across frames
        agg_dice_sum = agg_iou_sum = agg_prec_sum = agg_rec_sum = 0.0
        agg_count = 0
        for idx, mdict in aggregated_frame_metrics.items():
            agg_dice_sum += mdict["dice"]
            agg_iou_sum  += mdict["iou"]
            agg_prec_sum += mdict["precision"]
            agg_rec_sum  += mdict["recall"]
            agg_count    += 1
        aggregated_res = {"dice": 0, "iou": 0, "precision": 0, "recall": 0}
        if agg_count > 0:
            aggregated_res["dice"]      = agg_dice_sum / agg_count
            aggregated_res["iou"]       = agg_iou_sum  / agg_count
            aggregated_res["precision"] = agg_prec_sum / agg_count
            aggregated_res["recall"]    = agg_rec_sum  / agg_count

        return {
            "macro": macro_res,
            "micro": micro_res,
            "aggregated": aggregated_res
        }

def main(args: argparse.Namespace) -> None:

    # See segment_masks_sam1.py for how to prepare the CLASSES dictionary. Should be based on the labels.json for masks in your dataset. 

    CLASSES = {
        "kidney": {
            "target_color": np.array([255, 55, 0]),
            "display_color": np.array([0/255, 0/255, 255/255, 0.6])
        },
        "small_intestine": {
            "target_color": np.array([124, 155, 5]),
            "display_color": np.array([255/255, 69/255, 0/255, 0.6])
        },
        "instrument-shaft": {
            "target_color": np.array([0, 255, 0]),
            "display_color": np.array([255/255, 0/255, 0/255, 0.6])
        },
        "instrument-clasper": {
            "target_color": np.array([0, 255, 255]),
            "display_color": np.array([255/255, 255/255, 0/255, 0.6])
        },
        "instrument-wrist": {
            "target_color": np.array([125, 255, 12]),
            "display_color": np.array([128/255, 0/255, 128/255, 0.6])
        },
        "clamps": {
            "target_color": np.array([0, 255, 125]),
            "display_color": np.array([0, 100/255, 0, 0.6])
        },
    }

    print("Color mappings for classes:")
    print("  kidney => Blue")
    print("  small_intestine => Orange")
    print("  instrument-shaft => Red")
    print("  instrument-clasper => Yellow")
    print("  instrument-wrist => Purple")
    print("  clamps => Dark Green\n")

    torch.autocast(args.device, dtype=torch.bfloat16).__enter__()
    # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True


    gt_masks_paths = sorted([os.path.join(args.mask_path, f) for f in os.listdir(args.mask_path) if os.path.isfile(os.path.join(args.mask_path, f))])

    gt_masks = []
    gt_masks = []
    for mask_path in gt_masks_paths:
        gt_mask = cv2.imread(mask_path)
        gt_masks.append(gt_mask)

    class_masks = {class_name: [] for class_name in CLASSES}

    # Process each mask
    for mask_path in gt_masks_paths:
        gt_mask = cv2.imread(mask_path)
        mask_image = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2RGB)
        # Process each class
        for class_name, class_data in CLASSES.items():
            target_color = CLASSES[class_name]["target_color"]
            gt_binary = ((mask_image[:, :, 0] == target_color[0]) &
                    (mask_image[:, :, 1] == target_color[1]) &
                    (mask_image[:, :, 2] == target_color[2])).astype(np.uint8)
            class_masks[class_name].append(gt_binary)

    sam2_checkpoint = args.checkpoint
    model_cfg = args.cfg

    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda:1")
    video_dir = args.video_dir

    frame_names = [
        p for p in os.listdir(video_dir)
        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
    ]
    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

    inference_state = predictor.init_state(video_path=video_dir)

    predictor.reset_state(inference_state)

    # Initialize a dictionary to record which object IDs correspond to which class
    class_to_obj_ids = {}
    # We'll assign unique object IDs incrementally
    current_obj_id = 0

    out_dir = args.save_folder
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # For each class in your CLASSES dictionary, search for a frame where that class appear
    for class_name, class_info in CLASSES.items():
        print(f"Processing class '{class_name}'...")
        target_color = class_info["target_color"]
        
        # Initialize a flag to indicate whether we found the class
        found = False
        ann_frame_idx = None
        
        # Search through the frames until we find a matching region
        for idx in range(len(gt_masks)):
            gt_mask = gt_masks[idx]
            bboxes = compute_bboxes_from_mask(gt_mask, target_color, min_area=100)
            if len(bboxes) > 0:
                ann_frame_idx = idx
                found = True
                print(f"  Found class '{class_name}' in frame {ann_frame_idx}.")
                break

        if not found:
            print(f"Warning: Class '{class_name}' not shown in this video.")
            continue

        # To record all object IDS belonging to the class
        class_obj_ids = []
        
        # Process each discrete region found in this frame
        for bbox in bboxes:
            # Assign a unique object ID
            obj_id = current_obj_id
            current_obj_id += 1
            class_obj_ids.append(obj_id)
            
            # Call SAM2's API with the computed bounding box
            _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=ann_frame_idx,
                obj_id=obj_id,
                box=bbox,
            )
            print(f"  Added object id {obj_id} for class '{class_name}' in frame {ann_frame_idx} with bbox: {bbox}")
            
            plt.figure(figsize=(9, 6))
            plt.title(f"Frame {ann_frame_idx} | Class: {class_name} | Obj ID: {obj_id}")
            frame_path = os.path.join(video_dir, frame_names[ann_frame_idx])
            frame_image = Image.open(frame_path)
            plt.imshow(frame_image)
            show_box(bbox, plt.gca())
            plt.axis("off")
            plt.show()
        
        # Record the object IDs for this class
        class_to_obj_ids[class_name] = class_obj_ids

    print("Final mapping from class names to object IDs:")
    print(class_to_obj_ids)

    save_folder = args.save_folder

    print('running through videos')
    # run propagation throughout the video and collect the results in a dict
    video_segments = {}  # video_segments contains the per-frame segmentation results
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
    obj_id_to_class = {}
    for class_name, obj_ids in class_to_obj_ids.items():
        for obj_id in obj_ids:
            obj_id_to_class[obj_id] = class_name

    print("Object ID to Class mapping:")
    print(obj_id_to_class)

    #use 1 as stride , we want to run inference on all images in video 
    vis_frame_stride = 1
    plt.close("all")
    for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
        frame_path = os.path.join(video_dir, frame_names[out_frame_idx])
        frame_img = Image.open(frame_path)
        plt.imshow(frame_img)
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            # Search the class for this object id
            class_name = obj_id_to_class.get(out_obj_id, None)
            if class_name is None:
                continue  
            display_color = CLASSES[class_name]["display_color"]
            show_mask_with_color(out_mask, plt.gca(), display_color)
        save_path = os.path.join(save_folder, f"frame_{out_frame_idx}.png")
        plt.axis('off')
        plt.xticks([])
        plt.yticks([])
        ax = plt.gca()
        ax.set_position([0, 0, 1, 1])
        plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        print(f"Saved overlay for frame {out_frame_idx} to {save_path}")
        plt.close()

    # Frames to exclude, ie; testing images. 
    excluded_frames = {1, 28}

    # Per_class_metrics: { class_name: { frame_idx: {"dice":..., "iou":..., "precision":..., "recall":...} } }
    per_class_metrics = {}
    # aggregated metrics per frame (treating all classes as one binary segmentation)
    aggregated_frame_metrics = {}

    for frame_idx in sorted(video_segments.keys()):
        if frame_idx in excluded_frames:
            continue 
        
        # For each frame, get the predicted masks dictionary.
        # Each entry: obj_id -> predicted binary mask (numpy array)
        frame_preds = video_segments[frame_idx]
        
        # For aggregated evaluation, initialize empty arrays
        aggregated_pred = None
        aggregated_gt = None

        # Loop over each class in your classes
        for class_name in class_masks.keys():
            if class_name not in per_class_metrics:
                per_class_metrics[class_name] = {}
            
            # Retrieve the ground truth mask for this class for the current frame
            # If this frame index exceeds the available frames for that class, skip
            if frame_idx >= len(class_masks[class_name]):
                continue
            gt = class_masks[class_name][frame_idx]

            # Use obj_id_to_class to decide which predicted masks correspond to the current class
            merged_pred = np.zeros_like(gt, dtype=np.uint8)
            for obj_id, pred_mask in frame_preds.items():
                if obj_id_to_class.get(obj_id, None) == class_name:
                    # Merge w/ pixel-wise maximum (assuming masks are binary, 0 or 1)
                    merged_pred = np.maximum(merged_pred, pred_mask)
                                             
            gt_area = np.sum(gt)
            pred_sum = np.sum(merged_pred)
            
            if gt_area == 0:
                # If ground truth is entirely empty for this class in this frame
                if pred_sum == 0:
                    continue
                else:
                    # GT=0 but Pred>0 -> all false positives => dice=0, iou=0, etc
                    dice = 0.0
                    iou = 0.0
                    prec = 0.0
                    rec = 0.0
            else:
                # Here GT>0
                if pred_sum == 0:
                    # predicted mask is empty but GT is not => dice=0
                    dice = 0.0
                    iou = 0.0
                    prec = 0.0
                    rec = 0.0
                else:
                    # normal calculation
                    dice = dice_coefficient(merged_pred, gt)
                    iou  = iou_score(merged_pred, gt)
                    prec = precision_score(merged_pred, gt)
                    rec  = recall_score(merged_pred, gt)

            per_class_metrics[class_name][frame_idx] = {
                "dice": dice,
                "iou": iou,
                "precision": prec,
                "recall": rec,
                "gt_area": float(gt_area)
            }
            
            # For aggregated metrics, merge all classes together
            if aggregated_pred is None:
                aggregated_pred = merged_pred.copy()
                aggregated_gt = gt.copy()
            else:
                aggregated_pred = np.maximum(aggregated_pred, merged_pred)
                aggregated_gt = np.maximum(aggregated_gt, gt)
        
        # Compute aggregated metrics for the frame
        if aggregated_pred is not None and aggregated_gt is not None:
            if np.sum(aggregated_pred) == 0 and np.sum(aggregated_gt) == 0:
                agg_dice = 1.0
                agg_iou = 1.0
                agg_prec = 1.0
                agg_rec = 1.0
            elif np.sum(aggregated_pred) == 0 or np.sum(aggregated_gt) == 0:
                agg_dice = 0.0
                agg_iou = 0.0
                agg_prec = 0.0
                agg_rec = 0.0
            else:
                agg_dice = dice_coefficient(aggregated_pred, aggregated_gt)
                agg_iou = iou_score(aggregated_pred, aggregated_gt)
                agg_prec = precision_score(aggregated_pred, aggregated_gt)
                agg_rec = recall_score(aggregated_pred, aggregated_gt)
            aggregated_frame_metrics[frame_idx] = {
                "dice": agg_dice,
                "iou": agg_iou,
                "precision": agg_prec,
                "recall": agg_rec
            }

    # Compute average metrics per class
    average_metrics_per_class = {}
    for class_name, frame_dict in per_class_metrics.items():
        # Only include frames not in the excluded set
        valid_metrics = [metrics for idx, metrics in frame_dict.items() if idx not in excluded_frames]
        if len(valid_metrics) == 0:
            continue
        avg_dice = np.mean([m["dice"] for m in valid_metrics])
        avg_iou = np.mean([m["iou"] for m in valid_metrics])
        avg_prec = np.mean([m["precision"] for m in valid_metrics])
        avg_rec = np.mean([m["recall"] for m in valid_metrics])
        average_metrics_per_class[class_name] = {
            "dice": avg_dice,
            "iou": avg_iou,
            "precision": avg_prec,
            "recall": avg_rec
        }

    # Compute agg. metrics across all valid frames
    valid_aggregated_metrics = [m for idx, m in aggregated_frame_metrics.items() if idx not in excluded_frames]
    if len(valid_aggregated_metrics) > 0:
        overall_avg_dice = np.mean([m["dice"] for m in valid_aggregated_metrics])
        overall_avg_iou = np.mean([m["iou"] for m in valid_aggregated_metrics])
        overall_avg_prec = np.mean([m["precision"] for m in valid_aggregated_metrics])
        overall_avg_rec = np.mean([m["recall"] for m in valid_aggregated_metrics])
    else:
        overall_avg_dice = overall_avg_iou = overall_avg_prec = overall_avg_rec = None

    print("Per-frame metrics (per class):")
    for class_name, frame_dict in per_class_metrics.items():
        print(f"Class '{class_name}':")
        for frame_idx, metrics_dict in sorted(frame_dict.items()):
            print(f"  Frame {frame_idx}: Dice: {metrics_dict['dice']:.3f}, IoU: {metrics_dict['iou']:.3f}, Precision: {metrics_dict['precision']:.3f}, Recall: {metrics_dict['recall']:.3f}")

    print("\nAverage metrics per class (excluding frames 1 and 28):")
    for class_name, avg_metrics in average_metrics_per_class.items():
        print(f"Class '{class_name}': Dice: {avg_metrics['dice']:.3f}, IoU: {avg_metrics['iou']:.3f}, Precision: {avg_metrics['precision']:.3f}, Recall: {avg_metrics['recall']:.3f}")

    print("\nAggregated overall metrics (binary segmentation across all classes, excluding frames 1 and 28):")
    print(f"Dice: {overall_avg_dice:.3f}, IoU: {overall_avg_iou:.3f}, Precision: {overall_avg_prec:.3f}, Recall: {overall_avg_rec:.3f}")

    results_summary = compute_macro_micro_aggregated(per_class_metrics, aggregated_frame_metrics)

    print("\n=== Overall Macro Averages ===")
    print(f"Dice={results_summary['macro']['dice']:.4f}, IoU={results_summary['macro']['iou']:.4f}, "
          f"Prec={results_summary['macro']['precision']:.4f}, Rec={results_summary['macro']['recall']:.4f}")

    print("\n=== Overall Micro (Weighted) Averages ===")
    print(f"Dice={results_summary['micro']['dice']:.4f}, IoU={results_summary['micro']['iou']:.4f}, "
          f"Prec={results_summary['micro']['precision']:.4f}, Rec={results_summary['micro']['recall']:.4f}")

    print("\n=== Overall Aggregated (Binary OR) Averages ===")
    print(f"Dice={results_summary['aggregated']['dice']:.4f}, IoU={results_summary['aggregated']['iou']:.4f}, "
          f"Prec={results_summary['aggregated']['precision']:.4f}, Rec={results_summary['aggregated']['recall']:.4f}")
    
    csv_fields = [
        "macro_dice","macro_iou","macro_precision","macro_recall",
        "micro_dice","micro_iou","micro_precision","micro_recall",
        "aggregated_dice","aggregated_iou","aggregated_precision","aggregated_recall"
    ]
    with open("averaged_metrics.csv", "w", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_fields)
        writer.writeheader()
        row_data = {
            "macro_dice": results_summary["macro"]["dice"],
            "macro_iou": results_summary["macro"]["iou"],
            "macro_precision": results_summary["macro"]["precision"],
            "macro_recall": results_summary["macro"]["recall"],
            "micro_dice": results_summary["micro"]["dice"],
            "micro_iou": results_summary["micro"]["iou"],
            "micro_precision": results_summary["micro"]["precision"],
            "micro_recall": results_summary["micro"]["recall"],
            "aggregated_dice": results_summary["aggregated"]["dice"],
            "aggregated_iou": results_summary["aggregated"]["iou"],
            "aggregated_precision": results_summary["aggregated"]["precision"],
            "aggregated_recall": results_summary["aggregated"]["recall"],
        }
        writer.writerow(row_data)
    print("Saved macro, micro, aggregated averages to 'averaged_metrics.csv'.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="MedSAM segmentation on a dataset with dynamic prompts from GT masks. "
                    "Computes segmentation metrics and saves results to CSV. "
                    "Excludes specified images from 'video' dataset and computes aggregated binary metrics."
    )
    parser.add_argument("--video-dir", type=str, required=True, 
                        help="folder containing jpg images corresponding to frames in a video.")
    parser.add_argument("--save-folder", type=str, required=True, 
                        help="where to save the SAM 2 masks.")
    parser.add_argument("--checkpoint", type=str, required=True, 
                        help="Path to the MedSAM checkpoint, e.g., medsam_vit_b.pth")
    parser.add_argument("--cfg", type=str, required=True, 
                        help="Path to model config")
    parser.add_argument("--data", type=str, required=True,
                        help="Path to the data folder (with subfolders for images, masks, etc.)")
    parser.add_argument("--mask-path", type=str, required=True, default='../endovis_sttr_downsized/masks_separated',
                        help="Path to the mask path folder")
    parser.add_argument("--iteration", type=int, required=True,
                        help="Chosen iteration number")
    parser.add_argument("--image", action="store_true",
                        help="If set, compute image embeddings on the fly (default for MedSAM)")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to run the model on (e.g., 'cuda' or 'cpu')")
    parser.add_argument("--clean-mask", action="store_true",
                        help="If set, postprocess the predicted masks (smooth + remove small specks)")
    args = parser.parse_args()
    start = time.time()
    main(args)
    end = time.time()
    print('Processing time:', end - start)

#sample usg:
'''
python segment_masks_sam2.py --video-dir endovis_sttr_downsized/vid --checkpoint checkpoints/sam2_hiera_tiny.pt --cfg /configs/sam2/sam2_hiera_t.yaml --data ../../output/endovis/pulling --mask-path endovis_sttr_downsized/masks_separated --iteration 6000 --device "cuda:1"
'''

'''
python segment_masks_sam2.py --video-dir endovis_sttr_downsized/vid --checkpoint checkpoints/sam2_hiera_tiny.pt --cfg /configs/sam2/sam2_hiera_t.yaml --data ../../output/endovis/pulling --mask-path endovis_sttr_downsized/masks_separated --iteration 6000 --device "cuda:1" --save-folder 'SAM2_tiny_segments' 
'''

# python segment_masks_sam2.py --video-dir ../../../endovis18_seq1.2.2/vid --checkpoint checkpoints/sam2_hiera_tiny.pt --cfg /configs/sam2/sam2_hiera_t.yaml --data ../../output/endovis/pulling --mask-path endovis_sttr_downsized/masks_separated --iteration 6000 --device "cuda:1" --save-folder 'SAM2_tiny_segments_endovis2' 
