"""
Evaluation script for computing forward and inverse performance metrics for Gas Turbine.

Supports both Diag-CFM/CFM models and INN models.
Uses ground truth surrogate models for round-trip error computation.
"""

import torch
import numpy as np
from pathlib import Path
import json
from typing import Dict, List

from uq_diagcfm.utils import get_device
from uq_diagcfm.ensembles import (
    load_gas_turbine_diag_cfm_ensemble,
    load_gas_turbine_cfm_ensemble,
    load_gas_turbine_inn_ensemble,
)
from uq_diagcfm.data_utils_gas_turbine import (
    GasTurbineDataset,
    GAS_TURBINE_DATASET_NAME,
    LEN_PARAMETERS as GT_LEN_PARAMETERS,
    LEN_LABELS as GT_LEN_LABELS,
)
from uq_diagcfm.evaluation_utils import (
    forward_pass,
    inverse_pass,
    forward_pass_inn,
    inverse_pass_inn,
    compute_mse_metrics,
    compute_design_diversity,
    compute_roundtrip_errors,
    compute_summary_statistics,
    save_results,
    print_summary,
    validate_ensemble_parameter_counts,
    compute_diversity_vs_epsilon,
    plot_diversity_vs_epsilon_combined,
    print_diversity_vs_epsilon_description,
)


def compute_forward_metrics(model, dataloader, diag_cfm, P, L, device, is_inn=False):
    """Compute forward performance metrics (MSE, per-label MSE)."""
    all_y_true = []
    all_y_pred = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            if len(batch) == 2:
                x, y = batch
            elif len(batch) == 3:
                x, _, y = batch
            else:
                raise ValueError("Unexpected batch length")

            x = x.to(device)
            y = y.to(device)

            # Forward pass
            if is_inn:
                y_pred = forward_pass_inn(model, x, device)
            else:
                y_pred = forward_pass(model, x, L, device, diag_cfm=diag_cfm)

            all_y_true.append(y.cpu())
            all_y_pred.append(y_pred.cpu())

    y_true = torch.cat(all_y_true, dim=0).numpy()
    y_pred = torch.cat(all_y_pred, dim=0).numpy()

    return compute_mse_metrics(y_true, y_pred)


def compute_inverse_metrics(
    model,
    dataloader,
    diag_cfm,
    P,
    L,
    device,
    surrogate_model=None,
    num_samples=10,
    is_inn=False,
):
    """Compute inverse performance metrics (round-trip error, design diversity).

    Args:
        model: The flow matching model or INN model
        dataloader: DataLoader with test data
        diag_cfm: Whether using Diag-CFM (vs vanilla CFM) - ignored for INN
        P: Number of design parameters
        L: Number of labels
        device: Torch device
        surrogate_model: Ground truth surrogate for evaluating generated designs.
                        If None, uses the model's own forward pass.
        num_samples: Number of samples to generate per target
        is_inn: Whether the model is an INN

    Returns:
        Dictionary with round-trip error and diversity metrics.
    """
    all_roundtrip_errors = []
    all_diversities_var = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            if len(batch) == 2:
                x, y = batch
            elif len(batch) == 3:
                x, _, y = batch
            else:
                raise ValueError("Unexpected batch length")

            x = x.to(device)
            y = y.to(device)

            # Inverse pass: generate designs from labels
            if is_inn:
                x_samples = inverse_pass_inn(model, y, P, device, num_samples)
            else:
                x_samples = inverse_pass(
                    model,
                    y,
                    P,
                    L,
                    device,
                    num_samples,
                    diag_cfm=diag_cfm,
                    noise_distribution="uniform",
                )

            # Define forward function for round-trip error
            def forward_fn(x_gen):
                if surrogate_model is not None:
                    return surrogate_model(x_gen)
                elif is_inn:
                    return forward_pass_inn(model, x_gen, device)
                else:
                    return forward_pass(model, x_gen, L, device, diag_cfm=diag_cfm)

            # Compute round-trip errors
            errors = compute_roundtrip_errors(y, x_samples, forward_fn)
            all_roundtrip_errors.append(errors)

            # Compute design diversity
            diversity = compute_design_diversity(x_samples)
            all_diversities_var.append(diversity["diversity_var"])

    roundtrip_errors = np.concatenate(all_roundtrip_errors, axis=0)
    diversities_var = np.concatenate(all_diversities_var, axis=0)

    return {
        "inverse_roundtrip_error": float(np.mean(roundtrip_errors)),
        "inverse_design_diversity_mean_var": float(np.mean(diversities_var)),
    }


