"""
Evaluation script for DTLZ benchmark: forward and inverse performance metrics.

This script evaluates trained Diag-CFM and CFM models on the DTLZ benchmark,
computing:
- Forward MSE: Error in predicting labels from designs
- Round-trip error: Error when generating designs and evaluating with ground truth
- Design diversity: Variety of generated designs for same target label

The key advantage of DTLZ is that the forward function is analytical, enabling
precise round-trip error computation without surrogate model approximation.

Usage:
    python -m uq_diagcfm.evaluate_results_dtlz [num_design_params] [epochs]

Examples:
    python -m uq_diagcfm.evaluate_results_dtlz 12
    python -m uq_diagcfm.evaluate_results_dtlz 50 50
"""

import torch
import numpy as np
from pathlib import Path
from typing import Dict, List
from tqdm import tqdm

from uq_diagcfm.utils import get_device
from uq_diagcfm.ensembles import (
    load_dtlz_diag_cfm_ensemble,
    load_dtlz_cfm_ensemble,
    load_dtlz_inn_ensemble,
)
from uq_diagcfm.data_utils_dtlz import (
    DTLZ_DATASET_NAME,
    DTLZDataset,
    make_dtlz_surrogate,
)
from uq_diagcfm.evaluation_utils import (
    forward_pass,
    inverse_pass,
    forward_pass_inn,
    inverse_pass_inn,
    compute_mse_metrics,
    compute_design_diversity,
    compute_roundtrip_errors,
    aggregate_diversity_stats,
    compute_summary_statistics,
    save_results,
    print_summary,
    validate_ensemble_parameter_counts,
    compute_diversity_vs_epsilon,
    plot_diversity_vs_epsilon_combined,
    print_diversity_vs_epsilon_description,
)


def compute_forward_metrics(
    model,
    test_designs,
    test_labels,
    num_objectives,
    device,
    diag_cfm=True,
    is_inn=False,
):
    """
    Compute forward performance metrics.

    Args:
        model: Trained model.
        test_designs: Test design parameters (N, P).
        test_labels: Test labels (N, L).
        num_objectives: Number of objectives L.
        device: Computation device.
        diag_cfm: Whether this is a Diag-CFM model (ignored for INN).
        is_inn: Whether this is an INN model.

    Returns:
        Dictionary with forward metrics.
    """
    model.eval()
    with torch.no_grad():
        if is_inn:
            y_pred = forward_pass_inn(model, test_designs, device)
        else:
            y_pred = forward_pass(
                model, test_designs, num_objectives, device, diag_cfm=diag_cfm
            )

    y_true = test_labels.cpu().numpy()
    y_pred_np = y_pred.cpu().numpy()

    return compute_mse_metrics(y_true, y_pred_np)


def compute_inverse_metrics(
    model,
    test_labels,
    num_design_params,
    num_objectives,
    device,
    ground_truth_forward,
    num_samples=10,
    diag_cfm=True,
    is_inn=False,
):
    """
    Compute inverse performance metrics using ground truth forward function.

    Args:
        model: Trained model.
        test_labels: Test labels (N, L).
        num_design_params: Design dimension P.
        num_objectives: Number of objectives L.
        device: Computation device.
        ground_truth_forward: Analytical forward function (x -> y).
        num_samples: Number of designs to generate per target.
        diag_cfm: Whether this is a Diag-CFM model (ignored for INN).
        is_inn: Whether this is an INN model.

    Returns:
        Dictionary with inverse metrics.
    """
    model.eval()
    batch_size = test_labels.shape[0]

    all_roundtrip_errors = []
    all_diversities_var = []

    # Process in batches to manage memory
    eval_batch_size = 100

    with torch.no_grad():
        for start_idx in tqdm(
            range(0, batch_size, eval_batch_size), desc="Computing inverse metrics"
        ):
            end_idx = min(start_idx + eval_batch_size, batch_size)
            batch_labels = test_labels[start_idx:end_idx]

            # Generate designs
            if is_inn:
                x_samples = inverse_pass_inn(
                    model,
                    batch_labels,
                    num_design_params,
                    device,
                    num_samples,
                )
            else:
                x_samples = inverse_pass(
                    model,
                    batch_labels,
                    num_design_params,
                    num_objectives,
                    device,
                    num_samples,
                    diag_cfm=diag_cfm,
                )

            # Define forward function with clamping for DTLZ
            def forward_fn(x):
                x_clamped = torch.clamp(x, 0, 1)
                return ground_truth_forward(x_clamped)

            # Compute round-trip errors
            errors = compute_roundtrip_errors(batch_labels, x_samples, forward_fn)
            all_roundtrip_errors.append(errors)

            # Compute design diversity
            diversity = compute_design_diversity(x_samples)
            all_diversities_var.append(diversity["diversity_var"])

    roundtrip_errors = np.concatenate(all_roundtrip_errors, axis=0)
    diversities_var = np.concatenate(all_diversities_var, axis=0)

    result = {
        "inverse_roundtrip_error_mean": float(np.mean(roundtrip_errors)),
        "inverse_roundtrip_error_std": float(np.std(roundtrip_errors)),
    }
    result.update(aggregate_diversity_stats(diversities_var))

    return result


