import time
import os
import cv2
import csv
import argparse
import numpy as np
from tqdm import tqdm


from benchmark import (
    MAE, Fmeasure, WeightedFmeasure, Emeasure, SmeasureStrict,
    Tmeasure, IoU, Dice, HD95, ASD
)


def load_data_into_memory(gt_root, pred_root, max_samples=None):
    """
    Pre-load images into memory to exclude I/O latency from the efficiency benchmark.
    """
    print(f"Pre-loading data from {gt_root} ...")
    data_cache = []

    if not os.path.exists(gt_root):
        print(f"Error: GT path does not exist: {gt_root}")
        return []

    names = [f for f in os.listdir(gt_root) if f.lower().endswith(('.png', '.jpg', '.bmp', '.tif'))]

    # If no images found, generate random data for simulation (Robustness)
    if not names:
        print("Warning: No images found. Generating random noise for simulation...")
        for _ in range(50):
            pred = np.random.rand(352, 352) * 255
            gt = (np.random.rand(352, 352) > 0.5).astype(np.float64) * 255
            data_cache.append((pred, gt))
        return data_cache

    count = 0
    for name in tqdm(names):
        if max_samples and count >= max_samples:
            break

        gt_path = os.path.join(gt_root, name)

        # Try to find corresponding prediction
        pred_path = os.path.join(pred_root, name)
        if not os.path.exists(pred_path):
            pred_path = os.path.join(pred_root, os.path.splitext(name)[0] + '.png')
            if not os.path.exists(pred_path):
                continue

        # Read in grayscale. Normalization happens inside the metric class.
        gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

        if gt is None or pred is None:
            continue

        data_cache.append((pred, gt))
        count += 1

    print(f"Data loaded. Total samples: {len(data_cache)}\n")
    return data_cache


def benchmark_single_metric(name, metric_class, data_cache, loop_factor=1):
    """
    Measure the inference time of a single metric.
    """
    # 1. Instantiate the metric with specific parameters matching the paper
    if name == "wFm":
        metric = metric_class(beta=1)
    elif name == "Fmeasure":
        metric = metric_class(beta=0.3)
    elif name == "Sm":
        metric = metric_class(alpha=0.5)
    elif name == "Tm":
        metric = metric_class(alpha=0.8)  # Default setting in paper
    else:
        metric = metric_class()

    # 2. Warm-up (avoid cold start latency)
    if len(data_cache) > 0:
        p, g = data_cache[0]
        metric.step(p, g)

    # 3. Start Timer
    start_time = time.perf_counter()

    total_frames = 0
    for _ in range(loop_factor):
        for pred, gt in data_cache:
            metric.step(pred, gt)
            total_frames += 1

    end_time = time.perf_counter()

    # 4. Calculate Stats
    total_duration = end_time - start_time
    avg_ms = (total_duration / total_frames) * 1000
    fps = total_frames / total_duration

    return avg_ms, fps


def main():
    parser = argparse.ArgumentParser(description="Computational Efficiency Benchmark Script")
    parser.add_argument('--gt', type=str, required=True, help='Path to Ground Truth folder')
    parser.add_argument('--pred', type=str, required=True, help='Path to Prediction folder')
    parser.add_argument('--loop', type=int, default=1, help='Loop factor to increase timing precision (default: 1)')
    parser.add_argument('--output', type=str, default='efficiency_results.csv', help='Output CSV file path')
    args = parser.parse_args()

    # 1. Load Data
    data_cache = load_data_into_memory(args.gt, args.pred)
    if not data_cache:
        return

    # 2. Define Metrics to Test
    # (Display Name, Class)
    metrics_to_test = [
        ("MAE", MAE),
        ("IoU", IoU),
        ("Dice", Dice),
        ("Fmeasure", Fmeasure),
        ("wFm", WeightedFmeasure),
        ("Sm", SmeasureStrict),
        ("Em", Emeasure),
        ("Tm", Tmeasure),  # Our proposed metric
        ("95HD", HD95),
        ("ASD", ASD)
    ]

    print(f"{'Metric':<12} | {'Status':<20}")
    print("-" * 40)

    results = []

    # 3. Run Benchmark
    for name, cls in metrics_to_test:
        print(f"Benchmarking {name}...", end="\r")
        try:
            avg_ms, fps = benchmark_single_metric(name, cls, data_cache, args.loop)
            results.append({
                "Metric": name,
                "Time(ms)": round(avg_ms, 3),
                "FPS": round(fps, 1)
            })
            print(f"{name:<12} | Done (FPS: {fps:.1f})   ")
        except Exception as e:
            print(f"{name:<12} | Error: {e}")

    # 4. Sort and Display Results
    # Sort by FPS descending (Faster first)
    results.sort(key=lambda x: x["FPS"], reverse=True)

    print("\n" + "=" * 15 + " Benchmark Results " + "=" * 15)
    print(f"{'Metric':<10} | {'Time (ms/img)':<15} | {'FPS':<10}")
    print("-" * 45)
    for row in results:
        print(f"{row['Metric']:<10} | {row['Time(ms)']:<15} | {row['FPS']:<10}")
    print("=" * 45)

    # 5. Save to CSV
    try:
        with open(args.output, mode='w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=["Metric", "Time(ms)", "FPS"])
            writer.writeheader()
            writer.writerows(results)
        print(f"\nResults saved to {args.output}")
    except Exception as e:
        print(f"Failed to save CSV: {e}")


if __name__ == "__main__":
    main()