"""Evaluation script for ACE model with comprehensive metrics and timing."""

import argparse
import json
import csv
import time
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from src.data.utils import OfflineBatchLoader, SamplePermutationHelper
from tqdm import tqdm

#from src.data.gp_sampler import GPBatchLoader, GPSampler, generate_offline_batches
from src.models.ace import AmortizedConditioningEngine, InferenceEngine2
from src.models.modules import Embedder, MixtureGaussian, Transformer, MultiChannelMixtureGaussian
from src.utils import DataAttr

def string2bool(b):
    if isinstance(b, bool):
        return b
    if b.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif b.lower() in ("no", "false", "f", "n", "0"):
        return False

class ModelEvaluator:
    """Professional evaluation framework for ACE models."""
    
    def __init__(
        self,
        checkpoint_path: str,
        K: int = 4,
        data_path: Optional[str] = None,
        eval_functions: int=1000,
        num_predictions_per_target: int=100,
        max_context_points: int=4,
        max_target_points: int=128,
        independent_sample: bool=False,
        device: str = "cuda",
        save_dir: str = "./eval_results",
        compile: bool = False,
    ):
        """Initialize evaluator with model and evaluation settings.

        Args:
            checkpoint_path: Path to model checkpoint file.
            K: Decoding batch size for sample_sequence.
            data_path: Path to evaluation data (offline only).
            eval_functions: Total number of functions/samples to evaluate.
            num_predictions_per_target: Number of predictions per target.
            max_context_points: Max context points.
            max_target_points: Max target points.
            independent_sample: Whether to sample predictions independently.
            device: "auto", "cpu", "cuda", or explicit device string like "cuda:0".
            save_dir: Directory to save evaluation results.
            compile_model: Whether to use torch.compile on the model (PyTorch 2.x).
        """
        self.device = torch.device(device)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.K=K

        # Setup data
        print(f"Loading data from {data_path}")
        self.data_path = data_path
        if data_path:
            self.dataloader = self._build_dataloader(data_path)
        else:
            raise NotImplementedError
        self.batch_size = self.dataloader.dataset.batch_size
        self.eval_functions = eval_functions
        self.num_predictions_per_target = num_predictions_per_target
        self.max_context_points = max_context_points
        self.max_target_points = max_target_points
        self.independent_sample = independent_sample
        n_batch_needed = self.eval_functions // self.batch_size
        if n_batch_needed * self.batch_size < self.eval_functions:
            n_batch_needed += 1
        self.num_eval_batches = min(n_batch_needed, len(self.dataloader))
        print(f"Loaded {len(self.dataloader)} batches with batch size {self.batch_size}", flush=True)

        # Load model and config
        print(f"Loading checkpoint from {checkpoint_path}")
        self.checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.config = OmegaConf.create(self.checkpoint["config"])
        self.epoch = self.checkpoint.get("epoch", 0)
        self.global_step = self.checkpoint.get("global_step", 0)
        
        # Build and load model
        self.model = self._build_model()
        
        # Handle torch.compile prefix if present
        state_dict = self.checkpoint["model_state_dict"]
        if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
            # Remove _orig_mod. prefix from keys
            state_dict = {
                key.replace("_orig_mod.", ""): value 
                for key, value in state_dict.items()
            }
        
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Compile model if requested
        self.compile = compile
        if self.compile:
            print("Compiling model with torch.compile (mode=default)...")
            self.model = torch.compile(self.model, mode="default")
        
        # Create inference engine
        self.inference_engine = InferenceEngine2.from_trained_model(self.model)
        self.inference_engine = self.inference_engine.to(self.device)

            
        # Initialize metrics storage
        self.metrics = {
            "timing": {
                "sample_sequence_times": [],
                "eval_sequence_ll_times": [],
                "batch_decode_times": [],
                "per_sample_times": [],
                "total_inference_time": 0.0,
                # Detailed profiling
                "prefill_times": [],  # Self-attention prefill
                "transformer_decode_times": [],  # Cache-based decode
                "embedding_times": [],  # Embedding computation
                "head_times": [],  # Head prediction
                "context_update_times": [],  # Context embedding updates
            },
            "performance": {
                "mae": [],
                "mse": [],
                "log_mean_likelihood": [],
                "mean_log_likelihood": [],
            },
            "model_stats": {
                "checkpoint_path": checkpoint_path,
                "epoch": self.epoch,
                "global_step": self.global_step,
                "num_parameters": sum(p.numel() for p in self.model.parameters()),
                "device": str(self.device),
                "compiled": self.compile,
            },
            "evaluation": {
                "num_batches": self.num_eval_batches,
                "batch_size": self.batch_size,
                "K": self.K,
                "total_samples": 0,
            },
            "data_stats": {},
            "predictions": [],
        }
    
    def _build_model(self) -> AmortizedConditioningEngine:
        """Build model from config."""
        cfg = self.config.model
        
        embedder = Embedder(
            dim_x=cfg.dim_x,
            dim_y=cfg.dim_y,
            hidden_dim=cfg.embedder.hidden_dim,
            out_dim=cfg.dim_model,  # Use dim_model from config
            depth=cfg.embedder.depth,
        )
        
        backbone = Transformer(
            num_layers=cfg.backbone.num_layers,
            dim_model=cfg.dim_model,
            num_head=cfg.backbone.num_heads,
            dim_feedforward=cfg.backbone.dim_feedforward,
            dropout=cfg.backbone.dropout,
        )

        if cfg.dim_y == 1:
            head = MixtureGaussian(
                dim_y=cfg.dim_y,
                dim_model=cfg.dim_model,
                dim_feedforward=cfg.head.dim_feedforward,
                num_components=cfg.head.num_components,
            )
        else:
            head = MultiChannelMixtureGaussian(
                dim_y=cfg.dim_y,
                dim_model=cfg.dim_model,
                dim_feedforward=cfg.head.dim_feedforward,
                num_components=cfg.head.num_components,
            )

        model = AmortizedConditioningEngine(
            embedder=embedder,
            backbone=backbone,
            head=head,
            max_buffer_size=cfg.max_buffer_size,
            targets_block_size_for_buffer_attend=cfg.targets_block_size_for_buffer_attend,
        )
        
        return model
    
    def _build_dataloader(self, data_path: str) -> DataLoader:
        """Build dataloader for given split."""
        data_path = Path(data_path)
        
        if not data_path.exists():
            raise ValueError(f"Data path does not exist: {data_path}")
        
        dataset = OfflineBatchLoader(data_path, device=str(self.device))
        
        return DataLoader(
            dataset,
            batch_size=None,  # Pre-batched
            shuffle=False,
            num_workers=0,
            pin_memory=(self.device.type == "cuda"),
        )

    def _batch_permute(self, batch: DataAttr, repetition: int):
        batch_random_permute, perm_info = SamplePermutationHelper.repeat_and_permute_batch(batch, repetition)
        return batch_random_permute, perm_info

    def _prediction_unpermute(self, yt_perm, perm_info):
        """yt_perm: [repetition, B, T, Dy]"""
        y_dep = SamplePermutationHelper.unpermute_targets(yt_perm, None, perm_info=perm_info)[0] # [repetition, B, T, Dy]
        return y_dep

    def _time_call(self, fn, *args, **kwargs):
        """Time a callable, using CUDA events when on GPU for accuracy."""
        if self.device.type == "cuda":
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = fn(*args, **kwargs)
            end.record(); end.synchronize()
            t = start.elapsed_time(end) / 1000.0  # seconds
        else:
            t0 = time.perf_counter()
            out = fn(*args, **kwargs)
            t = time.perf_counter() - t0
        return out, t


    def evaluate_batch(self, batch: DataAttr, repetition: int=100) -> Dict[str, float]:
        """Evaluate a single batch with timing."""

        batch_size = batch.xc.shape[0]
        num_targets = batch.xt.shape[1]

        if repetition > 1:
            batch_random_permute, perm_info = self._batch_permute(batch, repetition) # [repetition*B, T, Dx]
            batch_execute = batch_random_permute
        else:
            batch_execute = batch

        with torch.no_grad():
            # Time full sequence sampling
            pred, seq_time = self._time_call(
                self.inference_engine.sample_sequence, batch_execute, K=self.K
            )

            yhat = pred.yc
            yhat = yhat.view(repetition, batch_size, num_targets, -1)

            if repetition > 1:
                yhat = self._prediction_unpermute(yhat, perm_info) # [repetition, B, T, Dy]


            # Time joint log likelihood evaluation if ground truth available
            ll, ll_time = self._time_call(
                self.inference_engine.evaluate_joint_loglikelihood, batch_execute, K=self.K
            )
            ll = ll.sum(-2).view(repetition, batch_size)  # [repetition, B]
 
        # Run detailed profiling on a subset
        if len(self.metrics["timing"]["prefill_times"]) < 3:  # Profile first 3 batches
            profile_metrics = self._profile_inference_components(batch)
            for key, value in profile_metrics.items():
                if key in self.metrics["timing"]:
                    self.metrics["timing"][key].append(value)
        
        # Calculate metrics
        per_sample_time = seq_time / (batch_size * repetition * num_targets)

        # Store batch metrics
        batch_metrics = {
            "batch_size": batch_size,
            "num_targets": num_targets,
            "num_context": batch.xc.shape[1],
            "sequence_time": seq_time,
            "sequence_ll_time": ll_time,
            "per_sample_time": per_sample_time,
            "throughput_samples_per_sec": (batch_size * num_targets * repetition) / seq_time,
        }
    
        # Calculate prediction error if ground truth available
        if hasattr(batch, "yt") and batch.yt is not None:
            y_true = batch.yt
            y_true = y_true.unsqueeze(0)  # [1, B, T, Dy]
            # Handle case where sample_sequence doesn't predict all targets due to T // K
            num_predictions = yhat.shape[1]
            if num_predictions < y_true.shape[1]:
                # Only compare the predictions we have
                batch_yt_truncated = y_true[:, :num_predictions, :]
                mse = torch.mean((yhat - batch_yt_truncated) ** 2).item()
                mae = torch.mean(torch.abs(yhat - batch_yt_truncated)).item()
                llmean = (ll[:, :num_predictions].exp()).mean().log().item()
                meanll = ll[:, :num_predictions].mean().item()
                batch_metrics["warning"] = f"Only {num_predictions}/{y_true.shape[1]} targets predicted (T//K issue)"
                print(batch_metrics["warning"])
            else:
                mse = torch.mean((yhat - y_true) ** 2).item() 
                mae = torch.mean(torch.abs(yhat - y_true)).item()
                llmean = np.array((torch.logsumexp(ll.reshape(-1), dim=0) - math.log(ll.numel())).cpu()).item() #safe ll mean
                meanll = ll.mean().item()
            batch_metrics["mse"] = mse
            batch_metrics["mae"] = mae
            batch_metrics["log_mean_likelihood"] = llmean
            batch_metrics["mean_log_likelihood"] = meanll

        return batch_metrics, yhat

    def _profile_inference_components(self, batch: DataAttr) -> Dict[str, float]:
        """Profile individual components of inference."""
        import torch.nn.functional as F
        
        # We'll manually run through the inference steps with timing
        profile_times = {}
        
        # 1. Time context embedding
        embed_start = time.time()
        self.inference_engine.store_context_embeddings(batch)
        embed_time = time.time() - embed_start
        profile_times["embedding_times"] = embed_time
        
        # 2. Initialize KV cache
        T = batch.xt.shape[1]
        max_seq = self.inference_engine.context_embeddings.shape[1] + T
        self.inference_engine.init_kv_cache(batch.xt.shape[0], max_seq, device=str(batch.xt.device))
        
        # 3. Time prefill (self-attention)
        context_len = self.inference_engine.context_embeddings.shape[1]
        selfattention_mask = self.inference_engine._get_cached_selfattn_mask(
            context_len, str(batch.xt.device)
        )
        
        prefill_start = time.time()
        self.inference_engine.prefill_kv_cache(selfattention_mask)
        prefill_time = time.time() - prefill_start
        profile_times["prefill_times"] = prefill_time
        
        # 4. Profile a single decode step (using cache)
        # Get first target
        query = DataAttr(
            xc=None, yc=None, xb=None, yb=None,
            xt=batch.xt[:, 0:1, :],
            yt=batch.yt[:, 0:1, :] if batch.yt is not None else None
        )
        
        # Time embedding
        embed_start = time.time()
        query_embedding = self.inference_engine.embedder.embed_target(query)
        embed_single_time = time.time() - embed_start
        
        # Time transformer decode (cache-based)
        decode_start = time.time()
        z = self.inference_engine.transformer_decode(query_embedding)
        decode_time = time.time() - decode_start
        profile_times["transformer_decode_times"] = decode_time
        
        # Time head prediction
        head_start = time.time()
        head_output = self.inference_engine.head(z[:, -1, :].unsqueeze(1), num_samples=1)
        head_time = time.time() - head_start
        profile_times["head_times"] = head_time
        
        # Time context update
        yhat = head_output.samples.squeeze(2)
        from src.utils import create_context_buffer_datapoint
        prediction = create_context_buffer_datapoint(query, yhat)
        
        update_start = time.time()
        self.inference_engine.update_context_embeddings(prediction)
        update_time = time.time() - update_start
        profile_times["context_update_times"] = update_time
        
        return profile_times
    
    def run_evaluation(self):
        """Run full evaluation with comprehensive metrics."""
        # Comprehensive evaluation header
        print(f"\nStarting evaluation on {self.num_eval_batches} batches", flush=True)
        print(f"Checkpoint: {self.metrics['model_stats']['checkpoint_path']}", flush=True)
        print(f"Data path: {self.data_path}", flush=True)
        print(f"Device: {self.device} | Compiled: {self.compile}", flush=True)
        print(f"Batch size: {self.batch_size} | K (decode batch): {self.K}", flush=True)
        print(f"Max contexts: {self.max_context_points} | Max targets: {self.max_target_points}", flush=True)
        print(f"Repetition per function: {self.num_predictions_per_target}", flush=True)
        print(f"Planned eval functions: {self.eval_functions}", flush=True)
        print(f"Saving results to: {self.save_dir}", flush=True)
        
        all_metrics = []
        total_samples = 0
        total_functions = 0
        eval_start = time.time()
        
        # Progress bar
        pbar = tqdm(range(self.num_eval_batches), desc="Evaluating")
        
        with torch.no_grad():
            data_iter = iter(self.dataloader)
            for batch_idx in pbar:
                # Get batch
                batch = next(data_iter)
                C = min(batch.xc.shape[1], self.max_context_points)
                T = min(batch.xt.shape[1], self.max_target_points)

                if batch.xc.shape[1] > C or batch.xt.shape[1] > T:
                    batch = DataAttr(
                        xc=batch.xc[:, :C],
                        yc=batch.yc[:, :C],
                        xb=batch.xb if batch.xb is not None else None,
                        yb=batch.yb if batch.yb is not None else None,
                        xt=batch.xt[:, :T],
                        yt=batch.yt[:, :T] if batch.yt is not None else None,
                    )

                # Move to device if needed
                if batch.xc.device != self.device:
                    batch.to(self.device)
                    
                # Evaluate batch
                for i in range(self.batch_size):
                    func = DataAttr(
                        xc=batch.xc[i:i+1],
                        yc=batch.yc[i:i+1],
                        xb=batch.xb[i:i+1] if batch.xb is not None else None,
                        yb=batch.yb[i:i+1] if batch.yb is not None else None,
                        xt=batch.xt[i:i+1],
                        yt=batch.yt[i:i+1] if batch.yt is not None else None,
                    )

                    func_metrics, yhat = self.evaluate_batch(func, repetition=self.num_predictions_per_target)
                    all_metrics.append(func_metrics)

                    # Update timing metrics
                    self.metrics["timing"]["sample_sequence_times"].append(func_metrics["sequence_time"])
                    self.metrics["timing"]["eval_sequence_ll_times"].append(func_metrics["sequence_ll_time"])
                    self.metrics["timing"]["per_sample_times"].append(func_metrics["per_sample_time"])
                    self.metrics["performance"]["mae"].append(func_metrics["mae"])
                    self.metrics["performance"]["mse"].append(func_metrics["mse"])
                    self.metrics["performance"]["log_mean_likelihood"].append(func_metrics["log_mean_likelihood"])
                    self.metrics["performance"]["mean_log_likelihood"].append(func_metrics["mean_log_likelihood"])
                    total_functions += func_metrics["batch_size"]
                    total_samples += func_metrics["batch_size"] * func_metrics["num_targets"] * self.num_predictions_per_target
                
                # Optionally save some predictions
                if batch_idx < 10:  # Save last batch element of first 10 batches of predictions
                    self.metrics["predictions"].append({
                        "batch_idx": batch_idx,
                        "xc": batch.xc[-1].cpu().numpy().tolist(),
                        "yc": batch.yc[-1].cpu().numpy().tolist(),
                        "xt": batch.xt[-1].cpu().numpy().tolist(),
                        "yt": batch.yt[-1].cpu().numpy().tolist() if batch.yt is not None else None,
                        "predictions": yhat.cpu().numpy().tolist(), # yhat is the last yhat from above loop
                    })
        
        # Calculate final statistics
        eval_time = time.time() - eval_start
        self.metrics["timing"]["total_inference_time"] = eval_time
        self.metrics["evaluation"]["total_samples"] = total_samples
        self.metrics["evaluation"]["samples_per_second"] = total_samples / eval_time
        
        # Aggregate batch metrics
        self._aggregate_metrics(all_metrics)
        
        print(f"\nEvaluation complete in {eval_time:.2f}s", flush=True)
        print(f"Total samples processed: {total_samples:,}", flush=True)
        print(f"Average throughput: {total_samples / eval_time:.1f} samples/s", flush=True)
        
        # Save results
        self.save_results()
    
    def _aggregate_metrics(self, all_metrics: List[Dict]):
        """Aggregate metrics across all batches."""
        # Timing statistics
        timing_keys = ["sequence_time", "sequence_ll_time", "per_sample_time", "throughput_samples_per_sec"]
        for key in timing_keys:
            values = [m[key] for m in all_metrics]
            self.metrics["evaluation"][f"{key}_mean"] = float(np.mean(values))
            self.metrics["evaluation"][f"{key}_std"] = float(np.std(values))
            self.metrics["evaluation"][f"{key}_min"] = float(np.min(values))
            self.metrics["evaluation"][f"{key}_max"] = float(np.max(values))
            self.metrics["evaluation"][f"{key}_p50"] = float(np.percentile(values, 50))
            self.metrics["evaluation"][f"{key}_p90"] = float(np.percentile(values, 90))
            self.metrics["evaluation"][f"{key}_p99"] = float(np.percentile(values, 99))
        
        # Error metrics if available
        if "mse" in all_metrics[0]:
            for metric in ["mse", "mae", "log_mean_likelihood", "mean_log_likelihood"]:
                values = [m[metric] for m in all_metrics]
                self.metrics["evaluation"][f"{metric}_mean"] = float(np.mean(values))
                self.metrics["evaluation"][f"{metric}_std"] = float(np.std(values))
        
        # Data statistics
        self.metrics["data_stats"] = {
            "num_context_mean": float(np.mean([m["num_context"] for m in all_metrics])),
            "num_targets_mean": float(np.mean([m["num_targets"] for m in all_metrics])),
            "batch_size": self.batch_size,
        }
    
    def save_results(self):
        """Save evaluation results to disk."""
        # Save main metrics as JSON
        metrics_path = self.save_dir / "evaluation_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")
        
        # Save summary report
        self._generate_report()
        
        # Save timing & performance data as CSV for easy analysis
        self._save_timing_csv()
        self._save_performance_csv()
    
    def _generate_report(self):
        """Generate human-readable evaluation report."""
        report_path = self.save_dir / "evaluation_report.txt"
        
        with open(report_path, "w") as f:
            f.write("=" * 80 + "\n")
            f.write("ACE MODEL EVALUATION REPORT\n")
            f.write("=" * 80 + "\n\n")
            
            # Model information
            f.write("MODEL INFORMATION\n")
            f.write("-" * 40 + "\n")
            f.write(f"Checkpoint: {self.metrics['model_stats']['checkpoint_path']}\n")
            f.write(f"Epoch: {self.metrics['model_stats']['epoch']}\n")
            f.write(f"Global Step: {self.metrics['model_stats']['global_step']}\n")
            f.write(f"Parameters: {self.metrics['model_stats']['num_parameters']:,}\n")
            f.write(f"Device: {self.metrics['model_stats']['device']}\n\n")
            
            # Evaluation settings
            f.write("EVALUATION SETTINGS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Number of batches: {self.metrics['evaluation']['num_batches']}\n")
            f.write(f"Batch size: {self.metrics['evaluation']['batch_size']}\n")
            f.write(f"K (decode batch size): {self.metrics['evaluation']['K']}\n")
            f.write(f"Total samples: {self.metrics['evaluation']['total_samples']:,}\n\n")
            
            # Performance metrics
            f.write("PERFORMANCE METRICS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Total evaluation time: {self.metrics['timing']['total_inference_time']:.2f}s\n")
            f.write(f"Overall throughput: {self.metrics['evaluation']['samples_per_second']:.1f} samples/s\n\n")
            
            # Timing statistics
            f.write("TIMING STATISTICS (seconds)\n")
            f.write("-" * 40 + "\n")
            timing_metrics = ["sequence_time", "sequence_ll_time", "per_sample_time"]
            headers = ["Metric", "Mean", "Std", "Min", "Max", "P50", "P90", "P99"]
            f.write(f"{headers[0]:<20} " + " ".join(f"{h:>10}" for h in headers[1:]) + "\n")
            
            for metric in timing_metrics:
                f.write(f"{metric:<20} ")
                for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                    key = f"{metric}_{stat}"
                    value = self.metrics["evaluation"].get(key, 0)
                    f.write(f"{value:>10.6f} ")
                f.write("\n")
            
            # Throughput statistics
            f.write(f"\n{'throughput (samples/s)':<20} ")
            for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                key = f"throughput_samples_per_sec_{stat}"
                value = self.metrics["evaluation"].get(key, 0)
                f.write(f"{value:>10.1f} ")
            f.write("\n")
            
            # Error metrics if available
            if "mse_mean" in self.metrics["evaluation"]:
                f.write("\nPREDICTION ERROR METRICS\n")
                f.write("-" * 40 + "\n")
                f.write(f"MSE: {self.metrics['evaluation']['mse_mean']:.6f} ± {self.metrics['evaluation']['mse_std']:.6f}\n")
                f.write(f"MAE: {self.metrics['evaluation']['mae_mean']:.6f} ± {self.metrics['evaluation']['mae_std']:.6f}\n")
                f.write(f"log_mean_likelihood: {self.metrics['evaluation']['log_mean_likelihood_mean']:.6f} ± {self.metrics['evaluation']['log_mean_likelihood_std']:.6f}\n")
                f.write(f"mean_log_likelihood: {self.metrics['evaluation']['mean_log_likelihood_mean']:.6f} ± {self.metrics['evaluation']['mean_log_likelihood_std']:.6f}\n")

            # Detailed profiling results
            if self.metrics["timing"]["prefill_times"]:
                f.write("\nDETAILED PROFILING (seconds)\n")
                f.write("-" * 40 + "\n")
                
                # Calculate averages
                prefill_avg = np.mean(self.metrics["timing"]["prefill_times"])
                decode_avg = np.mean(self.metrics["timing"]["transformer_decode_times"])
                embed_avg = np.mean(self.metrics["timing"]["embedding_times"])
                head_avg = np.mean(self.metrics["timing"]["head_times"])
                update_avg = np.mean(self.metrics["timing"]["context_update_times"])
                
                f.write(f"Context Prefill (self-attention): {prefill_avg:.6f}\n")
                f.write(f"Transformer Decode (with cache):  {decode_avg:.6f}\n")
                f.write(f"Embedding Computation:            {embed_avg:.6f}\n")
                f.write(f"Head Prediction:                  {head_avg:.6f}\n")
                f.write(f"Context Update:                   {update_avg:.6f}\n")
                
                f.write(f"\nPrefill vs Cache Decode Ratio: {prefill_avg/decode_avg:.1f}x\n")
                f.write(f"Self-attention is {prefill_avg/decode_avg:.1f}x slower than cache-based decode\n")
                
                # Component breakdown for one decode step
                total_decode_step = decode_avg + head_avg
                f.write(f"\nSingle Token Generation Breakdown:\n")
                f.write(f"  Transformer: {decode_avg/total_decode_step*100:.1f}%\n")
                f.write(f"  Head:        {head_avg/total_decode_step*100:.1f}%\n")
                
                # Full sequence cost analysis
                avg_targets = self.metrics["data_stats"].get("num_targets_mean", 50)
                f.write(f"\nFull Sequence Generation Cost (avg {avg_targets:.0f} targets):\n")
                f.write(f"  Prefill (1x):     {prefill_avg:.6f}s\n")
                f.write(f"  Decode ({avg_targets:.0f}x):    {decode_avg * avg_targets:.6f}s\n")
                f.write(f"  Total:            {prefill_avg + decode_avg * avg_targets:.6f}s\n")
                f.write(f"  Prefill %:        {prefill_avg/(prefill_avg + decode_avg * avg_targets)*100:.1f}%\n")
                f.write(f"  Decode %:         {decode_avg * avg_targets/(prefill_avg + decode_avg * avg_targets)*100:.1f}%\n")
            
            f.write("\n" + "=" * 80 + "\n")
        
        print(f"Saved evaluation report to {report_path}")
    
    def _save_timing_csv(self):
        """Save detailed timing data as CSV."""
        
        csv_path = self.save_dir / "timing_data.csv"
        
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["batch_idx", "sequence_time", "sequence_ll_time", "per_sample_time"])

            for i in range(len(self.metrics["timing"]["sample_sequence_times"])):
                writer.writerow([
                    i,
                    self.metrics["timing"]["sample_sequence_times"][i],
                    self.metrics["timing"]["eval_sequence_ll_times"][i],
                    self.metrics["timing"]["per_sample_times"][i],
                ])
        
        print(f"Saved timing data to {csv_path}")

    def _save_performance_csv(self):
        """Save detailed performance data as CSV."""

        csv_path = self.save_dir / "performance_data.csv"

        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["batch_idx", "mse", "mae", "log_mean_likelihood", "mean_log_likelihood"])

            for i in range(len(self.metrics["performance"]["mse"])):
                writer.writerow([
                    i,
                    self.metrics["performance"]["mse"][i],
                    self.metrics["performance"]["mae"][i],
                    self.metrics["performance"]["log_mean_likelihood"][i],
                    self.metrics["performance"]["mean_log_likelihood"][i],
                ])

        print(f"Saved performance data to {csv_path}")