def compute_ensemble_metrics(
    models, test_designs, num_objectives, device, diag_cfm=True, is_inn=False
):
    """
    Compute ensemble consistency metrics.

    Args:
        models: List of trained models.
        test_designs: Test design parameters (N, P).
        num_objectives: Number of objectives L.
        device: Computation device.
        diag_cfm: Whether these are Diag-CFM models (ignored for INN).
        is_inn: Whether these are INN models.

    Returns:
        Dictionary with ensemble metrics.
    """
    all_preds = []

    for model in models:
        model.eval()
        with torch.no_grad():
            if is_inn:
                y_pred = forward_pass_inn(model, test_designs, device)
            else:
                y_pred = forward_pass(
                    model, test_designs, num_objectives, device, diag_cfm=diag_cfm
                )
            all_preds.append(y_pred)

    # Stack predictions: (num_models, N, L)
    ensemble_preds = torch.stack(all_preds, dim=0)

    # Compute variance across models
    ensemble_var = ensemble_preds.var(dim=0).mean().item()
    ensemble_std = np.sqrt(ensemble_var)

    return {
        "ensemble_variance": float(ensemble_var),
        "ensemble_std": float(ensemble_std),
    }


def evaluate_loaded_dtlz_ensemble(
    models: List,
    run_infos: List[Dict],
    checkpoint_names: List[str],
    num_design_params: int,
    num_objectives: int,
    device,
    ground_truth_forward,
    test_designs,
    test_labels,
    num_inverse_samples: int = 10,
) -> List[Dict]:
    """Evaluate an already-loaded ensemble of DTLZ models.

    Args:
        models: List of loaded PyTorch models.
        run_infos: List of run info dictionaries.
        checkpoint_names: List of checkpoint directory names.
        num_design_params: Design dimension P.
        num_objectives: Number of objectives L.
        device: Torch device.
        ground_truth_forward: Ground truth forward function.
        test_designs: Test design parameters.
        test_labels: Test labels.
        num_inverse_samples: Number of designs to generate per target.

    Returns:
        List of result dictionaries for each model.
    """
    all_results = []
    for i, (model, run_info, ckpt_name) in enumerate(
        zip(models, run_infos, checkpoint_names)
    ):
        is_inn = run_info.get("model_type", "") == "INN"
        is_diag_cfm = run_info.get("diag_cfm", True)

        print(f"\n{'='*60}")
        print(f"Evaluating model {i+1}/{len(models)}")
        if is_inn:
            print(f"  model_type: INN")
            print(f"  num_blocks: {run_info.get('num_blocks', 'N/A')}")
            print(f"  hidden_dim: {run_info.get('hidden_dim', 'N/A')}")
        else:
            print(f"  diag_cfm: {is_diag_cfm}")
            print(f"  shuffle_seed: {run_info.get('shuffle_params_seed', None)}")
        print(f"  checkpoint: {ckpt_name}")
        print(f"{'='*60}")

        # Forward metrics
        print("Computing forward metrics...")
        forward_metrics = compute_forward_metrics(
            model,
            test_designs,
            test_labels,
            num_objectives,
            device,
            diag_cfm=is_diag_cfm,
            is_inn=is_inn,
        )
        print(f"  Forward MSE: {forward_metrics['forward_mse']:.6f}")

        # Inverse metrics
        print("Computing inverse metrics...")
        inverse_metrics = compute_inverse_metrics(
            model,
            test_labels,
            num_design_params,
            num_objectives,
            device,
            ground_truth_forward,
            num_inverse_samples,
            diag_cfm=is_diag_cfm,
            is_inn=is_inn,
        )
        print(
            f"  Round-trip error: {inverse_metrics['inverse_roundtrip_error_mean']:.6f}"
        )
        print(
            f"  Design diversity: {inverse_metrics['inverse_design_diversity_var_mean']:.4f}"
        )

        # Combine results
        result = {
            **run_info,
            **forward_metrics,
            **inverse_metrics,
            "checkpoint_name": ckpt_name,
            "num_parameters": sum(p.numel() for p in model.parameters()),
        }
        all_results.append(result)

    return all_results


