"""
Evaluation script for computing forward and inverse performance metrics for Unifoil dataset.

Unifoil requires conditioning on physical parameters (angle of attack, Mach number),
so we use forward_pass and inverse_pass with the conditioning parameter.

Metrics computed:
- Forward MSE: How well the model predicts labels from designs
- Round-trip error: y -> x_gen -> y_pred (using surrogate), measure ||y - y_pred||
- Design diversity: Variance across generated samples
"""

import torch
import numpy as np
from pathlib import Path
from typing import Dict, List

from uq_diagcfm.utils import get_device
from uq_diagcfm.ensembles import (
    load_unifoil_diag_cfm_ensemble,
    load_unifoil_cfm_ensemble,
    load_unifoil_inn_ensemble,
)
from uq_diagcfm.data_utils_unifoil import (
    UnifoilDataset,
    UNIFOIL_DATASET_NAME,
    LEN_DESIGN_PARAMETERS,
    LEN_PHYSICAL_PARAMS,
    LEN_PHYSICAL_PERFORMANCE,
    make_unifoil_surrogate,
)
from uq_diagcfm.evaluation_utils import (
    forward_pass,
    inverse_pass,
    forward_pass_conditional_inn,
    inverse_pass_conditional_inn,
    compute_mse_metrics,
    compute_design_diversity,
    compute_summary_statistics,
    save_results,
    print_summary,
    validate_ensemble_parameter_counts,
    compute_diversity_vs_epsilon,
    plot_diversity_vs_epsilon_combined,
    print_diversity_vs_epsilon_description,
)

# Aliases for clarity
P = LEN_DESIGN_PARAMETERS  # 14
L = LEN_PHYSICAL_PERFORMANCE  # 3


def compute_forward_metrics(model, dataloader, diag_cfm, device, is_inn=False):
    """Compute forward performance metrics (MSE, per-label MSE).

    Args:
        model: Flow matching model or conditional INN
        dataloader: DataLoader yielding (x, physical_params, y) tuples
        diag_cfm: Whether using Diagonal CFM (ignored for INN)
        device: Torch device
        is_inn: Whether the model is a conditional INN

    Returns:
        Dictionary with forward_mse and forward_per_label_mse
    """
    all_y_true = []
    all_y_pred = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, physical_params, y = batch
            x = x.to(device)
            physical_params = physical_params.to(device)
            y = y.to(device)

            # Forward pass with conditioning
            if is_inn:
                y_pred = forward_pass_conditional_inn(model, x, physical_params, device)
            else:
                y_pred = forward_pass(
                    model, x, L, device, diag_cfm=diag_cfm, conditioning=physical_params
                )

            all_y_true.append(y.cpu())
            all_y_pred.append(y_pred.cpu())

    y_true = torch.cat(all_y_true, dim=0).numpy()
    y_pred = torch.cat(all_y_pred, dim=0).numpy()

    return compute_mse_metrics(y_true, y_pred)


def compute_roundtrip_metrics(
    model, dataloader, diag_cfm, device, surrogate_model, num_samples=10, is_inn=False
):
    """Compute round-trip error using the surrogate model: y -> x_gen -> y_pred.

    This metric uses the ground truth surrogate model to evaluate how well
    the generated designs match the target labels. For each target label y:
    1. Generate a design x_gen from y (inverse pass)
    2. Predict labels y_pred from x_gen using the surrogate model
    3. Measure ||y - y_pred||^2

    Args:
        model: Flow matching model or conditional INN
        dataloader: DataLoader yielding (x, physical_params, y) tuples
        diag_cfm: Whether using Diagonal CFM (ignored for INN)
        device: Torch device
        surrogate_model: Ground truth surrogate that maps (design, physical_params) -> labels
        num_samples: Number of samples to generate per target
        is_inn: Whether the model is a conditional INN

    Returns:
        Dictionary with round-trip error statistics
    """
    all_errors = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, physical_params, y = batch
            physical_params = physical_params.to(device)
            y = y.to(device)

            # Generate designs from labels (inverse pass)
            if is_inn:
                x_samples = inverse_pass_conditional_inn(
                    model, y, physical_params, P, device, num_samples=num_samples
                )
            else:
                x_samples = inverse_pass(
                    model,
                    y,
                    P,
                    L,
                    device,
                    num_samples=num_samples,
                    diag_cfm=diag_cfm,
                    conditioning=physical_params,
                )

            # Compute round-trip error for each sample using surrogate
            batch_errors = []
            for i in range(num_samples):
                x_gen = x_samples[i]

                # Clamp generated designs to reasonable range
                x_gen = torch.clamp(x_gen, -10, 10)

                # Concatenate design params with physical params for surrogate input
                surrogate_input = torch.cat([x_gen, physical_params], dim=1)

                # Predict labels using surrogate model
                y_pred = surrogate_model(surrogate_input)

                # MSE per sample
                errors = torch.mean((y - y_pred) ** 2, dim=1)
                batch_errors.append(errors)

            # Average across samples
            avg_errors = torch.stack(batch_errors, dim=0).mean(dim=0)
            all_errors.append(avg_errors.cpu())

    errors = torch.cat(all_errors, dim=0).numpy()

    return {
        "roundtrip_error_mean": float(np.mean(errors)),
        "roundtrip_error_std": float(np.std(errors)),
        "roundtrip_error_median": float(np.median(errors)),
        "roundtrip_error_90pct": float(np.percentile(errors, 90)),
    }


