"""Evaluation script for ACE model with comprehensive metrics and timing."""

import argparse
import json
import csv
import time
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from src.data.utils import OfflineBatchLoader, SamplePermutationHelper
from tqdm import tqdm

#from src.data.gp_sampler import GPBatchLoader, GPSampler, generate_offline_batches
from src.models.ace import AmortizedConditioningEngine, InferenceEngine2
from src.models.modules import Embedder, MixtureGaussian, Transformer, MultiChannelMixtureGaussian
from src.utils import DataAttr
from src.evaluate_model import ModelEvaluator, string2bool
from src.data.bav_real_data import BavTrueDataloader


class BavRealEvaluator(ModelEvaluator):
    """Evaluator for ACE model on Bavarian real-world dataset."""
    
    def __init__(
        self,
        checkpoint_path: str,
        K: int = 4,
        data_path: Optional[str] = None,
        eval_functions: int = 1000,
        num_predictions_per_target: int = 100,
        max_context_points: int = 128,
        max_target_points: int = 256,
        device: str = "cuda",
        save_dir: str = "./eval_results",
        compile: bool = True,
        randomize_idx_file: str = None,
    ):
        self.randomize_idx_file = randomize_idx_file

        super().__init__(
            checkpoint_path=checkpoint_path,
            K=K,
            data_path=data_path,
            eval_functions=eval_functions,
            num_predictions_per_target=num_predictions_per_target,
            max_context_points=max_context_points,
            max_target_points=max_target_points,
            device=device,
            save_dir=save_dir,
            compile=compile,
        )

        self.total_num_points_eval = self.max_context_points + self.max_target_points    

    def _build_dataloader(self, data_path: str) -> None:
        """
        Load real data using the BavTrueDataloader. 
        Note that this will load data with dummy empty point as context
        and put rest of the points as target.
        """
        return BavTrueDataloader(data_path, device=self.device, randomize_idx_file=self.randomize_idx_file)
    
    def _split_context_target(self, data_list):
        """Split data into context and target sets."""

        if self.max_context_points == 0:
            return data_list
        
        else:
            new_list = []
            for batch in data_list:
                x, y = batch.xt, batch.yt  # [B, N, D]
                # Precompute slice boundaries
                c_end = self.max_context_points
                t_end = self.max_target_points+c_end

                assert t_end <= x.size(1), "Requested slices exceed available points"

                context = slice(0, c_end)
                target   = slice(c_end, t_end)

                new_batch = DataAttr(
                    xc=x[:, context, :],
                    yc=y[:, context, :],
                    xb=None,
                    yb=None,
                    xt=x[:, target, :],
                    yt=y[:, target, :],
                )
                new_list.append(new_batch)

            return new_list

            
            

    def run_evaluation(self):
        """Run evaluation on the Bavarian dataset."""

        data_list = self.dataloader.load_data()

        assert data_list[0].xt.shape[1] >= self.total_num_points_eval, \
            f"Not enough points in loaded data. Required: {self.total_num_points_eval}, Available: {data_list[0].xt.shape[1]}"

        data_list = self._split_context_target(data_list)
        all_metrics = []
        total_samples = 0
        total_functions = 0
        eval_start = time.time()

        for i, batch in tqdm(enumerate(data_list), total=len(data_list), desc="Evaluating"):
            metrics, yhat = self.evaluate_batch(batch, self.num_predictions_per_target)
            all_metrics.append(metrics)

            # Update timing metrics
            self.metrics["timing"]["sample_sequence_times"].append(metrics["sequence_time"])
            self.metrics["timing"]["eval_sequence_ll_times"].append(metrics["sequence_ll_time"])
            self.metrics["timing"]["per_sample_times"].append(metrics["per_sample_time"])
            self.metrics["performance"]["mae"].append(metrics["mae"])
            self.metrics["performance"]["mse"].append(metrics["mse"])
            self.metrics["performance"]["log_mean_likelihood"].append(metrics["log_mean_likelihood"])
            self.metrics["performance"]["mean_log_likelihood"].append(metrics["mean_log_likelihood"])
            total_functions += metrics["batch_size"]
            total_samples += metrics["batch_size"] * metrics["num_targets"] * self.num_predictions_per_target

            # Optionally save some predictions
            if i < 10:  # Save last batch element of first 10 batches of predictions
                self.metrics["predictions"].append({
                    "batch_idx": i,
                    "xc": batch.xc[-1].cpu().numpy().tolist(),
                    "yc": batch.yc[-1].cpu().numpy().tolist(),
                    "xt": batch.xt[-1].cpu().numpy().tolist(),
                    "yt": batch.yt[-1].cpu().numpy().tolist() if batch.yt is not None else None,
                    "predictions": yhat.cpu().numpy().tolist(),
                })
        
        # Calculate final statistics
        eval_time = time.time() - eval_start
        self.metrics["timing"]["total_inference_time"] = eval_time
        self.metrics["evaluation"]["total_samples"] = total_samples
        self.metrics["evaluation"]["samples_per_second"] = total_samples / eval_time
        
        # Aggregate batch metrics
        self._aggregate_metrics(all_metrics)
        
        print(f"\nEvaluation complete in {eval_time:.2f}s", flush=True)
        print(f"Total samples processed: {total_samples:,}", flush=True)
        print(f"Average throughput: {total_samples / eval_time:.1f} samples/s", flush=True)
        
        # Save results
        self.save_results()



def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate trained ACE model")
    parser.add_argument("checkpoint", type=str, help="Path to model checkpoint")
    parser.add_argument("--K", type=int, default=4, help="Decoding batch size for sample_sequence")
    parser.add_argument("--data-path", type=str, default="data/bav_real", help="Path to offline evaluation data")
    parser.add_argument("--num-plot-functions", type=int, default=5, help="Max number of functions to be plotted")
    parser.add_argument("--num-eval-functions", type=int, default=30, help="Total number of functions for evaluation statistics")
    parser.add_argument("--num-contexts", type=int, default=128, help="Number of context points per prediction")
    parser.add_argument("--num-targets", type=int, default=256, help="Number of target points per prediction")
    parser.add_argument("--repetition-per-function", type=int, default=100, help="Repetition of inference per function, each with an individual order of target points")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")
    parser.add_argument("--save-dir", type=str, default="./eval_results", help="Directory to save results")
    parser.add_argument("--compile", type=string2bool, default=True, help="Use torch.compile on inference methods")
    parser.add_argument("--rand_idx_file", type=str, default=None, help="Path to randomization index file for data loading")
    
    
    args = parser.parse_args()
    
    # Create evaluator
    evaluator = BavRealEvaluator(
        checkpoint_path=args.checkpoint,
        K=args.K,
        data_path=args.data_path,
        eval_functions=args.num_eval_functions,
        num_predictions_per_target=args.repetition_per_function,
        max_context_points=args.num_contexts,
        max_target_points=args.num_targets,
        device=args.device,
        save_dir=args.save_dir,
        compile=args.compile,
        randomize_idx_file=args.rand_idx_file,
    )
    
    # Run evaluation
    evaluator.run_evaluation()
    
    print("\nEvaluation complete! Results saved to:", args.save_dir)


if __name__ == "__main__":
    main()