def evaluate_diversity_vs_epsilon(
    P: int,
    num_objectives: int = 3,
    function_name: str = "dtlz2",
    sampling_strategy: str = "stratified",
    num_samples: int = 10,
    num_epsilon_points: int = 50,
    num_test_samples: int = 1000,
):
    """
    Evaluate diversity vs epsilon for DTLZ models and generate combined plot.

    Args:
        P: Number of design parameters.
        num_objectives: Number of objectives L.
        function_name: DTLZ function name.
        sampling_strategy: Sampling strategy for test dataset.
        num_samples: Number of design samples per target.
        num_epsilon_points: Number of epsilon values to evaluate.
        num_test_samples: Number of test samples to use.
    """
    import numpy as np
    from uq_diagcfm.paths import PAPER_FIGURES_DIR, ensure_paper_dirs_exist

    device = get_device()
    print(f"Using device: {device}")

    # Create test dataset
    test_dataset = DTLZDataset(
        split="test",
        num_design_params=P,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=sampling_strategy,
    )
    label_scale = test_dataset.label_scale

    # Get test samples
    test_labels = test_dataset.labels[:num_test_samples].to(device)
    print(f"Test samples: {test_labels.shape[0]}")

    # Create ground truth forward function (normalized)
    base_forward = make_dtlz_surrogate(function_name, num_objectives)

    def ground_truth_forward(x):
        x_clamped = torch.clamp(x, 0, 1)
        return base_forward(x_clamped) / label_scale

    # Define epsilon values to sweep (log scale)
    epsilon_values = np.logspace(-5, 0, num_epsilon_points)

    results = {}

    # Ensemble loading functions for each model type
    ensemble_loaders = [
        ("Diag-CFM", load_dtlz_diag_cfm_ensemble),
        ("CFM", load_dtlz_cfm_ensemble),
        ("INN", load_dtlz_inn_ensemble),
    ]

    for model_name, load_fn in ensemble_loaders:
        print(f"\n{'='*60}")
        print(f"Loading {model_name} model...")

        try:
            models, run_infos, _, criteria = load_fn(P=P, device=device, verbose=True)
        except Exception as e:
            print(f"Failed to load {model_name}: {e}")
            continue

        if not models:
            print(f"No {model_name} models found matching criteria")
            continue

        # Use the first model (sorted by train loss)
        model = models[0]
        run_info = run_infos[0]
        is_inn = run_info.get("model_type", "") == "INN"
        diag_cfm = run_info.get("diag_cfm", True)

        print(f"Evaluating {model_name}...")

        model.eval()
        with torch.no_grad():
            # Generate designs
            if is_inn:
                x_samples = inverse_pass_inn(model, test_labels, P, device, num_samples)
            else:
                x_samples = inverse_pass(
                    model,
                    test_labels,
                    P,
                    num_objectives,
                    device,
                    num_samples,
                    diag_cfm=diag_cfm,
                )

            # Compute round-trip errors
            roundtrip_errors = []
            for i in range(num_samples):
                x_gen = x_samples[i]
                y_pred = ground_truth_forward(x_gen)
                error = torch.mean((test_labels - y_pred) ** 2, dim=1)
                roundtrip_errors.append(error)
            roundtrip_errors = torch.stack(roundtrip_errors, dim=0)

        # Compute diversity for each epsilon
        model_results = compute_diversity_vs_epsilon(
            x_samples, roundtrip_errors, epsilon_values
        )
        results[model_name] = model_results

        print(f"  Diversity at max eps: {model_results['diversity_var_mean'][-1]:.6f}")
        print(f"  Valid samples at max eps: {model_results['num_valid_mean'][-1]:.1f}")

    if not results:
        print("No models were successfully evaluated!")
        return

    # Create combined plot
    ensure_paper_dirs_exist()
    output_path = PAPER_FIGURES_DIR / f"diversity_vs_epsilon_dtlz_P{P}.pdf"
    plot_diversity_vs_epsilon_combined(
        results,
        output_path,
        num_samples=num_samples,
        title=f"DTLZ (P={P}): Diversity vs Round-Trip Error Threshold",
    )

    # Print description
    print_diversity_vs_epsilon_description()


