"""
Evaluate uncertainty quantification with OOD detection for Gas Turbine dataset.

This script evaluates Diag-CFM models on the Gas Turbine dataset,
measuring their ability to detect out-of-distribution generation requests.

Includes:
- Ensemble variance
- FM loss
- Surrogate error (using trained surrogates)
- Zero-deviation (label part deviation from zero in synthesis)
- Self-consistency (reconstruction error via analysis pass)

Usage:
    python -m uq_diagcfm.evaluate_uq_gas_turbine [epochs] [nb_samples] [difficulty]

Examples:
    python -m uq_diagcfm.evaluate_uq_gas_turbine 20 2500 hard
    python -m uq_diagcfm.evaluate_uq_gas_turbine 20 2500 medium
"""

import json
import numpy as np
import torch

from uq_diagcfm.data_utils_gas_turbine import (
    GasTurbineDataset,
    make_surrogates,
    LEN_PARAMETERS,
    LEN_LABELS,
)
from uq_diagcfm.solvers import euler_method
from uq_diagcfm.uq_evaluation_utils import (
    evaluate_uq_with_ood_generic,
    create_y_complement,
    compute_zero_deviation,
    compute_self_consistency,
    create_all_plots,
    print_uq_summary,
    print_auc_summary,
)
from uq_diagcfm.ood_utils import DIFFICULTY_CONFIGS
from uq_diagcfm.ensembles import load_gas_turbine_diag_cfm_ensemble


def get_validation_labels(device: torch.device) -> torch.Tensor:
    """Load validation labels from Gas Turbine dataset."""
    val_dataset = GasTurbineDataset(split="val")
    val_labels = torch.stack(
        [val_dataset[i][1] for i in range(len(val_dataset))], dim=0
    )
    return val_labels


def get_training_labels(device: torch.device) -> torch.Tensor:
    """Load training labels for OOD generation."""
    train_dataset = GasTurbineDataset(split="train")
    train_labels = torch.stack(
        [train_dataset[i][1] for i in range(min(50000, len(train_dataset)))], dim=0
    )
    return train_labels


def compute_surrogate_error(
    simulated_designs: torch.Tensor,
    all_labels: torch.Tensor,
    device: torch.device,
) -> torch.Tensor:
    """Compute surrogate prediction error using ground truth surrogates."""
    # Load surrogate models
    surrogate_Unmix_O, surrogate_IO_PD, surrogate_IFD1 = make_surrogates()
    surrogate_Unmix_O = surrogate_Unmix_O.to(device)
    surrogate_IO_PD = surrogate_IO_PD.to(device)
    surrogate_IFD1 = surrogate_IFD1.to(device)

    # Compute surrogate predictions
    label_1 = surrogate_Unmix_O(simulated_designs)
    label_2 = surrogate_IO_PD(simulated_designs)
    label_3 = surrogate_IFD1(simulated_designs)
    surrogate_labels = torch.cat((label_1, label_2, label_3), dim=1)

    # Surrogate error (MSE)
    surrogate_errors = ((surrogate_labels - all_labels) ** 2).mean(dim=1)
    return surrogate_errors


