"""
Evaluate model selection for baselines.
"""

from __future__ import annotations

import argparse
import json
import math
import random
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf

from src.data.bav_real_data import BavTrueDataloader
from src.data.bav_samplerv2 import BavSampler
from src.data.utils import SamplePermutationHelper
from src.models.benchmarks.tnp import TNP
from src.utils import DataAttr

# Helpers


def string2bool(b):
    if isinstance(b, bool):
        return b
    if b.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif b.lower() in ("no", "false", "f", "n", "0"):
        return False


def _format_seconds(x: float) -> str:
    if x < 1e-6:
        return f"{x*1e9:.1f} ns"
    if x < 1e-3:
        return f"{x*1e6:.1f} µs"
    if x < 1.0:
        return f"{x*1e3:.2f} ms"
    if x < 60.0:
        return f"{x:.3f} s"
    return f"{x/60.0:.2f} min"


def _expand_batch_for_mc(batch: DataAttr, n_mc: int):
    # Use unified helper to repeat along batch and permute targets per repetition
    batch_out, _ = SamplePermutationHelper.repeat_and_permute_batch(batch, n_mc)
    return batch_out


def _build_mc_test_batches(batches: Sequence[DataAttr], n_mc: int) -> List[DataAttr]:
    return [_expand_batch_for_mc(b, n_mc=n_mc) for b in batches]


def _build_baseline_model(checkpoint_path: str, device: str = "cpu") -> TNP:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = OmegaConf.create(checkpoint["config"])
    model = hydra.utils.instantiate(config["model"])

    state_dict = checkpoint["model_state_dict"]
    if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
        # Remove _orig_mod. prefix from keys
        state_dict = {
            key.replace("_orig_mod.", ""): value for key, value in state_dict.items()
        }
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    return model, config


# Metrics helpers
def _prec(tp: int, fp: int) -> float:
    return tp / (tp + fp) if (tp + fp) > 0 else 0.0


def _rec(tp: int, fn: int) -> float:
    return tp / (tp + fn) if (tp + fn) > 0 else 0.0


def _f1(p: float, r: float) -> float:
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0


def _compute_selection_metrics(
    cm: np.ndarray,
    label_a: str,
    label_b: str,
    a_times: List[float],
    b_times: List[float],
    total_time: float,
) -> Dict[str, Any]:
    total = cm.sum()
    acc = (cm[0, 0] + cm[1, 1]) / total if total else float("nan")
    acc_A = cm[0, 0] / cm[0].sum() if cm[0].sum() > 0 else float("nan")
    acc_B = cm[1, 1] / cm[1].sum() if cm[1].sum() > 0 else float("nan")

    TP_A, FN_A, FP_A = cm[0, 0], cm[0, 1], cm[1, 0]
    TP_B, FN_B, FP_B = cm[1, 1], cm[1, 0], cm[0, 1]

    pA, rA = _prec(int(TP_A), int(FP_A)), _rec(int(TP_A), int(FN_A))
    pB, rB = _prec(int(TP_B), int(FP_B)), _rec(int(TP_B), int(FN_B))
    fA, fB = _f1(pA, rA), _f1(pB, rB)
    macro_f1 = (fA + fB) / 2.0

    a_dt = np.array(a_times, dtype=float)
    b_dt = np.array(b_times, dtype=float)
    timing = {
        "per_call_engine_A_mean_s": float(a_dt.mean()) if a_dt.size else float("nan"),
        "per_call_engine_A_std_s": (
            float(a_dt.std(ddof=0)) if a_dt.size else float("nan")
        ),
        "per_call_engine_B_mean_s": float(b_dt.mean()) if b_dt.size else float("nan"),
        "per_call_engine_B_std_s": (
            float(b_dt.std(ddof=0)) if b_dt.size else float("nan")
        ),
        "total_inference_time": float(total_time),
        "throughput_samples_per_sec": (
            float(total / total_time) if total_time > 0 else float("nan")
        ),
    }

    return {
        "labels": {"row_col_0": label_a, "row_col_1": label_b},
        "accuracy": float(acc),
        "macro_f1": float(macro_f1),
        "per_class": {
            label_a: {"precision": float(pA), "recall": float(rA), "f1": float(fA)},
            label_b: {"precision": float(pB), "recall": float(rB), "f1": float(fB)},
        },
        "per_data_accuracy": {label_a: float(acc_A), label_b: float(acc_B)},
        "timing": timing,
    }


