"""Evaluation script for ACE model with comprehensive metrics and timing."""

import math
import argparse
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable

import hydra
import numpy as np
import torch
import csv
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
from matplotlib import pyplot as plt

from src.data.utils import OfflineBatchLoader, SamplePermutationHelper
from src.models.benchmarks.tnp import TNP, SampleReshaper
from src.utils import DataAttr


def string2bool(b):
    if isinstance(b, bool):
        return b
    if b.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif b.lower() in ("no", "false", "f", "n", "0"):
        return False


class BaselineModelEvaluator:
    """Professional evaluation framework for ACE models."""
    
    def __init__(
        self,
        checkpoint_path: str,
        data_path: Optional[str] = None,
        eval_functions: int=1000,
        num_predictions_per_target: int=100,
        max_context_points: int=4,
        max_target_points: int=128,
        independent_sample: bool=False,
        device: str = "cuda",
        save_dir: str = "./eval_results",
        compile: bool = False,
    ):
        """Initialize evaluator with model and evaluation settings.
        
        Args:
            checkpoint_path: Path to model checkpoint
            data_path: Path to evaluation data (if None, generates online)
            eval_functions: Total number of functions for evaluation statistics
            num_predictions_per_target: predict this number of samples (repetition per function)
            max_context_points: max number of context points used per prediction
            max_target_points: max number of target points used per prediction
            independent_sample: Whether to sample targets independently
            device: Device to run evaluation on
            save_dir: Directory to save evaluation results
            compile: Whether to use torch.compile on inference methods
        """
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        # Setup data
        print(f"Loading data from {data_path}")
        self.data_path = data_path
        if data_path:
            self.dataloader = self._build_dataloader(data_path)
        else:
            raise NotImplementedError
        self.batch_size = self.dataloader.dataset.batch_size
        self.eval_functions = eval_functions
        self.num_predictions_per_target = num_predictions_per_target
        self.max_context_points = max_context_points
        self.max_target_points = max_target_points
        self.independent_sample = independent_sample
        n_batch_needed = self.eval_functions // self.batch_size
        if n_batch_needed * self.batch_size < self.eval_functions:
            n_batch_needed += 1
        self.num_eval_batches = min( n_batch_needed, len(self.dataloader) )
        print(f"Loaded {len(self.dataloader)} batches with batch size {self.batch_size}", flush=True)

        # Load model and config
        print(f"Loading checkpoint from {checkpoint_path}")
        self.checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.config = OmegaConf.create(self.checkpoint["config"])
        self.epoch = self.checkpoint.get("epoch", 0)
        self.global_step = self.checkpoint.get("global_step", 0)
        
        # Build and load model
        self.model = self._build_model()
        
        # Handle torch.compile prefix if present
        state_dict = self.checkpoint["model_state_dict"]
        if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
            # Remove _orig_mod. prefix from keys
            state_dict = {
                key.replace("_orig_mod.", ""): value 
                for key, value in state_dict.items()
            }
        
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Compile model if requested
        self.compile = compile
        if self.compile:
            print("Compiling model with torch.compile (mode=default)...")
            self.model = torch.compile(self.model, mode="default")
        
        # Initialize metrics storage
        self.metrics = {
            "timing": {
                "sample_sequence_times": [],
                "eval_sequence_ll_times": [],
                "prepare_cache_times": [],
                "batch_decode_times": [],
                "per_sample_times": [],
                "total_inference_time": 0.0,
            },
            "performance": {
                "mae": [],
                "mse": [],
                "log_mean_likelihood": [],
                "mean_log_likelihood": [],
            },
            "model_stats": {
                "checkpoint_path": checkpoint_path,
                "epoch": self.epoch,
                "global_step": self.global_step,
                "num_parameters": sum(p.numel() for p in self.model.parameters()),
                "device": str(self.device),
                "compiled": self.compile,
            },
            "evaluation": {
                "num_batches": self.num_eval_batches,
                "batch_size": self.batch_size,
                "total_samples": 0,
                "independent_targets": self.independent_sample,
            },
            "data_stats": {},
            "predictions": [],
        }

    def _build_model(self) -> TNP:
        """Build model from config."""
        cfg = self.config.model
        model = hydra.utils.instantiate(cfg)
        return model

    def _build_dataloader(self, data_path: str) -> DataLoader:
        """Build dataloader for given split."""
        data_path = Path(data_path)
        
        if not data_path.exists():
            raise ValueError(f"Data path does not exist: {data_path}")
        
        dataset = OfflineBatchLoader(data_path, device=str(self.device))
        
        return DataLoader(
            dataset,
            batch_size=None,  # Pre-batched
            shuffle=False,
            num_workers=0,
            pin_memory=(self.device.type == "cuda"),
        )

    def _batch_permute(self, batch: DataAttr, repetition: int):
        if repetition > 1:
            batch_random_permute, perm_info = SamplePermutationHelper.repeat_and_permute_batch(batch, repetition)
            return batch_random_permute, perm_info
        else: # if repetition=1, no permutation is applied
            return batch, None

    def _prediction_unpermute_reshape(self, yt_perm, perm_info):
        """yt_perm: [repetition, B, T, Dy]"""
        if perm_info is None:
            y_dep = yt_perm
        else:
            y_dep = SamplePermutationHelper.unpermute_targets(yt_perm, None, perm_info=perm_info)[0] # [repetition, B, T, Dy]
        return SampleReshaper.torch_dist2custom(y_dep) # [B, T, repetition, Dy]

    def _time_call(self, fn: Callable, *args, **kwargs):
        if self.device.type == "cuda":
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = fn(*args, **kwargs)
            end.record(); end.synchronize()
            t = start.elapsed_time(end) / 1000.0 # in seconds
        else:
            t0 = time.perf_counter()
            out = fn(*args, **kwargs)
            t = time.perf_counter() - t0
        
        return out, t

    def evaluate_batch(self, batch: DataAttr, repetition: int=128) -> Dict[str, float]:
        """Evaluate a single batch with timing."""
        batch_size = batch.xc.shape[0]
        num_targets = batch.xt.shape[1]

        # Time cache preparation, baselines don't have cache
        cache_time = 0.0

        # Time full sequence sampling
        if not self.independent_sample and not self.model.support_non_ar_joint:
            # if repetition=1, _batch_permute() applies no permutation and perm_info=None
            batch_execute, perm_info = self._batch_permute(batch, repetition)

        # start timing
        # sequence sampling
        with torch.no_grad():
            if self.independent_sample:
                yhat, seq_time = self._time_call(
                    self.model.sample, # function to be timed
                    batch.xc, batch.yc, batch.xt, num_samples=repetition # arguments to the function
                ) # yhat: [B, T, repetition, Dy]
            elif self.model.support_non_ar_joint:
                yhat, seq_time = self._time_call(
                    self.model.sample_joint_predictive, # function to be timed
                    batch.xc, batch.yc, batch.xt, num_samples=repetition
                ) # yhat: [B, T, repetition, Dy]
            else:
                yhat, seq_time = self._time_call(
                    self.model.sample_joint_predictive, # function to be timed
                    batch_execute.xc, batch_execute.yc, batch_execute.xt, num_samples=1
                ) # yhat: [repetition * B, T, Dy]
                yhat = yhat.view(repetition, batch_size, num_targets, -1) # [repetition, B, T, Dy]
                # if repetition=1, perm_info=None and we only swap dim 0 <-> 2
                yhat = self._prediction_unpermute_reshape(yhat, perm_info) # [B, T, repetition, Dy]

        with torch.no_grad():
            if self.independent_sample:
                ll, seqll_time = self._time_call(
                    self.model.eval_log_likelihood, # function to be timed
                    batch.xc, batch.yc, batch.xt, batch.yt
                ) # ll: [B,]
                llmean = - math.log(ll.numel()) + ll.view(-1).logsumexp(dim=0)
                meanll = ll.mean()
            elif self.model.support_non_ar_joint:
                ll, seqll_time = self._time_call(
                    self.model.eval_log_joint_likelihood,
                    batch.xc, batch.yc, batch.xt, batch.yt
                ) # ll: [B,]
                llmean = - math.log(ll.numel()) + ll.view(-1).logsumexp(dim=0)
                meanll = ll.mean()
            else:
                ll, seqll_time = self._time_call(
                    self.model.eval_log_joint_likelihood,
                    batch_execute.xc, batch_execute.yc, batch_execute.xt, batch_execute.yt
                ) # ll: [repetition * B,]
                llmean = - math.log(ll.numel()) + ll.view(-1).logsumexp(dim=0)
                meanll = ll.mean()

        mem_all = torch.cuda.get_device_properties(0).total_memory / 2**30
        mem_res = torch.cuda.memory_reserved(0) / 2**30
        mem_aloc = torch.cuda.memory_allocated(0) / 2**30
        mem_fre = mem_res - mem_aloc  # free inside reserved
        print("### cuda memory -- total | reserv | alloc: %.3f | %.3f | %.3f (GB)"%(mem_all, mem_res, mem_aloc), flush=True)
        # Calculate metrics
        per_sample_time = seq_time / (batch_size * repetition * num_targets)

        # Store batch metrics
        batch_metrics = {
            "batch_size": batch_size,
            "num_targets": num_targets,
            "num_context": batch.xc.shape[1],
            "cache_prep_time": cache_time,
            "sequence_time": seq_time,
            "sequence_ll_time": seqll_time,
            "per_sample_time": per_sample_time,
            "throughput_samples_per_sec": (batch_size * num_targets * repetition) / seq_time,
        }
        
        # Calculate prediction error if ground truth available
        if hasattr(batch, "yt") and batch.yt is not None:
            err = yhat - batch.yt.unsqueeze(-2) # [B, T, repetition, Dy]
            with torch.no_grad():
                mse = np.array([torch.mean(err ** 2).item()]) # [1]
                mae = np.array([torch.mean(torch.abs(err)).item()]) # [1]
                llmean = np.array([llmean.item()]) # [1]
                meanll = np.array([meanll.item()]) # [1]
            batch_metrics["mse"] = mse
            batch_metrics["mae"] = mae
            batch_metrics["log_mean_likelihood"] = llmean
            batch_metrics["mean_log_likelihood"] = meanll

        return batch_metrics, yhat

    def plot_evaluation(self, max_number_of_plotted_function: int=8):
        B = min(self.batch_size, max_number_of_plotted_function)
        print("\nVisualize predictions on the first batch")
        print(f"Device: {self.device}")
        print(f"Batch size: {B}")
        with torch.no_grad():
            # Get batch
            batch = self.dataloader.dataset[0]
            if batch.xc.shape[-1] > 1:
                print("Visualization of more than > 1Dx is not supported.")
                return
            C = min(batch.xc.shape[1], self.max_context_points)
            T = min(batch.xt.shape[1], self.max_target_points)

            if batch.xc.shape[0] > B or batch.xc.shape[1] > C or batch.xt.shape[1] > T:
                batch = DataAttr(
                    xc=batch.xc[:B, :C],
                    yc=batch.yc[:B, :C],
                    xb=batch.xb[:B] if batch.xb is not None else None,
                    yb=batch.yb[:B] if batch.yb is not None else None,
                    xt=batch.xt[:B, :T],
                    yt=batch.yt[:B, :T] if batch.yt is not None else None,
                )

            # Move to device if needed
            if batch.xc.device != self.device:
                batch = DataAttr(
                    xc=batch.xc.to(self.device),
                    yc=batch.yc.to(self.device),
                    xb=batch.xb.to(self.device) if batch.xb is not None else None,
                    yb=batch.yb.to(self.device) if batch.yb is not None else None,
                    xt=batch.xt.to(self.device),
                    yt=batch.yt.to(self.device) if batch.yt is not None else None,
                )

            n = self.num_predictions_per_target
            # joint samples
            if self.model.support_non_ar_joint:
                yhat_joint = self.model.sample_joint_predictive(batch.xc, batch.yc, batch.xt, num_samples=n) # [B, T, n, Dy]
            else:
                # if repetition=1, _batch_permute() applies no permutation and perm_info=None
                batch_execute, perm_info = self._batch_permute(batch, n)
                yhat_joint = self.model.sample_joint_predictive(batch_execute.xc, batch_execute.yc, batch_execute.xt, num_samples=1).squeeze(-2) # [n*B, T, Dy]
                yhat_joint = yhat_joint.view(n, batch.xt.shape[0], batch.xt.shape[1], -1) # [n, B, T, Dy]
                # if repetition=1, perm_info=None and we only swap dim 0 <-> 2
                yhat_joint = self._prediction_unpermute_reshape(yhat_joint, perm_info) # [B, T, n, Dy]
            # parallel samples
            if self.independent_sample:
                yhat = self.model.sample(batch.xc, batch.yc, batch.xt, num_samples=n) # [B, T, n, Dy]

        Dy = batch.yc.shape[-1]
        fig, axs = plt.subplots(B, Dy, figsize=(10 * Dy, 5 * B), sharex=True, sharey=True, squeeze=False)
        for i in range(B):
            xc = batch.xc[i, :, 0].cpu().numpy()  # Context inputs
            org = batch.xt[i, :, 0].argsort()  # Sort target inputs
            xt = batch.xt[i, org, 0].cpu().numpy()  # Target inputs
            for dy in range(Dy):
                yc = batch.yc[i, :, dy].cpu().numpy()  # Context outputs
                yt = batch.yt[i, org, dy].cpu().numpy()  # Target outputs
                # mean & std of joint predictive
                yhat_joint_i = yhat_joint[i, org, :, dy].cpu().numpy()  # Predicted outputs
                mu_joint = yhat_joint_i.mean(axis=-1)  # Mean prediction
                sigma_joint = yhat_joint_i.std(axis=-1)  # Std prediction

                axs[i, dy].plot(xc, yc, "o", color="black", label="Context")
                axs[i, dy].plot(xt, yt, "--", color="black", label="Ground Truth")
                axs[i, dy].plot(xt, mu_joint, "-", color="C0", label="Predictive sample mean")
                axs[i, dy].fill_between(
                    xt,
                    mu_joint - sigma_joint,
                    mu_joint + sigma_joint,
                    color="C0",
                    alpha=0.2,
                    label="Predictive sample std"
                )
                for r in range(min(10, n)):
                    axs[i, dy].plot(xt, yhat_joint_i[..., r], "--", color="C0", alpha=0.5, label="Predictive sample" if r==0 else None)

                if self.independent_sample:
                    # mean & std of independent predictive
                    yhat_i = yhat[i, org, :, dy].cpu().numpy()  # Predicted outputs
                    mu_indep = yhat_i.mean(axis=-1)  # Mean prediction
                    sigma_indep = yhat_i.std(axis=-1)  # Std prediction

                    axs[i, dy].plot(xt, mu_indep, "-", color="C1", label="Predictive independent sampling mean")
                    axs[i, dy].fill_between(
                        xt,
                        mu_indep - sigma_indep,
                        mu_indep + sigma_indep,
                        color="C1",
                        alpha=0.2,
                        label="Predictive independent sampling std"
                    )
                    for r in range(min(10, n)):
                        axs[i, dy].plot(xt, yhat_i[..., r], "--", color="C1", alpha=0.5, label="Predictive independent sample" if r==0 else None)

                axs[i, dy].set_title(
                    f"Function {i + 1}, {dy}th-y, {self.model.__class__.__name__} samples & {'joint distribution' if self.model.support_non_ar_joint else 'AR MC joint distribution'}"
                )

        for dy in range(Dy):
            axs[0, dy].legend()

        # Save results
        plot_path = self.save_dir / "predictions.svg"
        fig.savefig(plot_path, format="svg")
        plt.close('all')

    def run_evaluation(self):
        """Run full evaluation with comprehensive metrics."""
        print(f"\nStarting evaluation on {self.num_eval_batches} batches")
        print(f"Device: {self.device}")
        print(f"Batch size: {self.batch_size}")
        all_metrics = []
        total_samples = 0
        total_functions = 0
        eval_start = time.time()
        
        # Progress bar
        pbar = tqdm(range(self.num_eval_batches), desc="Evaluating")
        
        with torch.no_grad():
            data_iter = iter(self.dataloader)
            for batch_idx in pbar:
                # Get batch
                batch = next(data_iter)
                C = min(batch.xc.shape[1], self.max_context_points)
                T = min(batch.xt.shape[1], self.max_target_points)

                if batch.xc.shape[1] > C or batch.xt.shape[1] > T:
                    batch = DataAttr(
                        xc=batch.xc[:, :C],
                        yc=batch.yc[:, :C],
                        xb=batch.xb if batch.xb is not None else None,
                        yb=batch.yb if batch.yb is not None else None,
                        xt=batch.xt[:, :T],
                        yt=batch.yt[:, :T] if batch.yt is not None else None,
                    )
                # Move to device if needed
                if batch.xc.device != self.device:
                    batch = DataAttr(
                        xc=batch.xc.to(self.device),
                        yc=batch.yc.to(self.device),
                        xb=batch.xb.to(self.device) if batch.xb is not None else None,
                        yb=batch.yb.to(self.device) if batch.yb is not None else None,
                        xt=batch.xt.to(self.device),
                        yt=batch.yt.to(self.device) if batch.yt is not None else None,
                    )
                
                # Evaluate batch
                for i in range(self.batch_size):
                    func = DataAttr(
                        xc=batch.xc[i:i+1],
                        yc=batch.yc[i:i+1],
                        xb=batch.xb[i:i+1] if batch.xb is not None else None,
                        yb=batch.yb[i:i+1] if batch.yb is not None else None,
                        xt=batch.xt[i:i+1],
                        yt=batch.yt[i:i+1] if batch.yt is not None else None,
                    )

                    func_metrics, yhat = self.evaluate_batch(func, repetition=self.num_predictions_per_target)
                    all_metrics.append(func_metrics)
                    
                    # Update timing metrics
                    self.metrics["timing"]["sample_sequence_times"].append(func_metrics["sequence_time"])
                    self.metrics["timing"]["eval_sequence_ll_times"].append(func_metrics["sequence_ll_time"])
                    self.metrics["timing"]["prepare_cache_times"].append(func_metrics["cache_prep_time"])
                    self.metrics["timing"]["per_sample_times"].append(func_metrics["per_sample_time"])
                    self.metrics["performance"]["mae"].append(func_metrics["mae"][0])
                    self.metrics["performance"]["mse"].append(func_metrics["mse"][0])
                    self.metrics["performance"]["log_mean_likelihood"].append(func_metrics["log_mean_likelihood"][0])
                    self.metrics["performance"]["mean_log_likelihood"].append(func_metrics["mean_log_likelihood"][0])

                    total_functions += func_metrics["batch_size"]
                    total_samples += func_metrics["batch_size"] * func_metrics["num_targets"] * self.num_predictions_per_target
                
                # Update progress bar
                avg_time = np.mean(self.metrics["timing"]["sample_sequence_times"])
                pbar.set_postfix({
                    "avg_seq_time": f"{avg_time:.3f}s",
                    "throughput": f"{total_samples / (time.time() - eval_start):.1f} samples/s",
                })
        
        # Calculate final statistics
        eval_time = time.time() - eval_start
        self.metrics["timing"]["total_inference_time"] = eval_time
        self.metrics["evaluation"]["batch_size_per_evaluation"] = 1
        self.metrics["evaluation"]["total_samples"] = total_samples
        self.metrics["evaluation"]["samples_per_second"] = total_samples / eval_time
        self.metrics["evaluation"]["num_predictions_per_target"] = self.num_predictions_per_target
        
        # Aggregate batch metrics
        self._aggregate_metrics(all_metrics)

        print(f"\nEvaluation complete in {eval_time:.2f}s")
        print(f"Total samples processed: {total_samples:,}")
        print(f"Average throughput: {total_samples / eval_time:.1f} samples/s")

        # Save results
        self.save_results()

    def _aggregate_metrics(self, all_metrics: List[Dict]):
        """Aggregate metrics across all batches."""
        # Timing statistics
        timing_keys = ["sequence_time", "sequence_ll_time", "cache_prep_time", "per_sample_time", "throughput_samples_per_sec"]
        for key in timing_keys:
            values = [m[key] for m in all_metrics]
            self.metrics["evaluation"][f"{key}_mean"] = float(np.mean(values))
            self.metrics["evaluation"][f"{key}_std"] = float(np.std(values))
            self.metrics["evaluation"][f"{key}_min"] = float(np.min(values))
            self.metrics["evaluation"][f"{key}_max"] = float(np.max(values))
            self.metrics["evaluation"][f"{key}_p50"] = float(np.percentile(values, 50))
            self.metrics["evaluation"][f"{key}_p90"] = float(np.percentile(values, 90))
            self.metrics["evaluation"][f"{key}_p99"] = float(np.percentile(values, 99))
        
        # Error metrics if available
        if "mse" in all_metrics[0]:
            for metric in ["mse", "mae", "log_mean_likelihood", "mean_log_likelihood"]:
                values = np.concatenate([m[metric] for m in all_metrics])
                self.metrics["evaluation"][f"{metric}_mean"] = float(np.mean(values))
                self.metrics["evaluation"][f"{metric}_std"] = float(np.std(values))
        
        # Data statistics
        self.metrics["data_stats"] = {
            "num_context_mean": float(np.mean([m["num_context"] for m in all_metrics])),
            "num_targets_mean": float(np.mean([m["num_targets"] for m in all_metrics])),
            "batch_size": self.batch_size,
            "num_predictions_per_target": self.num_predictions_per_target,
        }

    def save_results(self):
        """Save evaluation results to disk."""
        # Save main metrics as JSON
        metrics_path = self.save_dir / "evaluation_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")
        
        # Save summary report
        self._generate_report()
        
        # Save timing & performance data as CSV for easy analysis
        self._save_timing_csv()
        self._save_performance_csv()

    def _generate_report(self):
        """Generate human-readable evaluation report."""
        report_path = self.save_dir / "evaluation_report.txt"
        
        with open(report_path, "w") as f:
            f.write("=" * 80 + "\n")
            f.write("BASELINE MODEL EVALUATION REPORT\n")
            f.write("=" * 80 + "\n\n")
            
            # Model information
            f.write("MODEL INFORMATION\n")
            f.write("-" * 40 + "\n")
            f.write(f"Checkpoint: {self.metrics['model_stats']['checkpoint_path']}\n")
            f.write(f"Epoch: {self.metrics['model_stats']['epoch']}\n")
            f.write(f"Global Step: {self.metrics['model_stats']['global_step']}\n")
            f.write(f"Parameters: {self.metrics['model_stats']['num_parameters']:,}\n")
            f.write(f"Device: {self.metrics['model_stats']['device']}\n\n")
            
            # Evaluation settings
            f.write("EVALUATION SETTINGS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Number of batches: {self.metrics['evaluation']['num_batches']}\n")
            f.write(f"Batch size: {self.metrics['evaluation']['batch_size']}\n")
            f.write(f"Batch size per evaluation: {self.metrics['evaluation']['batch_size_per_evaluation']}\n")
            f.write(f"Independent targets: {self.metrics['evaluation']['independent_targets']}\n")
            f.write(f"Number of repetition predicted on each target: {self.metrics['evaluation']['num_predictions_per_target']}\n")
            f.write(f"Total samples: {self.metrics['evaluation']['total_samples']:,}\n\n")
            
            # Performance metrics
            f.write("PERFORMANCE METRICS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Total evaluation time: {self.metrics['timing']['total_inference_time']:.2f}s\n")
            f.write(f"Overall throughput: {self.metrics['evaluation']['samples_per_second']:.1f} samples/s\n\n")
            
            # Timing statistics
            f.write("TIMING STATISTICS (seconds)\n")
            f.write("-" * 40 + "\n")
            timing_metrics = ["sequence_time", "sequence_ll_time", "cache_prep_time", "per_sample_time"]
            headers = ["Metric", "Mean", "Std", "Min", "Max", "P50", "P90", "P99"]
            f.write(f"{headers[0]:<20} " + " ".join(f"{h:>10}" for h in headers[1:]) + "\n")
            
            for metric in timing_metrics:
                f.write(f"{metric:<20} ")
                for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                    key = f"{metric}_{stat}"
                    value = self.metrics["evaluation"].get(key, 0)
                    f.write(f"{value:>10.6f} ")
                f.write("\n")
            
            # Throughput statistics
            f.write(f"\n{'throughput (samples/s)':<20} ")
            for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                key = f"throughput_samples_per_sec_{stat}"
                value = self.metrics["evaluation"].get(key, 0)
                f.write(f"{value:>10.1f} ")
            f.write("\n")
            
            # Error metrics if available
            if "mse_mean" in self.metrics["evaluation"]:
                f.write("\nPREDICTION ERROR METRICS\n")
                f.write("-" * 40 + "\n")
                f.write(f"MSE: {self.metrics['evaluation']['mse_mean']:.6f} ± {self.metrics['evaluation']['mse_std']:.6f}\n")
                f.write(f"MAE: {self.metrics['evaluation']['mae_mean']:.6f} ± {self.metrics['evaluation']['mae_std']:.6f}\n")
                f.write(f"log_mean_likelihood_mean: {self.metrics['evaluation']['log_mean_likelihood_mean']:.6f} ± {self.metrics['evaluation']['log_mean_likelihood_std']:.6f}\n")
                f.write(f"mean_log_likelihood_mean: {self.metrics['evaluation']['mean_log_likelihood_mean']:.6f} ± {self.metrics['evaluation']['mean_log_likelihood_std']:.6f}\n")

            f.write("\n" + "=" * 80 + "\n")
        
        print(f"Saved evaluation report to {report_path}")

    def _save_timing_csv(self):
        """Save detailed timing data as CSV."""

        csv_path = self.save_dir / "timing_data.csv"

        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["batch_idx", "sequence_time", "sequence_ll_time", "cache_prep_time", "per_sample_time"])
            
            for i in range(len(self.metrics["timing"]["sample_sequence_times"])):
                writer.writerow([
                    i,
                    self.metrics["timing"]["sample_sequence_times"][i],
                    self.metrics["timing"]["eval_sequence_ll_times"][i],
                    self.metrics["timing"]["prepare_cache_times"][i],
                    self.metrics["timing"]["per_sample_times"][i],
                ])
        
        print(f"Saved timing data to {csv_path}")

    def _save_performance_csv(self):
        """Save detailed performance data as CSV."""

        csv_path = self.save_dir / "performance_data.csv"

        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["batch_idx", "mse", "mae", "log_mean_likelihood", "mean_log_likelihood"])

            for i in range(len(self.metrics["performance"]["mse"])):
                writer.writerow([
                    i,
                    self.metrics["performance"]["mse"][i],
                    self.metrics["performance"]["mae"][i],
                    self.metrics["performance"]["log_mean_likelihood"][i],
                    self.metrics["performance"]["mean_log_likelihood"][i],
                ])

        print(f"Saved performance data to {csv_path}")