def compute_diversity_metrics(
    model, dataloader, diag_cfm, device, num_samples=10, is_inn=False
):
    """Compute design diversity metrics.

    For each target label y, generate multiple design samples and measure
    variance across samples (per design parameter, then averaged).

    Args:
        model: Flow matching model or conditional INN
        dataloader: DataLoader yielding (x, physical_params, y) tuples
        diag_cfm: Whether using Diagonal CFM (ignored for INN)
        device: Torch device
        num_samples: Number of samples to generate per target
        is_inn: Whether the model is a conditional INN

    Returns:
        Dictionary with diversity statistics
    """
    all_diversities_var = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, physical_params, y = batch
            physical_params = physical_params.to(device)
            y = y.to(device)

            # Generate multiple design samples
            if is_inn:
                x_samples = inverse_pass_conditional_inn(
                    model, y, physical_params, P, device, num_samples=num_samples
                )
            else:
                x_samples = inverse_pass(
                    model,
                    y,
                    P,
                    L,
                    device,
                    num_samples=num_samples,
                    diag_cfm=diag_cfm,
                    conditioning=physical_params,
                )

            # Compute diversity metrics
            diversity = compute_design_diversity(x_samples)
            all_diversities_var.append(diversity["diversity_var"])

    diversities_var = np.concatenate(all_diversities_var, axis=0)

    return {
        "design_diversity_mean_var": float(np.mean(diversities_var)),
        "design_diversity_std_var": float(np.std(diversities_var)),
    }


def evaluate_single_model(
    model, run_info, device, surrogate_model=None, num_inverse_samples=10
):
    """Evaluate a single model on the validation set.

    Args:
        model: Loaded PyTorch model
        run_info: Dictionary containing run information
        device: Torch device
        surrogate_model: Ground truth surrogate for round-trip error (optional)
        num_inverse_samples: Number of samples for inverse/diversity evaluation

    Returns:
        Dictionary with all metrics
    """
    is_inn = run_info.get("model_type") == "INN"
    diag_cfm = run_info.get("diag_cfm", False)

    # Load validation dataset
    val_dataset = UnifoilDataset(split="val")
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

    print("  Computing forward metrics...")
    forward_metrics = compute_forward_metrics(
        model, val_loader, diag_cfm, device, is_inn=is_inn
    )

    # Compute round-trip error using surrogate if available
    roundtrip_metrics = {}
    if surrogate_model is not None:
        print("  Computing round-trip metrics (using surrogate)...")
        roundtrip_metrics = compute_roundtrip_metrics(
            model,
            val_loader,
            diag_cfm,
            device,
            surrogate_model,
            num_samples=num_inverse_samples,
            is_inn=is_inn,
        )

    print("  Computing diversity metrics...")
    diversity_metrics = compute_diversity_metrics(
        model,
        val_loader,
        diag_cfm,
        device,
        num_samples=num_inverse_samples,
        is_inn=is_inn,
    )

    # Combine all metrics
    results = {
        **forward_metrics,
        **roundtrip_metrics,
        **diversity_metrics,
    }

    return results


