"""
Evaluate model selection.
"""

from __future__ import annotations

import argparse
import random
import json
import math
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

import numpy as np
import torch
from omegaconf import OmegaConf
from src.data.utils import SamplePermutationHelper

from src.data.bav_samplerv2 import BavSampler
from src.data.bav_real_data import BavTrueDataloader
from src.models.ace import AmortizedConditioningEngine, InferenceEngine2
from src.models.modules import Embedder, MixtureGaussian, Transformer
from src.utils import DataAttr
import matplotlib.pyplot as plt

# Helpers

def string2bool(b):
    if isinstance(b, bool):
        return b
    if b.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif b.lower() in ("no", "false", "f", "n", "0"):
        return False


def _format_seconds(x: float) -> str:
    if x < 1e-6:
        return f"{x*1e9:.1f} ns"
    if x < 1e-3:
        return f"{x*1e6:.1f} µs"
    if x < 1.0:
        return f"{x*1e3:.2f} ms"
    if x < 60.0:
        return f"{x:.3f} s"
    return f"{x/60.0:.2f} min"

def _expand_batch_for_mc(batch: DataAttr, n_mc: int):
    # Use unified helper to repeat along batch and permute targets per repetition
    batch_out, _ = SamplePermutationHelper.repeat_and_permute_batch(batch, n_mc)
    return batch_out


def _build_mc_test_batches(batches: Sequence[DataAttr], n_mc: int) -> List[DataAttr]:
    return [_expand_batch_for_mc(b, n_mc=n_mc) for b in batches]


# Model construction and loading
def _build_ace_model(config: OmegaConf) -> AmortizedConditioningEngine:
    cfg = config.model
    embedder = Embedder(
        dim_x=cfg.dim_x,
        dim_y=cfg.dim_y,
        hidden_dim=cfg.embedder.hidden_dim,
        out_dim=cfg.dim_model,
        depth=cfg.embedder.depth,
    )
    backbone = Transformer(
        num_layers=cfg.backbone.num_layers,
        dim_model=cfg.dim_model,
        num_head=cfg.backbone.num_heads,
        dim_feedforward=cfg.backbone.dim_feedforward,
        dropout=cfg.backbone.dropout,
    )
    head = MixtureGaussian(
        dim_y=cfg.dim_y,
        dim_model=cfg.dim_model,
        dim_feedforward=cfg.head.dim_feedforward,
        num_components=cfg.head.num_components,
    )
    return AmortizedConditioningEngine(
        embedder=embedder,
        backbone=backbone,
        head=head,
        max_buffer_size=cfg.max_buffer_size,
        targets_block_size_for_buffer_attend=cfg.targets_block_size_for_buffer_attend,
    )