if __name__ == "__main__":
    import sys
    from uq_diagcfm.paths import RESULTS_DIR

    # Parse arguments
    # Usage: python -m uq_diagcfm.evaluate_results_dtlz [--diversity] P
    P = None  # required
    diversity_only = False

    positional_args = []
    for arg in sys.argv[1:]:
        if arg == "--diversity":
            diversity_only = True
        elif not arg.startswith("-"):
            positional_args.append(arg)

    # Parse positional arguments
    if len(positional_args) > 0:
        P = int(positional_args[0])

    if P is None:
        print("Usage: python -m uq_diagcfm.evaluate_results_dtlz [--diversity] P")
        print("  P: Number of design parameters (required)")
        sys.exit(1)

    if diversity_only:
        evaluate_diversity_vs_epsilon(P=P)
        sys.exit(0)

    print(f"Evaluating DTLZ with P={P}")

    device = get_device()
    print(f"Using device: {device}")

    num_objectives = 3
    function_name = "dtlz2"
    sampling_strategy = "stratified"
    num_test_samples = 1000
    num_inverse_samples = 10

    # Create test dataset
    test_dataset = DTLZDataset(
        split="test",
        num_design_params=P,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=sampling_strategy,
    )
    label_scale = test_dataset.label_scale

    # Get test samples
    test_designs = test_dataset.designs[:num_test_samples].to(device)
    test_labels = test_dataset.labels[:num_test_samples].to(device)
    print(f"Test samples: {test_designs.shape[0]}")

    # Create ground truth forward function (normalized to match training)
    base_forward = make_dtlz_surrogate(function_name, num_objectives)

    def ground_truth_forward(x):
        return base_forward(x) / label_scale

    all_results = []

    # Load and evaluate Diag-CFM models
    print("\n--- Evaluating Diag-CFM models ---")
    diag_cfm_models, diag_cfm_run_infos, diag_cfm_ckpt_names, diag_cfm_criteria = (
        load_dtlz_diag_cfm_ensemble(P=P, device=device, verbose=True)
    )
    print(f"Diag-CFM criteria: {diag_cfm_criteria}")
    if diag_cfm_models:
        diag_cfm_results = evaluate_loaded_dtlz_ensemble(
            diag_cfm_models,
            diag_cfm_run_infos,
            diag_cfm_ckpt_names,
            P,
            num_objectives,
            device,
            ground_truth_forward,
            test_designs,
            test_labels,
            num_inverse_samples=num_inverse_samples,
        )
        all_results.extend(diag_cfm_results)
    else:
        print("No Diag-CFM models found")

    # Load and evaluate vanilla CFM models
    print("\n--- Evaluating CFM models ---")
    cfm_models, cfm_run_infos, cfm_ckpt_names, cfm_criteria = load_dtlz_cfm_ensemble(
        P=P, device=device, verbose=True
    )
    print(f"CFM criteria: {cfm_criteria}")
    if cfm_models:
        cfm_results = evaluate_loaded_dtlz_ensemble(
            cfm_models,
            cfm_run_infos,
            cfm_ckpt_names,
            P,
            num_objectives,
            device,
            ground_truth_forward,
            test_designs,
            test_labels,
            num_inverse_samples=num_inverse_samples,
        )
        all_results.extend(cfm_results)
    else:
        print("No CFM models found")

    # Load and evaluate INN models
    print("\n--- Evaluating INN models ---")
    inn_models, inn_run_infos, inn_ckpt_names, inn_criteria = load_dtlz_inn_ensemble(
        P=P, device=device, verbose=True
    )
    print(f"INN criteria: {inn_criteria}")
    if inn_models:
        inn_results = evaluate_loaded_dtlz_ensemble(
            inn_models,
            inn_run_infos,
            inn_ckpt_names,
            P,
            num_objectives,
            device,
            ground_truth_forward,
            test_designs,
            test_labels,
            num_inverse_samples=num_inverse_samples,
        )
        all_results.extend(inn_results)
    else:
        print("No INN models found")

    if all_results:
        # Validate that all models of the same type have the same parameter count
        for r in all_results:
            if "number_of_parameters" not in r and "num_parameters" in r:
                r["number_of_parameters"] = r["num_parameters"]
        validate_ensemble_parameter_counts(all_results)

        # Compute summary with all results
        summary = compute_summary_statistics(all_results)

        # Print summary
        print_summary(summary)

        # Save results
        output_dir = RESULTS_DIR
        output_dir.mkdir(exist_ok=True)
        output_path = output_dir / f"dtlz_P{P}_evaluation_results.json"
        save_results(all_results, summary, output_path)

        # Run diversity vs epsilon analysis
        print("\n" + "=" * 70)
        print("DIVERSITY VS EPSILON ANALYSIS")
        print("=" * 70)
        evaluate_diversity_vs_epsilon(P=P)
    else:
        print("No results to save.")