def evaluate_loaded_ensemble(
    models: List[torch.nn.Module],
    run_infos: List[Dict],
    checkpoint_names: List[str],
    device: torch.device,
    surrogate_model=None,
    num_inverse_samples: int = 10,
) -> List[Dict]:
    """Evaluate an already-loaded ensemble of models.

    Args:
        models: List of loaded PyTorch models.
        run_infos: List of run info dictionaries.
        checkpoint_names: List of checkpoint directory names.
        device: Torch device.
        surrogate_model: Surrogate model for round-trip error computation.
        num_inverse_samples: Number of samples for inverse evaluation.

    Returns:
        List of result dictionaries for each model.
    """
    all_results = []

    for i, (model, run_info, ckpt_name) in enumerate(
        zip(models, run_infos, checkpoint_names)
    ):
        is_inn = run_info.get("model_type") == "INN"
        if is_inn:
            run_name = f"INN_nb{run_info.get('num_blocks')}_ep{run_info['epochs']}"
        else:
            run_name = f"diag{run_info['diag_cfm']}_ep{run_info['epochs']}"
        print(f"\nEvaluating model {i+1}/{len(models)}: {run_name}")

        results = evaluate_single_model(
            model,
            run_info,
            device,
            surrogate_model=surrogate_model,
            num_inverse_samples=num_inverse_samples,
        )

        # Build run_info for results
        if is_inn:
            results["run_info"] = {
                "model_type": "INN",
                "epochs": run_info["epochs"],
                "num_blocks": run_info.get("num_blocks"),
                "hidden_dim": run_info.get("hidden_dim"),
                "subnet_depth": run_info.get("subnet_depth"),
            }
        else:
            results["run_info"] = {
                "diag_cfm": run_info["diag_cfm"],
                "epochs": run_info["epochs"],
                "model_hidden_dimension": run_info.get("model_hidden_dimension"),
                "model_depth": run_info.get("model_depth"),
            }
        results["checkpoint_name"] = ckpt_name
        results["num_parameters"] = sum(p.numel() for p in model.parameters())
        all_results.append(results)

    return all_results


def compute_unifoil_summary(results: List[Dict]) -> Dict:
    """Compute summary statistics for Unifoil results.

    Args:
        results: List of result dictionaries from evaluate_single_model

    Returns:
        Dictionary with summary statistics grouped by model type
    """
    diag_cfm_results = {"metrics": {}}
    vanilla_cfm_results = {"metrics": {}}
    inn_results = {"metrics": {}}

    metric_keys = [
        "forward_mse",
        "roundtrip_error_mean",
        "design_diversity_mean_var",
    ]

    for r in results:
        is_inn = r.get("run_info", {}).get("model_type") == "INN"
        is_diag = r.get("run_info", {}).get("diag_cfm", False)

        if is_inn:
            target = inn_results
        elif is_diag:
            target = diag_cfm_results
        else:
            target = vanilla_cfm_results

        for key in metric_keys:
            if key in r:
                if key not in target["metrics"]:
                    target["metrics"][key] = []
                target["metrics"][key].append(r[key])

    summary = {}

    for model_type, data in [
        ("Diag-CFM", diag_cfm_results),
        ("CFM", vanilla_cfm_results),
        ("INN", inn_results),
    ]:
        if not data["metrics"]:
            continue

        summary[model_type] = {}
        for key, values in data["metrics"].items():
            summary[model_type][key] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "n": len(values),
            }

    return summary