def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate trained ACE model")
    parser.add_argument("checkpoint", type=str, help="Path to model checkpoint")
    parser.add_argument("--data-path", type=str, default=None, help="Path to offline evaluation data")
    parser.add_argument("--num-plot-functions", type=int, default=5, help="Max number of functions to be plotted")
    parser.add_argument("--num-eval-functions", type=int, default=1000, help="Total number of functions for evaluation statistics")
    parser.add_argument("--num-contexts", type=int, default=128, help="Number of context points per prediction")
    parser.add_argument("--num-targets", type=int, default=256, help="Number of target points per prediction")
    parser.add_argument("--repetition-per-function", type=int, default=100, help="Repetition of inference per function, each with an individual order of target points")
    parser.add_argument("--independent-sample", type=string2bool, default=False, help="Samples targets independently if specified")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")
    parser.add_argument("--save-dir", type=str, default="./eval_results", help="Directory to save results")
    parser.add_argument("--compile", type=string2bool, default=False, help="Use torch.compile on inference methods")
    
    args = parser.parse_args()
    
    print("Initialize evaluations")
    print(f"Set max {args.num_contexts} context points")
    print(f"        {args.num_targets} target points")
    # Create evaluator
    evaluator = BaselineModelEvaluator(
        checkpoint_path=args.checkpoint,
        data_path=args.data_path,
        eval_functions=args.num_eval_functions,
        num_predictions_per_target=args.repetition_per_function,
        max_context_points=args.num_contexts,
        max_target_points=args.num_targets,
        independent_sample=args.independent_sample,
        device=args.device,
        save_dir=args.save_dir,
        compile=args.compile,
    )
    
    # Run evaluation
    if args.num_plot_functions > 0:
        print("Plot a few predictive examples")
        evaluator.plot_evaluation(max_number_of_plotted_function=args.num_plot_functions)
    print("Start evaluation")
    evaluator.run_evaluation()

    print("\nEvaluation complete! Results saved to:", args.save_dir)


if __name__ == "__main__":
    main()