def make_gas_turbine_surrogate(device):
    """Create the ground truth surrogate model for gas turbine.

    Returns:
        A function that maps designs to labels.
    """
    from uq_diagcfm.data_utils_gas_turbine import make_surrogates

    model_Unmix_O, model_IO_PD, model_IFD1 = make_surrogates()
    model_Unmix_O = model_Unmix_O.to(device)
    model_IO_PD = model_IO_PD.to(device)
    model_IFD1 = model_IFD1.to(device)

    def surrogate_model(x: torch.Tensor) -> torch.Tensor:
        y1 = model_Unmix_O(x)
        y2 = model_IO_PD(x)
        y3 = model_IFD1(x)
        return torch.cat((y1, y2, y3), dim=1)

    return surrogate_model


def evaluate_checkpoint(model: torch.nn.Module, run_info: dict, num_inverse_samples=10):
    """Evaluate a single model on test set.

    Args:
        model: Loaded PyTorch model
        run_info: Dictionary containing run information
        num_inverse_samples: Number of samples for inverse evaluation
    """
    device = get_device()

    # Determine if this is an INN model
    is_inn = run_info.get("model_type", "") == "INN"
    P = GT_LEN_PARAMETERS
    L = GT_LEN_LABELS

    # Load test dataset
    test_dataset = GasTurbineDataset(split="test")
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=256, shuffle=False
    )

    # Load surrogate model
    surrogate_model = make_gas_turbine_surrogate(device)

    diag_cfm = run_info.get("diag_cfm", False)

    # Create a name for logging
    if is_inn:
        run_name = f"INN_hd{run_info.get('hidden_dim', 'N/A')}"
    else:
        run_name = f"diag{diag_cfm}_hd{run_info.get('model_hidden_dimension', 'N/A')}"

    # Compute metrics
    print(f"Computing forward metrics for {run_name}...")
    forward_metrics = compute_forward_metrics(
        model, test_loader, diag_cfm, P, L, device, is_inn=is_inn
    )

    print(f"Computing inverse metrics for {run_name}...")
    inverse_metrics = compute_inverse_metrics(
        model,
        test_loader,
        diag_cfm,
        P,
        L,
        device,
        surrogate_model=surrogate_model,
        num_samples=num_inverse_samples,
        is_inn=is_inn,
    )

    # Combine results
    return {
        **run_info,
        **forward_metrics,
        **inverse_metrics,
    }


def evaluate_loaded_ensemble(
    models: List[torch.nn.Module],
    run_infos: List[Dict],
    checkpoint_names: List[str],
) -> List[Dict]:
    """Evaluate an already-loaded ensemble of models.

    Args:
        models: List of loaded PyTorch models.
        run_infos: List of run info dictionaries.
        checkpoint_names: List of checkpoint directory names.

    Returns:
        List of result dictionaries for each model.
    """
    all_results = []

    for model, run_info, ckpt_name in zip(models, run_infos, checkpoint_names):
        is_inn = run_info.get("model_type", "") == "INN"
        if is_inn:
            run_name = f"INN_nb{run_info['num_blocks']}_hd{run_info['hidden_dim']}_sd{run_info['subnet_depth']}"
        else:
            run_name = f"diag{run_info['diag_cfm']}_hd{run_info['model_hidden_dimension']}_d{run_info['model_depth']}"
        print(f"\nEvaluating: {run_name}")
        try:
            results = evaluate_checkpoint(model, run_info)
            results["checkpoint_name"] = ckpt_name
            results["num_parameters"] = sum(p.numel() for p in model.parameters())
            all_results.append(results)
        except Exception as e:
            print(f"Error evaluating {run_name}: {e}")
            continue

    return all_results