def evaluate_uq_with_ood(nb_samples=2500, max_models=5, ood_difficulty="hard"):
    """Evaluate UQ with in-distribution and OOD labels for Gas Turbine.

    Models are loaded using load_gas_turbine_diag_cfm_ensemble() from ensembles.py.

    Args:
        nb_samples: Number of samples per category.
        max_models: Maximum number of models to use from ensemble.
        ood_difficulty: OOD difficulty level ("easy", "medium", "hard").

    Returns:
        Tuple of (results, metrics_dict, ood_mask, in_dist_labels, ood_labels, train_labels, epochs).
    """
    # Load models using the canonical ensemble function
    print("Loading ensemble models...")
    models, run_infos, checkpoint_names, criteria = load_gas_turbine_diag_cfm_ensemble()
    epochs = criteria["epochs"]

    # Limit models if requested
    if max_models is not None and len(models) > max_models:
        models = models[:max_models]
        run_infos = run_infos[:max_models]
        checkpoint_names = checkpoint_names[:max_models]

    # Use generic evaluation with intermediate data
    (
        results,
        ensemble_variance,
        fm_losses,
        surrogate_errors,
        ood_mask,
        intermediate_data,
    ) = evaluate_uq_with_ood_generic(
        models=models,
        run_infos=run_infos,
        checkpoint_names=checkpoint_names,
        get_validation_labels=get_validation_labels,
        get_training_labels=get_training_labels,
        compute_error_metric=compute_surrogate_error,
        error_metric_name="surrogate_error",
        num_design_params=LEN_PARAMETERS,
        num_labels=LEN_LABELS,
        get_conditioning=None,  # Gas Turbine has no conditioning
        nb_samples=nb_samples,
        ood_difficulty=ood_difficulty,
        return_intermediate_data=True,
    )

    from uq_diagcfm.utils import get_device
    device = get_device()

    reference_model = intermediate_data["reference_model"]
    all_labels = intermediate_data["all_labels"]
    simulated_designs = intermediate_data["simulated_designs"]

    # Compute zero-deviation
    print("Computing zero-deviation...")
    augmented_labels = torch.cat([
        all_labels,
        create_y_complement(all_labels.shape[0], LEN_PARAMETERS, LEN_LABELS, device),
    ], dim=1)

    with torch.no_grad():
        synthesis_output = euler_method(
            model=reference_model,
            input=augmented_labels,
            start_t=1,
            end_t=0,
            steps=30,
        )
    zero_deviation = compute_zero_deviation(synthesis_output, LEN_LABELS)
    print(f"Zero-deviation: mean={zero_deviation.mean():.6f}, std={zero_deviation.std():.6f}")

    # Compute self-consistency
    print("Computing self-consistency...")
    self_consistency = compute_self_consistency(
        model=reference_model,
        simulated_designs=simulated_designs,
        target_labels=all_labels,
        num_labels=LEN_LABELS,
        device=device,
        conditioning=None,
        steps=30,
    )
    print(f"Self-consistency: mean={self_consistency.mean():.6f}, std={self_consistency.std():.6f}")

    # Add new metrics to results
    results["zero_deviation"] = {
        "mean": float(zero_deviation.mean()),
        "std": float(zero_deviation.std()),
        "min": float(zero_deviation.min()),
        "max": float(zero_deviation.max()),
    }
    results["self_consistency"] = {
        "mean": float(self_consistency.mean()),
        "std": float(self_consistency.std()),
        "min": float(self_consistency.min()),
        "max": float(self_consistency.max()),
    }

    # Collect all UQ metrics for plotting (excluding error metrics like surrogate_error)
    metrics = {
        "zero_deviation": zero_deviation.detach().cpu().numpy(),
        "self_consistency": self_consistency.detach().cpu().numpy(),
        "ensemble_variance": ensemble_variance,
        "fm_loss": fm_losses,
    }

    # Get label data for point cloud plots
    n_in_dist = len(ood_mask) - ood_mask.sum()
    in_dist_labels = all_labels[:n_in_dist].detach().cpu().numpy()
    ood_labels = all_labels[n_in_dist:].detach().cpu().numpy()
    train_labels = get_training_labels(device).numpy()

    return results, metrics, ood_mask, in_dist_labels, ood_labels, train_labels, epochs


if __name__ == "__main__":
    import sys

    from uq_diagcfm.paths import RESULTS_UQ_DIR, PAPER_FIGURES_DIR

    # Parse command line arguments
    nb_samples = int(sys.argv[1]) if len(sys.argv) > 1 else 1000
    ood_difficulty = sys.argv[2] if len(sys.argv) > 2 else "hard"

    # Validate difficulty
    if ood_difficulty not in DIFFICULTY_CONFIGS:
        print(f"Invalid difficulty: {ood_difficulty}. Use one of: {list(DIFFICULTY_CONFIGS.keys())}")
        sys.exit(1)

    # Run evaluation (models loaded via load_gas_turbine_diag_cfm_ensemble)
    print("=" * 80)
    print("EVALUATING UQ WITH OOD DETECTION FOR GAS TURBINE")
    print("=" * 80)

    results, metrics, ood_mask, in_dist_labels, ood_labels, train_labels, epochs = evaluate_uq_with_ood(
        nb_samples=nb_samples, max_models=5, ood_difficulty=ood_difficulty
    )
    print(f"Epochs: {epochs}, Samples per category: {nb_samples}, OOD difficulty: {ood_difficulty}")

    # Save results
    output_prefix = f"uq_ood_gas_turbine_{ood_difficulty}"

    # Save JSON results
    json_file = RESULTS_UQ_DIR / f"{output_prefix}_ep{epochs}.json"
    with open(json_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to: {json_file}")

    # Save NPZ data
    npz_file = RESULTS_UQ_DIR / f"{output_prefix}_ep{epochs}.npz"
    np.savez(npz_file, ood_mask=ood_mask, **metrics)
    print(f"Data saved to: {npz_file}")

    # Print summary
    print_uq_summary(results, error_metric_name="surrogate_error")

    # Create all plots and get AUC scores
    auc_scores = create_all_plots(
        metrics=metrics,
        ood_mask=ood_mask,
        in_dist_labels=in_dist_labels,
        ood_labels=ood_labels,
        output_dir=PAPER_FIGURES_DIR,
        dataset_name="Gas_Turbine",
        train_labels=train_labels,
    )

    # Print AUC summary
    print_auc_summary(auc_scores)
