"""Evaluation script for baseline model with comprehensive metrics and timing."""

import argparse
import json
import time
from typing import Optional

import numpy as np
from tqdm import tqdm

from src.data.bav_real_data import BavTrueDataloader
from src.evaluate_baseline_model import BaselineModelEvaluator
from src.evaluate_model import string2bool
from src.utils import DataAttr


class NumpyArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        # If the object is a NumPy array, convert it to a list
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        # If it's a NumPy integer or float, convert it to a standard Python type
        if isinstance(obj, (np.integer, np.floating)):
            return obj.item()
        # Let the base class default method raise the TypeError for other types
        return super(NumpyArrayEncoder, self).default(obj)


class BavRealEvaluator(BaselineModelEvaluator):
    """Evaluator for ACE model on Bavarian real-world dataset."""

    def __init__(
        self,
        checkpoint_path: str,
        data_path: Optional[str] = None,
        eval_functions: int = 1000,
        num_predictions_per_target: int = 100,
        max_context_points: int = 128,
        max_target_points: int = 256,
        device: str = "cuda",
        save_dir: str = "./eval_results",
        randomize_idx_file: str = None,
        independent_sample: bool = False,
    ):
        self.randomize_idx_file = randomize_idx_file

        super().__init__(
            checkpoint_path=checkpoint_path,
            data_path=data_path,
            eval_functions=eval_functions,
            num_predictions_per_target=num_predictions_per_target,
            max_context_points=max_context_points,
            max_target_points=max_target_points,
            device=device,
            save_dir=save_dir,
            independent_sample=independent_sample,
        )

        self.total_num_points_eval = self.max_context_points + self.max_target_points

    def _build_dataloader(self, data_path: str) -> None:
        """
        Load real data using the BavTrueDataloader.
        Note that this will load data with dummy empty point as context
        and put rest of the points as target.
        """
        return BavTrueDataloader(
            data_path, device=self.device, randomize_idx_file=self.randomize_idx_file
        )

    def _split_context_target(self, data_list):
        """Split data into context and target sets."""

        if self.max_context_points == 0:
            return data_list

        else:
            new_list = []
            for batch in data_list:
                x, y = batch.xt, batch.yt  # [B, N, D]
                # Precompute slice boundaries
                c_end = self.max_context_points
                t_end = self.max_target_points + c_end

                assert t_end <= x.size(1), "Requested slices exceed available points"

                context = slice(0, c_end)
                target = slice(c_end, t_end)

                new_batch = DataAttr(
                    xc=x[:, context, :],
                    yc=y[:, context, :],
                    xb=None,
                    yb=None,
                    xt=x[:, target, :],
                    yt=y[:, target, :],
                )
                new_list.append(new_batch)

            return new_list

    def run_evaluation(self):
        """Run evaluation on the Bavarian dataset."""

        data_list = self.dataloader.load_data()

        assert (
            data_list[0].xt.shape[1] >= self.total_num_points_eval
        ), f"Not enough points in loaded data. Required: {self.total_num_points_eval}, Available: {data_list[0].xt.shape[1]}"

        data_list = self._split_context_target(data_list)
        all_metrics = []
        total_samples = 0
        total_functions = 0
        eval_start = time.time()

        for i, batch in tqdm(
            enumerate(data_list), total=len(data_list), desc="Evaluating"
        ):
            metrics, yhat = self.evaluate_batch(batch, self.num_predictions_per_target)
            all_metrics.append(metrics)

            # Update timing metrics
            self.metrics["timing"]["sample_sequence_times"].append(
                metrics["sequence_time"]
            )
            self.metrics["timing"]["eval_sequence_ll_times"].append(
                metrics["sequence_ll_time"]
            )
            self.metrics["timing"]["prepare_cache_times"].append(
                metrics["cache_prep_time"]
            )
            self.metrics["timing"]["per_sample_times"].append(
                metrics["per_sample_time"]
            )
            self.metrics["performance"]["mae"].append(metrics["mae"])
            self.metrics["performance"]["mse"].append(metrics["mse"])
            self.metrics["performance"]["log_mean_likelihood"].append(
                metrics["log_mean_likelihood"]
            )
            self.metrics["performance"]["mean_log_likelihood"].append(
                metrics["mean_log_likelihood"]
            )
            total_functions += metrics["batch_size"]
            total_samples += (
                metrics["batch_size"]
                * metrics["num_targets"]
                * self.num_predictions_per_target
            )

            # Optionally save some predictions
            if i < 10:  # Save last batch element of first 10 batches of predictions
                self.metrics["predictions"].append(
                    {
                        "batch_idx": i,
                        "xc": batch.xc[-1].cpu().numpy().tolist(),
                        "yc": batch.yc[-1].cpu().numpy().tolist(),
                        "xt": batch.xt[-1].cpu().numpy().tolist(),
                        "yt": (
                            batch.yt[-1].cpu().numpy().tolist()
                            if batch.yt is not None
                            else None
                        ),
                        "predictions": yhat.cpu().numpy().tolist(),
                    }
                )

        # Calculate final statistics
        eval_time = time.time() - eval_start
        self.metrics["timing"]["total_inference_time"] = eval_time
        self.metrics["evaluation"]["batch_size_per_evaluation"] = 1
        self.metrics["evaluation"]["total_samples"] = total_samples
        self.metrics["evaluation"]["samples_per_second"] = total_samples / eval_time
        self.metrics["evaluation"][
            "num_predictions_per_target"
        ] = self.num_predictions_per_target

        # Aggregate batch metrics
        self._aggregate_metrics(all_metrics)

        print(f"\nEvaluation complete in {eval_time:.2f}s")
        print(f"Total samples processed: {total_samples:,}")
        print(f"Average throughput: {total_samples / eval_time:.1f} samples/s")

        # Save results
        self.save_results()

    def save_results(self):
        """Save evaluation results to disk."""
        # Save main metrics as JSON
        metrics_path = self.save_dir / "evaluation_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2, cls=NumpyArrayEncoder)
        print(f"Saved metrics to {metrics_path}")

        # Save summary report
        self._generate_report()

        # Save timing & performance data as CSV for easy analysis
        self._save_timing_csv()
        self._save_performance_csv()