def compute_gas_turbine_summary(results: List[Dict]) -> Dict:
    """Compute summary statistics with min/max for gas turbine results."""
    # Group by model type: diag_cfm, vanilla_cfm, or INN
    groups = {"Diag-CFM": [], "CFM": [], "INN": []}
    for r in results:
        is_inn = r.get("model_type", "") == "INN"
        if is_inn:
            groups["INN"].append(r)
        elif r.get("diag_cfm", False):
            groups["Diag-CFM"].append(r)
        else:
            groups["CFM"].append(r)

    metric_keys = [
        "forward_mse",
        "inverse_roundtrip_error",
        "inverse_design_diversity_mean_var",
    ]

    summary = {}
    for group_name, group_results in groups.items():
        if not group_results:
            continue

        summary[group_name] = {}
        for key in metric_keys:
            values = [r[key] for r in group_results if key in r]
            if values:
                summary[group_name][key] = {
                    "mean": float(np.mean(values)),
                    "std": float(np.std(values)),
                    "min": float(np.min(values)),
                    "max": float(np.max(values)),
                    "n": len(values),
                }

    return summary


def evaluate_epoch_checkpoints(
    run_path: Path, num_epochs: int, num_inverse_samples: int = 10
):
    """Evaluate all epoch checkpoints from a training run.

    Args:
        run_path: Path to the training run directory containing epoch checkpoints
        num_epochs: Number of epochs (checkpoints to evaluate)
        num_inverse_samples: Number of samples for inverse pass

    Returns:
        List of result dictionaries, one per epoch
    """
    from uq_diagcfm.checkpointing import load_run_info
    from uq_diagcfm.models_for_datasets import models_for_gas_turbine

    device = get_device()
    print(f"Device: {device}")

    # Load run info
    run_info = load_run_info(run_path)
    print(f"Evaluating run: {run_path.name}")

    # Create model template
    model = models_for_gas_turbine(
        diag_cfm=run_info["diag_cfm"],
        model_hidden_dimension=run_info["model_hidden_dimension"],
        model_depth=run_info["model_depth"],
        dropout=run_info.get("dropout", 0),
        model_activation=run_info.get("model_activation", "LeakyReLU"),
    ).to(device)

    # Load test data
    test_dataset = GasTurbineDataset(split="test")
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=256, shuffle=False
    )

    # Load surrogate model
    surrogate_model = make_gas_turbine_surrogate(device)

    P = GT_LEN_PARAMETERS
    L = GT_LEN_LABELS
    diag_cfm = run_info["diag_cfm"]

    all_epoch_results = []

    for epoch in range(1, num_epochs + 1):
        checkpoint_path = run_path / f"model_checkpoint_epoch{epoch}.pth"
        if not checkpoint_path.exists():
            print(
                f"Warning: Checkpoint for epoch {epoch} not found at {checkpoint_path}"
            )
            continue

        print(f"\nEvaluating epoch {epoch}...")

        # Load checkpoint
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model.eval()

        # Compute forward metrics
        forward_metrics = compute_forward_metrics(
            model, test_loader, diag_cfm, P, L, device, is_inn=False
        )

        # Compute inverse metrics
        inverse_metrics = compute_inverse_metrics(
            model,
            test_loader,
            diag_cfm,
            P,
            L,
            device,
            surrogate_model=surrogate_model,
            num_samples=num_inverse_samples,
            is_inn=False,
        )

        epoch_results = {
            "epoch": epoch,
            **forward_metrics,
            **inverse_metrics,
        }
        all_epoch_results.append(epoch_results)

        print(f"  Forward MSE: {forward_metrics['forward_mse']:.6f}")
        print(f"  Roundtrip Error: {inverse_metrics['inverse_roundtrip_error']:.6f}")
        print(
            f"  Diversity (Var): {inverse_metrics['inverse_design_diversity_mean_var']:.6f}"
        )

    return all_epoch_results, run_info