def _load_model(
    checkpoint_path: Union[str, Path],
    device: Union[str, torch.device] = "cpu",
    compile_model: bool = False,
) -> Tuple[AmortizedConditioningEngine, OmegaConf]:
    checkpoint_path = Path(checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = OmegaConf.create(checkpoint["config"])  # type: ignore[index]
    model = _build_ace_model(config)
    state_dict = checkpoint["model_state_dict"]  # type: ignore[index]

    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    model = model.to(device)
    if compile_model:
        model = torch.compile(model)  # type: ignore[assignment]
    model.eval()
    return model, config


# Metrics helpers
def _prec(tp: int, fp: int) -> float:
    return tp / (tp + fp) if (tp + fp) > 0 else 0.0


def _rec(tp: int, fn: int) -> float:
    return tp / (tp + fn) if (tp + fn) > 0 else 0.0


def _f1(p: float, r: float) -> float:
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0


def _compute_selection_metrics(cm: np.ndarray, label_a: str, label_b: str,
                               a_times: List[float], b_times: List[float], total_time: float) -> Dict[str, Any]:
    total = cm.sum()
    acc = (cm[0, 0] + cm[1, 1]) / total if total else float("nan")
    acc_A = cm[0, 0] / cm[0].sum() if cm[0].sum() > 0 else float("nan")
    acc_B = cm[1, 1] / cm[1].sum() if cm[1].sum() > 0 else float("nan")

    TP_A, FN_A, FP_A = cm[0, 0], cm[0, 1], cm[1, 0]
    TP_B, FN_B, FP_B = cm[1, 1], cm[1, 0], cm[0, 1]

    pA, rA = _prec(int(TP_A), int(FP_A)), _rec(int(TP_A), int(FN_A))
    pB, rB = _prec(int(TP_B), int(FP_B)), _rec(int(TP_B), int(FN_B))
    fA, fB = _f1(pA, rA), _f1(pB, rB)
    macro_f1 = (fA + fB) / 2.0

    a_dt = np.array(a_times, dtype=float)
    b_dt = np.array(b_times, dtype=float)
    timing = {
        "per_call_engine_A_mean_s": float(a_dt.mean()) if a_dt.size else float("nan"),
        "per_call_engine_A_std_s": float(a_dt.std(ddof=0)) if a_dt.size else float("nan"),
        "per_call_engine_B_mean_s": float(b_dt.mean()) if b_dt.size else float("nan"),
        "per_call_engine_B_std_s": float(b_dt.std(ddof=0)) if b_dt.size else float("nan"),
        "total_inference_time": float(total_time),
        "throughput_samples_per_sec": float(total / total_time) if total_time > 0 else float("nan"),
    }

    return {
        "labels": {"row_col_0": label_a, "row_col_1": label_b},
        "accuracy": float(acc),
        "macro_f1": float(macro_f1),
        "per_class": {
            label_a: {"precision": float(pA), "recall": float(rA), "f1": float(fA)},
            label_b: {"precision": float(pB), "recall": float(rB), "f1": float(fB)},
        },
        "per_data_accuracy": {label_a: float(acc_A), label_b: float(acc_B)},
        "timing": timing,
    }


class ModelSelectionEvaluator:
    def __init__(
        self,
        ckpt_a: str,
        ckpt_b: str,
        label_a: str = "rho=4/3",
        label_b: str = "rho=1",
        n_data: int = 64,
        n_points: int = 400,
        K: int = 16,
        n_mc: int = 10,
        device: str = "cpu",
        save_dir: str = "./eval_results/model_selection",
        compile: bool = False,
    ) -> None:
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.ckpt_a = ckpt_a
        self.ckpt_b = ckpt_b
        self.label_a = label_a
        self.label_b = label_b
        self.n_data = n_data
        self.n_points = n_points
        self.K = K
        self.n_mc = n_mc
        # Timing follows device automatically
        self.timer = "cuda" if self.device.type == "cuda" else "cpu"
        self.compile = compile

        # Data
        self._build_data()

        # Models and inference engines
        self._load_models_and_inference()

        # Metrics scaffold
        self._initialize_metrics()

    def _load_models_and_inference(self):
        # Models and Inference engines
        self.model_a, self.cfg_a = _load_model(self.ckpt_a, device=self.device, compile_model=self.compile)
        self.model_b, self.cfg_b = _load_model(self.ckpt_b, device=self.device, compile_model=self.compile)
        self.engine_a = InferenceEngine2.from_trained_model(self.model_a, 128, 128).to(self.device)
        self.engine_b = InferenceEngine2.from_trained_model(self.model_b, 128, 128).to(self.device)
        self.engine_a.eval(); self.engine_b.eval()


    def _build_data(self) -> None:
        sampler_a = BavSampler(RHO_A=4.0 / 3.0, device=self.device)
        sampler_b = BavSampler(RHO_A=1.0, device=self.device)
        data_a = [sampler_a.generate_test_batch(1, num_target=self.n_points) for _ in range(self.n_data)]
        data_b = [sampler_b.generate_test_batch(1, num_target=self.n_points) for _ in range(self.n_data)]
        if self.device == torch.device("cuda"):
            # hacky bugfix on data in cpu
            data_a = [batch.to(self.device) for batch in data_a]
            data_b = [batch.to(self.device) for batch in data_b]
        self.data_a_mc = _build_mc_test_batches(data_a, n_mc=self.n_mc)
        self.data_b_mc = _build_mc_test_batches(data_b, n_mc=self.n_mc)
    
    def _initialize_metrics(self):
        self.metrics: Dict[str, Any] = {
            "model_stats": {
                "checkpoint_a": self.ckpt_a,
                "checkpoint_b": self.ckpt_b,
                "device": str(self.device),
                "compiled": self.compile,
                "timer": self.timer,
            },
            "evaluation": {
                "n_data_per_class": self.n_data,
                "n_points": self.n_points,
                "K": self.K,
                "n_mc": self.n_mc,
                "confusion_matrix": None,
            },
            "timing": {
                "per_call_engine_A_times": [],
                "per_call_engine_B_times": [],
                "total_inference_time": 0.0,
            },
            "labels": {"row_col_0": self.label_a, "row_col_1": self.label_b},
        }

    def _time_call(self, fn: Callable, *args, **kwargs):
        use_cuda = (self.timer == "cuda") and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = fn(*args, **kwargs)
            end.record(); end.synchronize()
            return out, start.elapsed_time(end) / 1000.0
        t0 = time.perf_counter()
        out = fn(*args, **kwargs)
        return out, (time.perf_counter() - t0)

    def _evaluate_dataset(self, engine_a, engine_b, batches: Sequence[DataAttr]):
        cm_row = np.zeros((2,), dtype=np.int64)
        a_times: List[float] = []
        b_times: List[float] = []
        batch_stats: List[Dict[str, Any]] = []
        all_joint_values: List[Dict[str, Any]] = []

        n_batch = len(batches)
        step = max(1, math.ceil(n_batch / 100))
        for i, batch in enumerate(batches):
            ll_a, dt_a = self._time_call(engine_a.evaluate_joint_loglikelihood, batch, K=self.K)
            ll_b, dt_b = self._time_call(engine_b.evaluate_joint_loglikelihood, batch, K=self.K)
            a_times.append(dt_a); b_times.append(dt_b)
            joint_a = ll_a.sum(dim=1)
            joint_b = ll_b.sum(dim=1)
            joint_a_mean = joint_a.mean()
            joint_b_mean = joint_b.mean()
            joint_a_std = joint_a.std()
            joint_b_std = joint_b.std()
            diff = joint_a_mean - joint_b_mean
            pred_col = 0 if joint_a_mean > joint_b_mean else (1 if joint_b_mean > joint_a_mean else 0)
            cm_row[pred_col] += 1
            # collect per-batch joint stats
            batch_stats.append({
                "batch_index": i,
                "engine_A_joint_mean": float(joint_a_mean.item()),
                "engine_A_joint_std": float(joint_a_std.item()),
                "engine_B_joint_mean": float(joint_b_mean.item()),
                "engine_B_joint_std": float(joint_b_std.item()),
                "pred_col": int(pred_col),
                "diff": float(diff.item()),
            })
            all_joint_values.append({
                "batch_index": i,
                "engine_A_joint_values": joint_a.cpu().numpy().tolist(),
                "engine_B_joint_values": joint_b.cpu().numpy().tolist(),
            })
            # print timing information
            if (i + 1) % step == 0 or (i + 1) == n_batch:
                print(f"Batch {i} timing - Engine A: {dt_a:.4f}s, Engine B: {dt_b:.4f}s", flush=True)
                print(f"Batch {i} stats - Engine A: {joint_a_mean:.4f} ± {joint_a_std:.4f}, Engine B: {joint_b_mean:.4f} ± {joint_b_std:.4f}", flush=True)
        return cm_row, a_times, b_times, batch_stats, all_joint_values

    def run_evaluation(self):
        cm = np.zeros((2, 2), dtype=np.int64)
        all_a_times: List[float] = []
        all_b_times: List[float] = []
        eval_start = time.time()
        with torch.no_grad():
            cm0, a0, b0, stats_a, all_joint_values_a = self._evaluate_dataset(self.engine_a, self.engine_b, self.data_a_mc)
            cm1, a1, b1, stats_b, all_joint_values_b = self._evaluate_dataset(self.engine_a, self.engine_b, self.data_b_mc)
        cm[0] += cm0; cm[1] += cm1
        all_a_times += a0 + a1
        all_b_times += b0 + b1
        total_time = time.time() - eval_start

        self.metrics["timing"]["per_call_engine_A_times"] = all_a_times
        self.metrics["timing"]["per_call_engine_B_times"] = all_b_times
        self.metrics["timing"]["total_inference_time"] = total_time
        self.metrics["evaluation"]["confusion_matrix"] = cm.tolist()
        # Save per-batch joint stats by true dataset
        self.metrics["evaluation"]["batch_joint_stats"] = {
            self.label_a: stats_a,
            self.label_b: stats_b,
        }

        m = _compute_selection_metrics(cm, self.label_a, self.label_b, all_a_times, all_b_times, total_time)
        self.metrics["evaluation"].update({
            "accuracy": m["accuracy"],
            "macro_f1": m["macro_f1"],
            "per_class": m["per_class"],
            "per_data_accuracy": m["per_data_accuracy"],
            "throughput_samples_per_sec": m["timing"]["throughput_samples_per_sec"],
        })

        self.save_results()
        print(f"\nEvaluation complete in {total_time:.2f}s")
        print(f"Confusion matrix (rows=true, cols=pred):\n{cm}")
        print(f"Accuracy: {m['accuracy']:.3f} | MacroF1: {m['macro_f1']:.3f}")
        return cm, m

    def save_results(self):
        metrics_path = self.save_dir / "model_selection_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")
        self._save_timing_csv()
        self._save_batch_joint_stats_csv()
        self._generate_report()
        cm = np.array(self.metrics["evaluation"]["confusion_matrix"])  # type: ignore[index]
        np.save(self.save_dir / "confusion_matrix.npy", cm)

    def _save_timing_csv(self):
        import csv
        csv_path = self.save_dir / "timing_selection.csv"
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["sample_idx", "engine_A_time_s", "engine_B_time_s"])
            for i, (ta, tb) in enumerate(zip(self.metrics["timing"]["per_call_engine_A_times"],
                                             self.metrics["timing"]["per_call_engine_B_times"])):
                writer.writerow([i, ta, tb])
        print(f"Saved timing data to {csv_path}")

    def _save_batch_joint_stats_csv(self):
        import csv
        csv_path = self.save_dir / "batch_joint_stats.csv"
        evalm = self.metrics.get("evaluation", {})
        batch_stats = evalm.get("batch_joint_stats", {})
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "true_dataset", "batch_index",
                "engine_A_joint_mean", "engine_A_joint_std",
                "engine_B_joint_mean", "engine_B_joint_std",
                "pred_label"
            ])
            for dataset_label, stats in batch_stats.items():
                for s in stats:
                    pred = s.get("pred_col", 0)
                    pred_label = self.label_a if pred == 0 else self.label_b
                    writer.writerow([
                        dataset_label,
                        s.get("batch_index", ""),
                        s.get("engine_A_joint_mean", ""),
                        s.get("engine_A_joint_std", ""),
                        s.get("engine_B_joint_mean", ""),
                        s.get("engine_B_joint_std", ""),
                        pred_label,
                    ])
        print(f"Saved batch joint stats to {csv_path}")

    def _generate_report(self):
        report_path = self.save_dir / "model_selection_report.txt"
        cm = np.array(self.metrics["evaluation"]["confusion_matrix"])  # type: ignore[index]
        timing = self.metrics["timing"]
        evalm = self.metrics["evaluation"]
        # Build content lines
        lines: List[str] = []
        def w(s: str = ""):
            lines.append(s)

        w("=" * 80)
        w("ACE MODEL SELECTION REPORT")
        w("=" * 80)
        w("")
        w("MODELS")
        w("-" * 40)
        w(f"Engine A: {self.ckpt_a}")
        w(f"Engine B: {self.ckpt_b}")
        w(f"Device: {self.metrics['model_stats']['device']}")
        w(f"Compiled: {self.metrics['model_stats']['compiled']}")
        w(f"Timer: {self.metrics['model_stats']['timer']}")
        w("")
        w("EVALUATION SETTINGS")
        w("-" * 40)
        w(f"n_data_per_class: {self.n_data}")
        w(f"n_points_per_sample: {self.n_points}")
        w(f"K (buffer batch size): {self.K}")
        w(f"n_mc_per_sample: {self.n_mc}")
        w("")
        w("CONFUSION MATRIX (rows=true, cols=pred)")
        w("-" * 40)
        w(str(cm))
        w("")
        w("SUMMARY METRICS")
        w("-" * 40)
        w(f"Accuracy: {evalm['accuracy']:.6f}")
        w(f"Macro F1: {evalm['macro_f1']:.6f}")
        w(f"Per-data accuracy: {evalm['per_data_accuracy']}")
        w("")
        w("TIMING")
        w("-" * 40)
        a_times = np.array(timing["per_call_engine_A_times"])  # type: ignore[index]
        b_times = np.array(timing["per_call_engine_B_times"])  # type: ignore[index]
        a_mean = a_times.mean() if a_times.size else float('nan')
        a_std = a_times.std() if a_times.size else float('nan')
        b_mean = b_times.mean() if b_times.size else float('nan')
        b_std = b_times.std() if b_times.size else float('nan')
        w(f"Engine A per-call: mean={_format_seconds(a_mean)}, std={_format_seconds(a_std)}")
        w(f"Engine B per-call: mean={_format_seconds(b_mean)}, std={_format_seconds(b_std)}")
        w(f"Total evaluation time: {_format_seconds(timing['total_inference_time'])}")
        w(f"Throughput (samples/s): {evalm.get('throughput_samples_per_sec', 0.0):.2f}")
        w("")
        # New: Batch-level joint stats section
        w("JOINT LOG-LIKELIHOOD STATS")
        w("-" * 40)
        batch_stats = evalm.get("batch_joint_stats", {})
        for true_label, stats in batch_stats.items():
            w(f"True dataset: {true_label}")
            if stats:
                a_means = np.array([s["engine_A_joint_mean"] for s in stats], dtype=float)
                a_stds = np.array([s["engine_A_joint_std"] for s in stats], dtype=float)
                b_means = np.array([s["engine_B_joint_mean"] for s in stats], dtype=float)
                b_stds = np.array([s["engine_B_joint_std"] for s in stats], dtype=float)
                w(f"  Engine A (across batches): mean={a_means.mean():.6f}, std_of_means={a_means.std():.6f}, avg_batch_std={a_stds.mean():.6f}")
                w(f"  Engine B (across batches): mean={b_means.mean():.6f}, std_of_means={b_means.std():.6f}, avg_batch_std={b_stds.mean():.6f}")
            else:
                w("  No batch stats available")
            w("  Per-batch (idx | A mean±std | B mean±std | pred)")
            header = f"    {'idx':>4} | {'A mean±std':>22} | {'B mean±std':>22} | {'pred':>4}"
            w(header)
            for s in stats:
                a_str = f"{s['engine_A_joint_mean']:.6f} ± {s['engine_A_joint_std']:.6f}"
                b_str = f"{s['engine_B_joint_mean']:.6f} ± {s['engine_B_joint_std']:.6f}"
                pred = s.get("pred_col", 0)
                pred_label = self.label_a if pred == 0 else self.label_b
                w(f"    {s['batch_index']:>4} | {a_str:>22} | {b_str:>22} | {pred_label:>4}")
            w("")
        w("=" * 80)

        content = "\n".join(lines)
        with open(report_path, "w") as f:
            f.write(content)
        # Also save a duplicate to result.txt for convenience
        result_path = self.save_dir / "result.txt"
        with open(result_path, "w") as f:
            f.write(content)
        print(f"Saved evaluation report to {report_path}")
        print(f"Saved duplicate report to {result_path}")