def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate trained baseline model")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="checkpoints/tnpamg_baseline/bav_rho43/best_model.pt",
        help="Path to model checkpoint",
    )
    parser.add_argument(
        "--data-path",
        type=str,
        default="data/bav_real",
        help="Path to offline evaluation data",
    )
    parser.add_argument(
        "--num-plot-functions",
        type=int,
        default=5,
        help="Max number of functions to be plotted",
    )
    parser.add_argument(
        "--num-eval-functions",
        type=int,
        default=30,
        help="Total number of functions for evaluation statistics",
    )
    parser.add_argument(
        "--num-contexts",
        type=int,
        default=128,
        help="Number of context points per prediction",
    )
    parser.add_argument(
        "--num-targets",
        type=int,
        default=256,
        help="Number of target points per prediction",
    )
    parser.add_argument(
        "--repetition-per-function",
        type=int,
        default=100,
        help="Repetition of inference per function, each with an individual order of target points",
    )
    parser.add_argument("--device", type=str, default="cpu", help="Device to run on")
    parser.add_argument(
        "--save-dir",
        type=str,
        default="./eval_results/bav_prediction",
        help="Directory to save results",
    )
    parser.add_argument(
        "--rand_idx_file",
        type=str,
        default="data/bav_real/data_perm/perm_0.json",
        help="Path to randomization index file for data loading",
    )
    parser.add_argument(
        "--independent-sample",
        type=string2bool,
        default=False,
        help="Samples targets independently if specified",
    )

    args = parser.parse_args()

    # Create evaluator
    evaluator = BavRealEvaluator(
        checkpoint_path=args.checkpoint,
        data_path=args.data_path,
        eval_functions=args.num_eval_functions,
        num_predictions_per_target=args.repetition_per_function,
        max_context_points=args.num_contexts,
        max_target_points=args.num_targets,
        device=args.device,
        save_dir=args.save_dir,
        randomize_idx_file=args.rand_idx_file,
        independent_sample=args.independent_sample,
    )

    # Run evaluation
    evaluator.run_evaluation()

    print("\nEvaluation complete! Results saved to:", args.save_dir)


if __name__ == "__main__":
    main()