class ModelSelectionEvaluator:
    def __init__(
        self,
        ckpt_a: str,
        ckpt_b: str,
        label_a: str = "rho=4/3",
        label_b: str = "rho=1",
        n_data: int = 64,
        n_points: int = 400,
        n_mc: int = 10,
        device: str = "cpu",
        save_dir: str = "./eval_results/model_selection",
    ) -> None:
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.ckpt_a = ckpt_a
        self.ckpt_b = ckpt_b
        self.label_a = label_a
        self.label_b = label_b
        self.n_data = n_data
        self.n_points = n_points
        self.n_mc = n_mc
        # Timing follows device automatically
        self.timer = "cuda" if self.device.type == "cuda" else "cpu"

        # Data
        self._build_data()

        # Models and inference engines
        self._load_models_and_inference()

        # Metrics scaffold
        self._initialize_metrics()

    def _load_models_and_inference(self):
        # Models and Inference engines
        self.model_a, self.cfg_a = _build_baseline_model(
            self.ckpt_a, device=self.device
        )
        self.model_b, self.cfg_b = _build_baseline_model(
            self.ckpt_b, device=self.device
        )

        self.engine_a = self.model_a
        self.engine_b = self.model_b

    def _build_data(self) -> None:
        sampler_a = BavSampler(RHO_A=4.0 / 3.0, device=self.device)
        sampler_b = BavSampler(RHO_A=1.0, device=self.device)
        data_a = [
            sampler_a.generate_test_batch(1, num_target=self.n_points)
            for _ in range(self.n_data)
        ]
        data_b = [
            sampler_b.generate_test_batch(1, num_target=self.n_points)
            for _ in range(self.n_data)
        ]
        if self.device == torch.device("cuda"):
            # hacky bugfix on data in cpu
            data_a = [batch.to(self.device) for batch in data_a]
            data_b = [batch.to(self.device) for batch in data_b]
        self.data_a_mc = _build_mc_test_batches(data_a, n_mc=self.n_mc)
        self.data_b_mc = _build_mc_test_batches(data_b, n_mc=self.n_mc)

    def _initialize_metrics(self):
        self.metrics: Dict[str, Any] = {
            "model_stats": {
                "checkpoint_a": self.ckpt_a,
                "checkpoint_b": self.ckpt_b,
                "device": str(self.device),
                "timer": self.timer,
            },
            "evaluation": {
                "n_data_per_class": self.n_data,
                "n_points": self.n_points,
                "n_mc": self.n_mc,
                "confusion_matrix": None,
            },
            "timing": {
                "per_call_engine_A_times": [],
                "per_call_engine_B_times": [],
                "total_inference_time": 0.0,
            },
            "labels": {"row_col_0": self.label_a, "row_col_1": self.label_b},
        }

    def _time_call(self, fn: Callable, *args, **kwargs):
        use_cuda = (self.timer == "cuda") and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = fn(*args, **kwargs)
            end.record()
            end.synchronize()
            return out, start.elapsed_time(end) / 1000.0
        t0 = time.perf_counter()
        out = fn(*args, **kwargs)
        return out, (time.perf_counter() - t0)

    def _evaluate_dataset(self, engine_a, engine_b, batches: Sequence[DataAttr]):
        cm_row = np.zeros((2,), dtype=np.int64)
        a_times: List[float] = []
        b_times: List[float] = []
        batch_stats: List[Dict[str, Any]] = []
        all_joint_values: List[Dict[str, Any]] = []

        n_batch = len(batches)
        step = max(1, math.ceil(n_batch / 100))
        for i, batch in enumerate(batches):
            ll_a, dt_a = self._time_call(
                engine_a.eval_log_joint_likelihood,
                batch.xc,
                batch.yc,
                batch.xt,
                batch.yt,
            )
            ll_b, dt_b = self._time_call(
                engine_b.eval_log_joint_likelihood,
                batch.xc,
                batch.yc,
                batch.xt,
                batch.yt,
            )
            a_times.append(dt_a)
            b_times.append(dt_b)
            joint_a = ll_a
            joint_b = ll_b
            joint_a_mean = joint_a.mean()
            joint_b_mean = joint_b.mean()
            joint_a_std = joint_a.std()
            joint_b_std = joint_b.std()
            diff = joint_a_mean - joint_b_mean
            pred_col = (
                0
                if joint_a_mean > joint_b_mean
                else (1 if joint_b_mean > joint_a_mean else 0)
            )
            cm_row[pred_col] += 1
            # collect per-batch joint stats
            batch_stats.append(
                {
                    "batch_index": i,
                    "engine_A_joint_mean": float(joint_a_mean.item()),
                    "engine_A_joint_std": float(joint_a_std.item()),
                    "engine_B_joint_mean": float(joint_b_mean.item()),
                    "engine_B_joint_std": float(joint_b_std.item()),
                    "pred_col": int(pred_col),
                    "diff": float(diff.item()),
                }
            )
            all_joint_values.append(
                {
                    "batch_index": i,
                    "engine_A_joint_values": joint_a.cpu().numpy().tolist(),
                    "engine_B_joint_values": joint_b.cpu().numpy().tolist(),
                }
            )
            # print timing information
            if (i + 1) % step == 0 or (i + 1) == n_batch:
                print(
                    f"Batch {i} timing - Engine A: {dt_a:.4f}s, Engine B: {dt_b:.4f}s",
                    flush=True,
                )
                print(
                    f"Batch {i} stats - Engine A: {joint_a_mean:.4f} ± {joint_a_std:.4f}, Engine B: {joint_b_mean:.4f} ± {joint_b_std:.4f}",
                    flush=True,
                )
        return cm_row, a_times, b_times, batch_stats, all_joint_values

    def run_evaluation(self):
        cm = np.zeros((2, 2), dtype=np.int64)
        all_a_times: List[float] = []
        all_b_times: List[float] = []
        eval_start = time.time()
        with torch.no_grad():
            cm0, a0, b0, stats_a, all_joint_values_a = self._evaluate_dataset(
                self.engine_a, self.engine_b, self.data_a_mc
            )
            cm1, a1, b1, stats_b, all_joint_values_b = self._evaluate_dataset(
                self.engine_a, self.engine_b, self.data_b_mc
            )
        cm[0] += cm0
        cm[1] += cm1
        all_a_times += a0 + a1
        all_b_times += b0 + b1
        total_time = time.time() - eval_start

        self.metrics["timing"]["per_call_engine_A_times"] = all_a_times
        self.metrics["timing"]["per_call_engine_B_times"] = all_b_times
        self.metrics["timing"]["total_inference_time"] = total_time
        self.metrics["evaluation"]["confusion_matrix"] = cm.tolist()
        # Save per-batch joint stats by true dataset
        self.metrics["evaluation"]["batch_joint_stats"] = {
            self.label_a: stats_a,
            self.label_b: stats_b,
        }

        m = _compute_selection_metrics(
            cm, self.label_a, self.label_b, all_a_times, all_b_times, total_time
        )
        self.metrics["evaluation"].update(
            {
                "accuracy": m["accuracy"],
                "macro_f1": m["macro_f1"],
                "per_class": m["per_class"],
                "per_data_accuracy": m["per_data_accuracy"],
                "throughput_samples_per_sec": m["timing"]["throughput_samples_per_sec"],
            }
        )

        self.save_results()
        print(f"\nEvaluation complete in {total_time:.2f}s")
        print(f"Confusion matrix (rows=true, cols=pred):\n{cm}")
        print(f"Accuracy: {m['accuracy']:.3f} | MacroF1: {m['macro_f1']:.3f}")
        return cm, m

    def save_results(self):
        metrics_path = self.save_dir / "model_selection_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")
        self._save_timing_csv()
        self._save_batch_joint_stats_csv()
        self._generate_report()
        cm = np.array(self.metrics["evaluation"]["confusion_matrix"])  # type: ignore[index]
        np.save(self.save_dir / "confusion_matrix.npy", cm)

    def _save_timing_csv(self):
        import csv

        csv_path = self.save_dir / "timing_selection.csv"
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["sample_idx", "engine_A_time_s", "engine_B_time_s"])
            for i, (ta, tb) in enumerate(
                zip(
                    self.metrics["timing"]["per_call_engine_A_times"],
                    self.metrics["timing"]["per_call_engine_B_times"],
                )
            ):
                writer.writerow([i, ta, tb])
        print(f"Saved timing data to {csv_path}")

    def _save_batch_joint_stats_csv(self):
        import csv

        csv_path = self.save_dir / "batch_joint_stats.csv"
        evalm = self.metrics.get("evaluation", {})
        batch_stats = evalm.get("batch_joint_stats", {})
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    "true_dataset",
                    "batch_index",
                    "engine_A_joint_mean",
                    "engine_A_joint_std",
                    "engine_B_joint_mean",
                    "engine_B_joint_std",
                    "pred_label",
                ]
            )
            for dataset_label, stats in batch_stats.items():
                for s in stats:
                    pred = s.get("pred_col", 0)
                    pred_label = self.label_a if pred == 0 else self.label_b
                    writer.writerow(
                        [
                            dataset_label,
                            s.get("batch_index", ""),
                            s.get("engine_A_joint_mean", ""),
                            s.get("engine_A_joint_std", ""),
                            s.get("engine_B_joint_mean", ""),
                            s.get("engine_B_joint_std", ""),
                            pred_label,
                        ]
                    )
        print(f"Saved batch joint stats to {csv_path}")

    def _generate_report(self):
        report_path = self.save_dir / "model_selection_report.txt"
        cm = np.array(self.metrics["evaluation"]["confusion_matrix"])  # type: ignore[index]
        timing = self.metrics["timing"]
        evalm = self.metrics["evaluation"]
        # Build content lines
        lines: List[str] = []

        def w(s: str = ""):
            lines.append(s)

        w("=" * 80)
        w("ACE MODEL SELECTION REPORT")
        w("=" * 80)
        w("")
        w("MODELS")
        w("-" * 40)
        w(f"Engine A: {self.ckpt_a}")
        w(f"Engine B: {self.ckpt_b}")
        w(f"Device: {self.metrics['model_stats']['device']}")
        w(f"Timer: {self.metrics['model_stats']['timer']}")
        w("")
        w("EVALUATION SETTINGS")
        w("-" * 40)
        w(f"n_data_per_class: {self.n_data}")
        w(f"n_points_per_sample: {self.n_points}")
        w(f"n_mc_per_sample: {self.n_mc}")
        w("")
        w("CONFUSION MATRIX (rows=true, cols=pred)")
        w("-" * 40)
        w(str(cm))
        w("")
        w("SUMMARY METRICS")
        w("-" * 40)
        w(f"Accuracy: {evalm['accuracy']:.6f}")
        w(f"Macro F1: {evalm['macro_f1']:.6f}")
        w(f"Per-data accuracy: {evalm['per_data_accuracy']}")
        w("")
        w("TIMING")
        w("-" * 40)
        a_times = np.array(timing["per_call_engine_A_times"])  # type: ignore[index]
        b_times = np.array(timing["per_call_engine_B_times"])  # type: ignore[index]
        a_mean = a_times.mean() if a_times.size else float("nan")
        a_std = a_times.std() if a_times.size else float("nan")
        b_mean = b_times.mean() if b_times.size else float("nan")
        b_std = b_times.std() if b_times.size else float("nan")
        w(
            f"Engine A per-call: mean={_format_seconds(a_mean)}, std={_format_seconds(a_std)}"
        )
        w(
            f"Engine B per-call: mean={_format_seconds(b_mean)}, std={_format_seconds(b_std)}"
        )
        w(f"Total evaluation time: {_format_seconds(timing['total_inference_time'])}")
        w(f"Throughput (samples/s): {evalm.get('throughput_samples_per_sec', 0.0):.2f}")
        w("")
        # New: Batch-level joint stats section
        w("JOINT LOG-LIKELIHOOD STATS")
        w("-" * 40)
        batch_stats = evalm.get("batch_joint_stats", {})
        for true_label, stats in batch_stats.items():
            w(f"True dataset: {true_label}")
            if stats:
                a_means = np.array(
                    [s["engine_A_joint_mean"] for s in stats], dtype=float
                )
                a_stds = np.array([s["engine_A_joint_std"] for s in stats], dtype=float)
                b_means = np.array(
                    [s["engine_B_joint_mean"] for s in stats], dtype=float
                )
                b_stds = np.array([s["engine_B_joint_std"] for s in stats], dtype=float)
                w(
                    f"  Engine A (across batches): mean={a_means.mean():.6f}, std_of_means={a_means.std():.6f}, avg_batch_std={a_stds.mean():.6f}"
                )
                w(
                    f"  Engine B (across batches): mean={b_means.mean():.6f}, std_of_means={b_means.std():.6f}, avg_batch_std={b_stds.mean():.6f}"
                )
            else:
                w("  No batch stats available")
            w("  Per-batch (idx | A mean±std | B mean±std | pred)")
            header = f"    {'idx':>4} | {'A mean±std':>22} | {'B mean±std':>22} | {'pred':>4}"
            w(header)
            for s in stats:
                a_str = (
                    f"{s['engine_A_joint_mean']:.6f} ± {s['engine_A_joint_std']:.6f}"
                )
                b_str = (
                    f"{s['engine_B_joint_mean']:.6f} ± {s['engine_B_joint_std']:.6f}"
                )
                pred = s.get("pred_col", 0)
                pred_label = self.label_a if pred == 0 else self.label_b
                w(
                    f"    {s['batch_index']:>4} | {a_str:>22} | {b_str:>22} | {pred_label:>4}"
                )
            w("")
        w("=" * 80)

        content = "\n".join(lines)
        with open(report_path, "w") as f:
            f.write(content)
        # Also save a duplicate to result.txt for convenience
        result_path = self.save_dir / "result.txt"
        with open(result_path, "w") as f:
            f.write(content)
        print(f"Saved evaluation report to {report_path}")
        print(f"Saved duplicate report to {result_path}")