def evaluate_diversity_vs_epsilon(
    num_samples: int = 10,
    num_epsilon_points: int = 50,
):
    """
    Evaluate diversity vs epsilon for Gas Turbine models and generate combined plot.

    Args:
        num_samples: Number of design samples per target.
        num_epsilon_points: Number of epsilon values to evaluate.
    """
    from tqdm import tqdm
    import numpy as np
    from uq_diagcfm.paths import PAPER_FIGURES_DIR, ensure_paper_dirs_exist

    device = get_device()
    print(f"Using device: {device}")

    # Load test dataset
    test_dataset = GasTurbineDataset(split="test")
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=256, shuffle=False
    )
    print(f"Test samples: {len(test_dataset)}")

    # Load surrogate model
    surrogate_model = make_gas_turbine_surrogate(device)

    # Define epsilon values to sweep (log scale)
    epsilon_values = np.logspace(-5, 0, num_epsilon_points)

    P = GT_LEN_PARAMETERS
    L = GT_LEN_LABELS

    results = {}

    # Ensemble loading functions for each model type
    ensemble_loaders = [
        ("Diag-CFM", load_gas_turbine_diag_cfm_ensemble),
        ("CFM", load_gas_turbine_cfm_ensemble),
        ("INN", load_gas_turbine_inn_ensemble),
    ]

    for model_name, load_fn in ensemble_loaders:
        print(f"\n{'='*60}")
        print(f"Loading {model_name} model...")

        try:
            models, run_infos, _, criteria = load_fn(device=device, verbose=True)
        except Exception as e:
            print(f"Failed to load {model_name}: {e}")
            continue

        if not models:
            print(f"No {model_name} models found matching criteria")
            continue

        # Use the first model (sorted by train loss)
        model = models[0]
        run_info = run_infos[0]
        is_inn = run_info.get("model_type", "") == "INN"
        diag_cfm = run_info.get("diag_cfm", False)

        print(f"Evaluating {model_name}...")

        all_x_samples = []
        all_errors = []

        model.eval()
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Processing batches"):
                if len(batch) == 2:
                    x, y = batch
                elif len(batch) == 3:
                    x, _, y = batch
                else:
                    raise ValueError("Unexpected batch length")

                y = y.to(device)

                # Generate designs
                if is_inn:
                    x_samples = inverse_pass_inn(model, y, P, device, num_samples)
                else:
                    x_samples = inverse_pass(
                        model,
                        y,
                        P,
                        L,
                        device,
                        num_samples,
                        diag_cfm=diag_cfm,
                        noise_distribution="uniform",
                    )

                # Compute round-trip errors
                roundtrip_errors = []
                for i in range(num_samples):
                    x_gen = x_samples[i]
                    y_pred = surrogate_model(x_gen)
                    error = torch.mean((y - y_pred) ** 2, dim=1)
                    roundtrip_errors.append(error)
                roundtrip_errors = torch.stack(roundtrip_errors, dim=0)

                all_x_samples.append(x_samples)
                all_errors.append(roundtrip_errors)

        # Concatenate all batches
        x_samples = torch.cat(all_x_samples, dim=1)
        roundtrip_errors = torch.cat(all_errors, dim=1)

        # Compute diversity for each epsilon
        model_results = compute_diversity_vs_epsilon(
            x_samples, roundtrip_errors, epsilon_values
        )
        results[model_name] = model_results

        print(f"  Diversity at max eps: {model_results['diversity_var_mean'][-1]:.6f}")
        print(f"  Valid samples at max eps: {model_results['num_valid_mean'][-1]:.1f}")

    if not results:
        print("No models were successfully evaluated!")
        return

    # Create combined plot
    ensure_paper_dirs_exist()
    output_path = PAPER_FIGURES_DIR / "diversity_vs_epsilon_gas_turbine.pdf"
    plot_diversity_vs_epsilon_combined(
        results,
        output_path,
        num_samples=num_samples,
        title="Gas Turbine: Diversity vs Round-Trip Error Threshold",
    )

    # Print description
    print_diversity_vs_epsilon_description()