class ModelSelectionEvaluatorRealData(ModelSelectionEvaluator):

    def __init__(self, ckpt_a, ckpt_b, label_a = "rho=4/3", label_b = "rho=1", data_path = "data/bav_real",
                 K = 16, n_mc = 10, device = "cpu", save_dir = "./eval_results/model_selection_real", 
                 compile = False, seed = 42) -> None:
        self.seed = seed
        self.data_path = data_path
        self.ckpt_a = ckpt_a
        self.ckpt_b = ckpt_b
        self.label_a = label_a
        self.label_b = label_b
        self.K = K
        self.n_mc = n_mc
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.compile = compile
        # Timing follows device automatically
        self.timer = "cuda" if self.device.type == "cuda" else "cpu"

        self._build_data()
        self._load_models_and_inference()
        self.n_data = len(self.data)  # number of real data samples
        self.n_points = self.data[0].xt.shape[1]

        self._initialize_metrics()
        

    def _build_data(self) -> None:
        # Load real data using the BavTrueDataloader
        data_loader = BavTrueDataloader(self.data_path, device=self.device)
        data = data_loader.load_data()
        self.data = _build_mc_test_batches(data, n_mc=self.n_mc)

    def run_evaluation(self):
        cm = np.zeros((2, 2), dtype=np.int64)
        all_a_times: List[float] = []
        all_b_times: List[float] = []
        eval_start = time.time()
        with torch.no_grad():
            cm0, a0, b0, batch_stats, all_joint_values = self._evaluate_dataset(self.engine_a, self.engine_b, self.data)

        cm[0] += cm0; cm[1] += cm0  # both rows are true real data
        all_a_times += a0
        all_b_times += b0
        total_time = time.time() - eval_start

        # pretty print cm
        print(f"\nEvaluation complete in {total_time:.2f}s")
        print(f"Confusion matrix (rows=true, cols=pred):\n{cm}")
        
        # get diff
        diffs = [s['diff'] for s in batch_stats]

        # Separate indices for positive and negative diffs
        pos_idx = [i for i, d in enumerate(diffs) if d > 0]
        neg_idx = [i for i, d in enumerate(diffs) if d < 0]

        # plt figsize
        plt.figure(figsize=(10, 4))
        # Scatter with different colors
        plt.scatter(pos_idx, [diffs[i] for i in pos_idx], 
                    color='green', label=r'$\rho = \frac{4}{3}$ is better')
        plt.scatter(neg_idx, [diffs[i] for i in neg_idx], 
                    color='blue', label=r'$\rho = 1$ is better')
        # Horizontal line at 0
        plt.axhline(0, color='red', linestyle='--')

        plt.xlabel('Sample Index (15 participants, each splited into 2, 30 total)')
        plt.title(f'LL Difference ({self.label_a} - {self.label_b})')
        plt.legend()
        plt.savefig("joint_loglikelihood_differences.png", dpi=300)
        plt.close()

        # Save a single CSV with prefix "real_data" into --save-dir
        # One row per batch with the metrics from batch_stats
        import csv
        csv_path = self.save_dir / f"real_data_batch_stats_TNPB-K{self.K}_seed{self.seed}.csv"
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "batch_index",
                "engine_A_joint_mean",
                "engine_A_joint_std",
                "engine_B_joint_mean",
                "engine_B_joint_std",
                "engine_A_time_s",
                "engine_B_time_s",
                "pred_col",
                "pred_label",
                "diff",
            ])
            # Build a mapping from batch index to timing to ensure correct alignment
            a_time_map = {i: t for i, t in enumerate(a0)}
            b_time_map = {i: t for i, t in enumerate(b0)}
            for s in batch_stats:
                idx = int(s.get("batch_index", 0))
                pred = int(s.get("pred_col", 0))
                pred_label = self.label_a if pred == 0 else self.label_b
                writer.writerow([
                    idx,
                    s.get("engine_A_joint_mean", ""),
                    s.get("engine_A_joint_std", ""),
                    s.get("engine_B_joint_mean", ""),
                    s.get("engine_B_joint_std", ""),
                    a_time_map.get(idx, ""),
                    b_time_map.get(idx, ""),
                    pred,
                    pred_label,
                    s.get("diff", ""),
                ])
        print(f"Saved batch stats CSV to {csv_path}")

        # save all joint values as csv
        all_joint_values_path = self.save_dir / f"real_data_all_joint_values_TNPB-K{self.K}_seed{self.seed}.csv"
        with open(all_joint_values_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "batch_index",
                "mc_index",
                "engine_A_joint_values",
                "engine_B_joint_values",
            ])
            for s in all_joint_values:
                for i in range(len(s.get("engine_A_joint_values", []))):
                    writer.writerow([
                        s.get("batch_index", ""),
                        i,
                        ";".join(map(str, [s.get("engine_A_joint_values", [])[i][0]])),
                        ";".join(map(str, [s.get("engine_B_joint_values", [])[i][0]])),
                    ])
        print(f"Saved all joint values CSV to {all_joint_values_path}")

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def main():
    parser = argparse.ArgumentParser(description="Evaluate model selection between two ACE checkpoints")
    parser.add_argument("--ckpt-a", type=str, default="checkpoints/bavrho43_model/best_model.pt", help="Path to checkpoint A")
    parser.add_argument("--ckpt-b", type=str, default="checkpoints/bavrho1_model/best_model.pt", help="Path to checkpoint B")
    parser.add_argument("--label-a", type=str, default="rho=4/3", help="Display label for engine A")
    parser.add_argument("--label-b", type=str, default="rho=1", help="Display label for engine B")
    parser.add_argument("--n-data", type=int, default=64, help="Number of data per class (true A / true B)")
    parser.add_argument("--n-points", type=int, default=400, help="Number of target points per sample")
    parser.add_argument("--n-mc", type=int, default=10, help="Number of MC permutations per sample")
    parser.add_argument("--K", type=int, default=16, help="Buffer batch size for inference engine")
    parser.add_argument("--device", type=str, default="cpu", help="Device to run on")
    parser.add_argument("--save-dir", type=str, default="./eval_results/model_selection", help="Directory to save results")
    parser.add_argument("--compile", type=string2bool, default=False, help="Use torch.compile on models")
    parser.add_argument("--real-data", type=string2bool, default=False, help="Use real BAV data for evaluation")
    parser.add_argument("--data-path", type=str, default="data/bav_real", help="Path to real BAV data (if --real-data is set)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")

    args = parser.parse_args()

    set_seed(args.seed)

    if args.real_data:
        print("Running real data model selection evaluation")
        evaluator = ModelSelectionEvaluatorRealData(
            ckpt_a=args.ckpt_a,
            ckpt_b=args.ckpt_b,
            label_a=args.label_a,
            label_b=args.label_b,
            data_path=args.data_path,
            K=args.K,
            n_mc=args.n_mc,
            device=args.device,
            save_dir=args.save_dir,
            compile=args.compile,
            seed=args.seed,
        )

    else:
        print("Running synthetic data model selection evaluation")
        evaluator = ModelSelectionEvaluator(
            ckpt_a=args.ckpt_a,
            ckpt_b=args.ckpt_b,
            label_a=args.label_a,
            label_b=args.label_b,
            n_data=args.n_data,
            n_points=args.n_points,
            K=args.K,
            n_mc=args.n_mc,
            device=args.device,
            save_dir=args.save_dir,
            compile=args.compile,
        )
    
    evaluator.run_evaluation()


if __name__ == "__main__":
    main()