def evaluate_diversity_vs_epsilon(
    num_samples: int = 10,
    num_epsilon_points: int = 50,
    num_test_samples: int = 5000,
):
    """
    Evaluate diversity vs epsilon for Unifoil models and generate combined plot.

    Args:
        num_samples: Number of design samples per target.
        num_epsilon_points: Number of epsilon values to evaluate.
        num_test_samples: Number of test samples to use (default 5000).
    """
    import numpy as np
    from tqdm import tqdm
    from uq_diagcfm.paths import PAPER_FIGURES_DIR, ensure_paper_dirs_exist

    device = get_device()
    print(f"Using device: {device}")

    # Load validation dataset (limited to num_test_samples for speed)
    full_val_dataset = UnifoilDataset(split="val")
    # Use random subset if dataset is larger than num_test_samples
    if len(full_val_dataset) > num_test_samples:
        import random

        random.seed(42)  # Fixed seed for reproducibility
        indices = random.sample(range(len(full_val_dataset)), num_test_samples)
        val_dataset = torch.utils.data.Subset(full_val_dataset, indices)
    else:
        val_dataset = full_val_dataset
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)
    print(
        f"Validation samples: {len(val_dataset)} (random subset of {len(full_val_dataset)} total)"
    )

    # Load surrogate model for round-trip error computation
    try:
        surrogate_model = make_unifoil_surrogate()
        print("Surrogate model loaded successfully.")
    except FileNotFoundError:
        print("Warning: Surrogate model not found. Cannot run diversity analysis.")
        return

    # Define epsilon values to sweep (log scale)
    epsilon_values = np.logspace(-5, 0, num_epsilon_points)

    results = {}

    # Ensemble loading functions for each model type
    ensemble_loaders = [
        ("Diag-CFM", load_unifoil_diag_cfm_ensemble),
        ("CFM", load_unifoil_cfm_ensemble),
        ("INN", load_unifoil_inn_ensemble),
    ]

    for model_name, load_fn in ensemble_loaders:
        print(f"\n{'='*60}")
        print(f"Loading {model_name} model...")

        try:
            models, run_infos, _, criteria = load_fn(device=device, verbose=True)
        except Exception as e:
            print(f"Failed to load {model_name}: {e}")
            continue

        if not models:
            print(f"No {model_name} models found matching criteria")
            continue

        # Use the first model (sorted by train loss)
        model = models[0]
        run_info = run_infos[0]
        is_inn = run_info.get("model_type") == "INN"
        diag_cfm = run_info.get("diag_cfm", False)

        print(f"Evaluating {model_name}...")

        all_x_samples = []
        all_errors = []

        model.eval()
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Processing batches"):
                x, physical_params, y = batch
                physical_params = physical_params.to(device)
                y = y.to(device)

                # Generate designs with conditioning
                if is_inn:
                    x_samples = inverse_pass_conditional_inn(
                        model, y, physical_params, P, device, num_samples=num_samples
                    )
                else:
                    x_samples = inverse_pass(
                        model,
                        y,
                        P,
                        L,
                        device,
                        num_samples,
                        diag_cfm=diag_cfm,
                        conditioning=physical_params,
                    )

                # Compute round-trip errors using surrogate
                roundtrip_errors = []
                for i in range(num_samples):
                    x_gen = x_samples[i]
                    x_gen = torch.clamp(x_gen, -10, 10)
                    surrogate_input = torch.cat([x_gen, physical_params], dim=1)
                    y_pred = surrogate_model(surrogate_input)
                    error = torch.mean((y - y_pred) ** 2, dim=1)
                    roundtrip_errors.append(error)
                roundtrip_errors = torch.stack(roundtrip_errors, dim=0)

                all_x_samples.append(x_samples)
                all_errors.append(roundtrip_errors)

        # Concatenate all batches
        x_samples = torch.cat(all_x_samples, dim=1)
        roundtrip_errors = torch.cat(all_errors, dim=1)

        # Compute diversity for each epsilon
        model_results = compute_diversity_vs_epsilon(
            x_samples, roundtrip_errors, epsilon_values
        )
        results[model_name] = model_results

        print(f"  Diversity at max eps: {model_results['diversity_var_mean'][-1]:.6f}")
        print(f"  Valid samples at max eps: {model_results['num_valid_mean'][-1]:.1f}")

    if not results:
        print("No models were successfully evaluated!")
        return

    # Create combined plot
    ensure_paper_dirs_exist()
    output_path = PAPER_FIGURES_DIR / "diversity_vs_epsilon_unifoil.pdf"
    plot_diversity_vs_epsilon_combined(
        results,
        output_path,
        num_samples=num_samples,
        title="Unifoil: Diversity vs Round-Trip Error Threshold",
    )

    # Print description
    print_diversity_vs_epsilon_description()