if __name__ == "__main__":
    import sys
    from uq_diagcfm.paths import CHECKPOINTS_DIR, RESULTS_DIR

    # Parse command line arguments
    # Usage: python -m uq_diagcfm.evaluate_results_gas_turbine [--diversity | --epochs <run_name> <num_epochs>]
    mode = "full"  # full, epochs, diversity
    run_name = None
    num_epochs = 20  # only used for --epochs mode

    i = 1
    while i < len(sys.argv):
        arg = sys.argv[i]
        if arg == "--diversity":
            mode = "diversity"
        elif arg == "--epochs":
            mode = "epochs"
            if i + 1 < len(sys.argv):
                run_name = sys.argv[i + 1]
                i += 1
            if i + 1 < len(sys.argv) and sys.argv[i + 1].isdigit():
                num_epochs = int(sys.argv[i + 1])
                i += 1
        i += 1

    if mode == "diversity":
        evaluate_diversity_vs_epsilon()
        sys.exit(0)

    if mode == "epochs":
        if run_name is None:
            print(
                "Usage: python -m uq_diagcfm.evaluate_results_gas_turbine --epochs <run_name> [num_epochs]"
            )
            sys.exit(1)

        run_path = CHECKPOINTS_DIR / "gas_turbine" / run_name

        if not run_path.exists():
            print(f"Run path not found: {run_path}")
            sys.exit(1)

        epoch_results, run_info = evaluate_epoch_checkpoints(run_path, num_epochs)

        # Print results table
        print("\n" + "=" * 90)
        print("EPOCH-BY-EPOCH RESULTS")
        print("=" * 90)
        print(
            f"\n{'Epoch':<8} {'Forward MSE':>14} {'Roundtrip Err':>14} "
            f"{'Diversity (Var)':>16}"
        )
        print("-" * 70)

        for r in epoch_results:
            print(
                f"{r['epoch']:<8} {r['forward_mse']:>14.6f} "
                f"{r['inverse_roundtrip_error']:>14.6f} "
                f"{r['inverse_design_diversity_mean_var']:>16.6f}"
            )

        # Save results
        output_dir = RESULTS_DIR
        output_dir.mkdir(exist_ok=True, parents=True)
        output_path = output_dir / f"epoch_evaluation_{run_name}.json"

        with open(output_path, "w") as f:
            json.dump(
                {
                    "run_name": run_name,
                    "run_info": run_info,
                    "epoch_results": epoch_results,
                },
                f,
                indent=2,
            )
        print(f"\nResults saved to: {output_path}")
        sys.exit(0)

    # Full evaluation mode
    dataset_name = GAS_TURBINE_DATASET_NAME
    device = get_device()
    print(f"Using device: {device}")

    print(f"Evaluating {dataset_name}")
    print("Note: Using only unshuffled CFM/Diag-CFM runs + all INN models")

    # Load and evaluate Diag-CFM models
    print("\n--- Evaluating Diag-CFM models ---")
    diag_cfm_models, diag_cfm_run_infos, diag_cfm_ckpt_names, diag_cfm_criteria = (
        load_gas_turbine_diag_cfm_ensemble(device=device, verbose=True)
    )
    print(f"Diag-CFM criteria: {diag_cfm_criteria}")
    diag_cfm_results = evaluate_loaded_ensemble(
        diag_cfm_models, diag_cfm_run_infos, diag_cfm_ckpt_names
    )

    # Load and evaluate vanilla CFM models
    print("\n--- Evaluating CFM models ---")
    cfm_models, cfm_run_infos, cfm_ckpt_names, cfm_criteria = (
        load_gas_turbine_cfm_ensemble(device=device, verbose=True)
    )
    print(f"CFM criteria: {cfm_criteria}")
    cfm_results = evaluate_loaded_ensemble(cfm_models, cfm_run_infos, cfm_ckpt_names)

    # Load and evaluate INN models
    print("\n--- Evaluating INN models ---")
    inn_models, inn_run_infos, inn_ckpt_names, inn_criteria = (
        load_gas_turbine_inn_ensemble(device=device, verbose=True)
    )
    print(f"INN criteria: {inn_criteria}")
    inn_results = evaluate_loaded_ensemble(inn_models, inn_run_infos, inn_ckpt_names)

    # Merge results
    results = diag_cfm_results + cfm_results + inn_results

    if not results:
        print("No matching checkpoints found!")
        sys.exit(1)

    # Validate that all models of the same type have the same parameter count
    validate_ensemble_parameter_counts(results)

    summary = compute_gas_turbine_summary(results)

    # Print summary
    print_summary(summary)

    # Save results
    output_dir = RESULTS_DIR
    output_dir.mkdir(exist_ok=True)
    output_path = output_dir / f"{dataset_name}_evaluation_results.json"
    save_results(results, summary, output_path)

    # Run diversity vs epsilon analysis
    print("\n" + "=" * 70)
    print("DIVERSITY VS EPSILON ANALYSIS")
    print("=" * 70)
    evaluate_diversity_vs_epsilon()