def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate trained ACE model")
    parser.add_argument("checkpoint", type=str, help="Path to model checkpoint")
    parser.add_argument("--K", type=int, default=4, help="Decoding batch size for sample_sequence")
    parser.add_argument("--data-path", type=str, default=None, help="Path to offline evaluation data")
    parser.add_argument("--num-plot-functions", type=int, default=5, help="Max number of functions to be plotted")
    parser.add_argument("--num-eval-functions", type=int, default=1000, help="Total number of functions for evaluation statistics")
    parser.add_argument("--num-contexts", type=int, default=128, help="Number of context points per prediction")
    parser.add_argument("--num-targets", type=int, default=256, help="Number of target points per prediction")
    parser.add_argument("--repetition-per-function", type=int, default=100, help="Repetition of inference per function, each with an individual order of target points")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")
    parser.add_argument("--save-dir", type=str, default="./eval_results", help="Directory to save results")
    parser.add_argument("--compile", type=string2bool, default=True, help="Use torch.compile on inference methods")
    
    
    args = parser.parse_args()
    
    # Create evaluator
    evaluator = ModelEvaluator(
        checkpoint_path=args.checkpoint,
        K=args.K,
        data_path=args.data_path,
        eval_functions=args.num_eval_functions,
        num_predictions_per_target=args.repetition_per_function,
        max_context_points=args.num_contexts,
        max_target_points=args.num_targets,
        device=args.device,
        save_dir=args.save_dir,
        compile=args.compile,
    )
    
    # Run evaluation
    evaluator.run_evaluation()
    
    print("\nEvaluation complete! Results saved to:", args.save_dir)


if __name__ == "__main__":
    main()