if __name__ == "__main__":
    import sys

    # Parse command line arguments
    # Usage: python -m uq_diagcfm.evaluate_results_unifoil [--all|--cfm|--inn|--diversity]
    model_filter = (
        None  # None = diag-cfm only, "all" = both, "cfm" = cfm only, "inn" = inn only
    )
    diversity_only = False

    for arg in sys.argv[1:]:
        if arg == "--all":
            model_filter = "all"
        elif arg == "--cfm":
            model_filter = "cfm"
        elif arg == "--inn":
            model_filter = "inn"
        elif arg == "--diversity":
            diversity_only = True

    if diversity_only:
        evaluate_diversity_vs_epsilon()
        sys.exit(0)

    print("=" * 70)
    print("UNIFOIL FORWARD/INVERSE EVALUATION")
    print("=" * 70)
    print("Note: Round-trip error computed using surrogate model")

    from uq_diagcfm.paths import RESULTS_DIR

    device = get_device()
    print(f"Using device: {device}")

    # Load surrogate model for round-trip error computation
    surrogate_model = None
    try:
        print("Loading surrogate model for round-trip error computation...")
        surrogate_model = make_unifoil_surrogate()
        print("  Surrogate model loaded successfully.")
    except FileNotFoundError:
        print("  Warning: Surrogate model not found. Skipping round-trip error.")

    all_results = []

    # Evaluate Diag-CFM models
    if model_filter not in ["cfm", "inn"]:
        print("\n--- Evaluating Diag-CFM models ---")
        diag_cfm_models, diag_cfm_run_infos, diag_cfm_ckpt_names, diag_cfm_criteria = (
            load_unifoil_diag_cfm_ensemble(device=device, verbose=True)
        )
        print(f"Diag-CFM criteria: {diag_cfm_criteria}")
        if diag_cfm_models:
            diag_cfm_results = evaluate_loaded_ensemble(
                diag_cfm_models,
                diag_cfm_run_infos,
                diag_cfm_ckpt_names,
                device,
                surrogate_model,
                num_inverse_samples=10,
            )
            all_results.extend(diag_cfm_results)
        else:
            print("No Diag-CFM models found")

    # Evaluate vanilla CFM models
    if model_filter in ["all", "cfm"]:
        print("\n--- Evaluating CFM models ---")
        cfm_models, cfm_run_infos, cfm_ckpt_names, cfm_criteria = (
            load_unifoil_cfm_ensemble(device=device, verbose=True)
        )
        print(f"CFM criteria: {cfm_criteria}")
        if cfm_models:
            cfm_results = evaluate_loaded_ensemble(
                cfm_models,
                cfm_run_infos,
                cfm_ckpt_names,
                device,
                surrogate_model,
                num_inverse_samples=10,
            )
            all_results.extend(cfm_results)
        else:
            print("No CFM models found")

    # Evaluate INN models
    if model_filter in ["all", "inn"]:
        print("\n--- Evaluating INN models ---")
        inn_models, inn_run_infos, inn_ckpt_names, inn_criteria = (
            load_unifoil_inn_ensemble(device=device, verbose=True)
        )
        print(f"INN criteria: {inn_criteria}")
        if inn_models:
            inn_results = evaluate_loaded_ensemble(
                inn_models,
                inn_run_infos,
                inn_ckpt_names,
                device,
                surrogate_model,
                num_inverse_samples=10,
            )
            all_results.extend(inn_results)
        else:
            print("No INN models found")

    if not all_results:
        print("No models found for evaluation!")
        sys.exit(1)

    # Validate that all models of the same type have the same parameter count
    # Extract run_info from results for validation
    run_infos_for_validation = [r.get("run_info", r) for r in all_results]
    # Add number_of_parameters to run_info if not present
    for r, run_info in zip(all_results, run_infos_for_validation):
        if "number_of_parameters" not in run_info:
            run_info["number_of_parameters"] = r.get("num_parameters")
    validate_ensemble_parameter_counts(run_infos_for_validation)

    # Compute combined summary
    summary = compute_unifoil_summary(all_results)

    # Print summary
    print_summary(summary, title="UNIFOIL EVALUATION SUMMARY")

    # Save results
    output_dir = RESULTS_DIR
    output_dir.mkdir(exist_ok=True, parents=True)
    output_path = output_dir / "unifoil_evaluation_results.json"
    save_results(
        all_results,
        summary,
        output_path,
        extra_fields={"dataset": UNIFOIL_DATASET_NAME},
    )

    # Run diversity vs epsilon analysis
    print("\n" + "=" * 70)
    print("DIVERSITY VS EPSILON ANALYSIS")
    print("=" * 70)
    evaluate_diversity_vs_epsilon()