class ModelSelectionEvaluatorRealData(ModelSelectionEvaluator):

    def __init__(
        self,
        ckpt_a,
        ckpt_b,
        label_a="rho=4/3",
        label_b="rho=1",
        data_path="data/bav_real",
        n_mc=10,
        device="cpu",
        save_dir="./eval_results/model_selection_real",
        seed=42,
    ) -> None:
        self.seed = seed
        self.data_path = data_path
        self.ckpt_a = ckpt_a
        self.ckpt_b = ckpt_b
        self.label_a = label_a
        self.label_b = label_b
        self.n_mc = n_mc
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        # Timing follows device automatically
        self.timer = "cuda" if self.device.type == "cuda" else "cpu"

        self._build_data()
        self._load_models_and_inference()
        self.n_data = len(self.data)  # number of real data samples
        self.n_points = self.data[0].xt.shape[1]

        self._initialize_metrics()

    def _build_data(self) -> None:
        # Load real data using the BavTrueDataloader
        data_loader = BavTrueDataloader(self.data_path, device=self.device)
        data = data_loader.load_data()
        self.data = _build_mc_test_batches(data, n_mc=self.n_mc)

    def run_evaluation(self):
        cm = np.zeros((2, 2), dtype=np.int64)
        all_a_times: List[float] = []
        all_b_times: List[float] = []
        eval_start = time.time()
        with torch.no_grad():
            cm0, a0, b0, batch_stats, all_joint_values = self._evaluate_dataset(
                self.engine_a, self.engine_b, self.data
            )

        cm[0] += cm0
        cm[1] += cm0  # both rows are true real data
        all_a_times += a0
        all_b_times += b0
        total_time = time.time() - eval_start

        # pretty print cm
        print(f"\nEvaluation complete in {total_time:.2f}s")
        print(f"Confusion matrix (rows=true, cols=pred):\n{cm}")

        # get diff
        diffs = [s["diff"] for s in batch_stats]

        # Separate indices for positive and negative diffs
        pos_idx = [i for i, d in enumerate(diffs) if d > 0]
        neg_idx = [i for i, d in enumerate(diffs) if d < 0]

        # plt figsize
        plt.figure(figsize=(10, 4))
        # Scatter with different colors
        plt.scatter(
            pos_idx,
            [diffs[i] for i in pos_idx],
            color="green",
            label=r"$\rho = \frac{4}{3}$ is better",
        )
        plt.scatter(
            neg_idx,
            [diffs[i] for i in neg_idx],
            color="blue",
            label=r"$\rho = 1$ is better",
        )
        # Horizontal line at 0
        plt.axhline(0, color="red", linestyle="--")

        plt.xlabel("Sample Index (15 participants, each splited into 2, 30 total)")
        plt.title(f"LL Difference ({self.label_a} - {self.label_b})")
        plt.legend()
        plt.savefig("joint_loglikelihood_differences.png", dpi=300)
        plt.close()

        # Save a single CSV with prefix "real_data" into --save-dir
        # One row per batch with the metrics from batch_stats
        import csv

        csv_path = self.save_dir / f"real_data_batch_stats_seed{self.seed}.csv"
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    "batch_index",
                    "engine_A_joint_mean",
                    "engine_A_joint_std",
                    "engine_B_joint_mean",
                    "engine_B_joint_std",
                    "engine_A_time_s",
                    "engine_B_time_s",
                    "pred_col",
                    "pred_label",
                    "diff",
                ]
            )
            # Build a mapping from batch index to timing to ensure correct alignment
            a_time_map = {i: t for i, t in enumerate(a0)}
            b_time_map = {i: t for i, t in enumerate(b0)}
            for s in batch_stats:
                idx = int(s.get("batch_index", 0))
                pred = int(s.get("pred_col", 0))
                pred_label = self.label_a if pred == 0 else self.label_b
                writer.writerow(
                    [
                        idx,
                        s.get("engine_A_joint_mean", ""),
                        s.get("engine_A_joint_std", ""),
                        s.get("engine_B_joint_mean", ""),
                        s.get("engine_B_joint_std", ""),
                        a_time_map.get(idx, ""),
                        b_time_map.get(idx, ""),
                        pred,
                        pred_label,
                        s.get("diff", ""),
                    ]
                )
        print(f"Saved batch stats CSV to {csv_path}")

        # save all joint values as csv
        all_joint_values_path = (
            self.save_dir / f"real_data_all_joint_values_seed{self.seed}.csv"
        )
        with open(all_joint_values_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    "batch_index",
                    "mc_index",
                    "engine_A_joint_values",
                    "engine_B_joint_values",
                ]
            )
            for s in all_joint_values:
                for i in range(len(s.get("engine_A_joint_values", []))):
                    writer.writerow(
                        [
                            s.get("batch_index", ""),
                            i,
                            ";".join(
                                map(str, [s.get("engine_A_joint_values", [])[i]])
                            ),
                            ";".join(
                                map(str, [s.get("engine_B_joint_values", [])[i]])
                            ),
                        ]
                    )
        print(f"Saved all joint values CSV to {all_joint_values_path}")


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate model selection between two ACE checkpoints"
    )
    parser.add_argument(
        "--ckpt-a",
        type=str,
        default="checkpoints/tnpdmg_baseline/bav_rho43/best_model.pt",
        help="Path to checkpoint A",
    )
    parser.add_argument(
        "--ckpt-b",
        type=str,
        default="checkpoints/tnpdmg_baseline/bav_rho1/best_model.pt",
        help="Path to checkpoint B",
    )
    parser.add_argument(
        "--label-a", type=str, default="rho=4/3", help="Display label for engine A"
    )
    parser.add_argument(
        "--label-b", type=str, default="rho=1", help="Display label for engine B"
    )
    parser.add_argument(
        "--n-data",
        type=int,
        default=64,
        help="Number of data per class (true A / true B)",
    )
    parser.add_argument(
        "--n-points", type=int, default=400, help="Number of target points per sample"
    )
    parser.add_argument(
        "--n-mc", type=int, default=10, help="Number of MC permutations per sample"
    )
    parser.add_argument("--device", type=str, default="cpu", help="Device to run on")
    parser.add_argument(
        "--save-dir",
        type=str,
        default="./eval_results/model_selection",
        help="Directory to save results",
    )
    parser.add_argument(
        "--real-data",
        type=string2bool,
        default=False,
        help="Use real BAV data for evaluation",
    )
    parser.add_argument(
        "--data-path",
        type=str,
        default="data/bav_real",
        help="Path to real BAV data (if --real-data is set)",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )

    args = parser.parse_args()

    set_seed(args.seed)

    if args.real_data:
        print("Running real data model selection evaluation")
        evaluator = ModelSelectionEvaluatorRealData(
            ckpt_a=args.ckpt_a,
            ckpt_b=args.ckpt_b,
            label_a=args.label_a,
            label_b=args.label_b,
            data_path=args.data_path,
            n_mc=args.n_mc,
            device=args.device,
            save_dir=args.save_dir,
            seed=args.seed,
        )

    else:
        print("Running synthetic data model selection evaluation")
        evaluator = ModelSelectionEvaluator(
            ckpt_a=args.ckpt_a,
            ckpt_b=args.ckpt_b,
            label_a=args.label_a,
            label_b=args.label_b,
            n_data=args.n_data,
            n_points=args.n_points,
            n_mc=args.n_mc,
            device=args.device,
            save_dir=args.save_dir,
        )

    evaluator.run_evaluation()


if __name__ == "__main__":
    main()
