"""
Evaluate uncertainty quantification with OOD detection for Unifoil dataset.

Unifoil uses conditional generation with physical parameters (angle of attack, Mach number).

Includes:
- Ensemble variance
- FM loss
- Self-consistency error (reconstruction via forward pass)
- Zero-deviation (label part deviation from zero in synthesis)

Usage:
    python -m uq_diagcfm.evaluate_uq_unifoil [epochs] [nb_samples] [difficulty]

Examples:
    python -m uq_diagcfm.evaluate_uq_unifoil 100 2500 hard
    python -m uq_diagcfm.evaluate_uq_unifoil 100 2500 medium
"""

import json
import numpy as np
import torch

from uq_diagcfm.data_utils_unifoil import (
    UnifoilDataset,
    LEN_DESIGN_PARAMETERS,
    LEN_PHYSICAL_PARAMS,
    LEN_PHYSICAL_PERFORMANCE,
)
from uq_diagcfm.solvers import euler_method_with_conditioning
from uq_diagcfm.uq_evaluation_utils import (
    evaluate_uq_with_ood_generic,
    create_y_complement,
    compute_zero_deviation,
    compute_self_consistency,
    create_all_plots,
    print_uq_summary,
    print_auc_summary,
)
from uq_diagcfm.ood_utils import DIFFICULTY_CONFIGS
from uq_diagcfm.ensembles import load_unifoil_diag_cfm_ensemble

# Module-level cache for validation physical params (needed for conditioning)
_val_physical_params_cache = None


def get_validation_labels(device: torch.device) -> torch.Tensor:
    """Load validation labels from Unifoil dataset.

    Also caches physical params for use in conditioning.
    """
    global _val_physical_params_cache

    val_dataset = UnifoilDataset(split="val")

    # Collect labels and physical params
    val_data = []
    for i in range(min(len(val_dataset), 10000)):  # Limit to avoid memory issues
        x, phys, y = val_dataset[i]
        val_data.append((x, phys, y))

    val_labels = torch.stack([torch.from_numpy(item[2]) for item in val_data], dim=0)
    _val_physical_params_cache = torch.stack(
        [torch.from_numpy(item[1]) for item in val_data], dim=0
    )

    return val_labels


def get_training_labels(device: torch.device) -> torch.Tensor:
    """Load training labels for OOD generation."""
    train_dataset = UnifoilDataset(split="train")
    train_labels = torch.stack(
        [
            torch.from_numpy(train_dataset[i][2])
            for i in range(min(50000, len(train_dataset)))
        ],
        dim=0,
    )
    return train_labels


def get_conditioning(
    in_dist_labels: torch.Tensor,
    ood_labels: torch.Tensor,
    device: torch.device,
    seed: int = 42,
) -> torch.Tensor:
    """Get physical conditioning parameters for in-dist and OOD samples.

    For in-distribution samples, uses the cached validation physical params
    with the SAME permutation that sample_in_dist_and_ood applies to labels.
    For OOD samples, samples randomly from the validation distribution.

    IMPORTANT: The seed must match the seed used in sample_in_dist_and_ood
    to ensure the physical params correspond to the correct labels.
    """
    global _val_physical_params_cache

    nb_in_dist = in_dist_labels.shape[0]
    nb_ood = ood_labels.shape[0]

    # In-distribution: replicate the same permutation used in sample_in_dist_and_ood
    # to ensure physical params match the permuted labels
    torch.manual_seed(seed)
    in_dist_perm = torch.randperm(len(_val_physical_params_cache))[:nb_in_dist]
    in_dist_physical_params = _val_physical_params_cache[in_dist_perm].to(device)

    # OOD: sample random physical params from validation distribution
    # Use a different seed to avoid correlation with in-dist sampling
    torch.manual_seed(seed + 1000)
    ood_physical_params = _val_physical_params_cache[
        torch.randint(0, len(_val_physical_params_cache), (nb_ood,))
    ].to(device)

    return torch.cat([in_dist_physical_params, ood_physical_params], dim=0)


def compute_placeholder_error(
    simulated_designs: torch.Tensor,
    all_labels: torch.Tensor,
    device: torch.device,
) -> torch.Tensor:
    """Placeholder error metric - actual computation happens after generic function."""
    return torch.zeros(simulated_designs.shape[0], device=device)


def evaluate_uq_with_ood(nb_samples=2500, max_models=5, ood_difficulty="hard"):
    """Evaluate UQ with in-distribution and OOD labels for Unifoil.

    Note: Unifoil uses self-consistency error computed via the model's forward pass,
    which requires special handling compared to other datasets.

    Models are loaded using load_unifoil_diag_cfm_ensemble() from ensembles.py.

    Args:
        nb_samples: Number of samples per category.
        max_models: Maximum number of models to use from ensemble.
        ood_difficulty: OOD difficulty level ("easy", "medium", "hard").

    Returns:
        Tuple of (results, metrics_dict, ood_mask, in_dist_labels, ood_labels, train_labels, epochs).
    """
    # Load models using the canonical ensemble function
    print("Loading ensemble models...")
    models, run_infos, checkpoint_names, criteria = load_unifoil_diag_cfm_ensemble()
    epochs = criteria["epochs"]

    # Limit models if requested
    if max_models is not None and len(models) > max_models:
        models = models[:max_models]
        run_infos = run_infos[:max_models]
        checkpoint_names = checkpoint_names[:max_models]

    # Use generic evaluation with intermediate data return
    (
        results,
        ensemble_variance,
        fm_losses,
        _,  # placeholder error metric
        ood_mask,
        intermediate_data,
    ) = evaluate_uq_with_ood_generic(
        models=models,
        run_infos=run_infos,
        checkpoint_names=checkpoint_names,
        get_validation_labels=get_validation_labels,
        get_training_labels=get_training_labels,
        compute_error_metric=compute_placeholder_error,
        error_metric_name="self_consistency_error",
        num_design_params=LEN_DESIGN_PARAMETERS,
        num_labels=LEN_PHYSICAL_PERFORMANCE,
        get_conditioning=get_conditioning,
        nb_samples=nb_samples,
        ood_grid_steps=15,  # Unifoil uses fewer grid steps
        ood_difficulty=ood_difficulty,
        return_intermediate_data=True,
    )

    from uq_diagcfm.utils import get_device
    device = get_device()

    reference_model = intermediate_data["reference_model"]
    simulated_designs = intermediate_data["simulated_designs"]
    all_labels = intermediate_data["all_labels"]
    conditioning = intermediate_data["conditioning"]

    # Compute self-consistency error using conditioning
    print("Computing self-consistency errors...")
    self_consistency_errors = compute_self_consistency(
        model=reference_model,
        simulated_designs=simulated_designs,
        target_labels=all_labels,
        num_labels=LEN_PHYSICAL_PERFORMANCE,
        device=device,
        conditioning=conditioning,
        steps=30,
    )

    # Update results with actual self-consistency error
    results["self_consistency_error"] = {
        "mean": float(self_consistency_errors.mean()),
        "std": float(self_consistency_errors.std()),
        "min": float(self_consistency_errors.min()),
        "max": float(self_consistency_errors.max()),
    }

    # Compute zero-deviation
    print("Computing zero-deviation...")
    augmented_labels = torch.cat([
        all_labels,
        create_y_complement(all_labels.shape[0], LEN_DESIGN_PARAMETERS, LEN_PHYSICAL_PERFORMANCE, device),
    ], dim=1)

    with torch.no_grad():
        synthesis_output = euler_method_with_conditioning(
            model=reference_model,
            input=augmented_labels,
            conditioning=conditioning,
            start_t=1,
            end_t=0,
            steps=30,
        )
    zero_deviation = compute_zero_deviation(synthesis_output, LEN_PHYSICAL_PERFORMANCE)
    print(f"Zero-deviation: mean={zero_deviation.mean():.6f}, std={zero_deviation.std():.6f}")

    # Add zero-deviation to results
    results["zero_deviation"] = {
        "mean": float(zero_deviation.mean()),
        "std": float(zero_deviation.std()),
        "min": float(zero_deviation.min()),
        "max": float(zero_deviation.max()),
    }

    # Collect all UQ metrics for plotting (canonical order)
    metrics = {
        "zero_deviation": zero_deviation.detach().cpu().numpy(),
        "self_consistency": self_consistency_errors.detach().cpu().numpy(),
        "ensemble_variance": ensemble_variance,
        "fm_loss": fm_losses,
    }

    # Get label data for point cloud plots
    n_in_dist = len(ood_mask) - ood_mask.sum()
    in_dist_labels = all_labels[:n_in_dist].detach().cpu().numpy()
    ood_labels_np = all_labels[n_in_dist:].detach().cpu().numpy()
    train_labels = get_training_labels(device).numpy()

    return results, metrics, ood_mask, in_dist_labels, ood_labels_np, train_labels, epochs


if __name__ == "__main__":
    import sys

    from uq_diagcfm.paths import RESULTS_UQ_DIR, PAPER_FIGURES_DIR

    # Parse command line arguments
    nb_samples = int(sys.argv[1]) if len(sys.argv) > 1 else 1000
    ood_difficulty = sys.argv[2] if len(sys.argv) > 2 else "hard"

    # Validate difficulty
    if ood_difficulty not in DIFFICULTY_CONFIGS:
        print(f"Invalid difficulty: {ood_difficulty}. Use one of: {list(DIFFICULTY_CONFIGS.keys())}")
        sys.exit(1)

    # Run evaluation (models loaded via load_unifoil_diag_cfm_ensemble)
    print("=" * 80)
    print("EVALUATING UQ WITH OOD DETECTION FOR UNIFOIL")
    print("=" * 80)

    results, metrics, ood_mask, in_dist_labels, ood_labels, train_labels, epochs = evaluate_uq_with_ood(
        nb_samples=nb_samples, max_models=5, ood_difficulty=ood_difficulty
    )
    print(f"Epochs: {epochs}, Samples per category: {nb_samples}, OOD difficulty: {ood_difficulty}")

    # Save results
    output_prefix = f"uq_ood_unifoil_{ood_difficulty}"

    # Save JSON results
    json_file = RESULTS_UQ_DIR / f"{output_prefix}_ep{epochs}.json"
    with open(json_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to: {json_file}")

    # Save NPZ data
    npz_file = RESULTS_UQ_DIR / f"{output_prefix}_ep{epochs}.npz"
    np.savez(npz_file, ood_mask=ood_mask, **metrics)
    print(f"Data saved to: {npz_file}")

    # Print summary
    print_uq_summary(results, error_metric_name="self_consistency_error")

    # Create all plots and get AUC scores
    auc_scores = create_all_plots(
        metrics=metrics,
        ood_mask=ood_mask,
        in_dist_labels=in_dist_labels,
        ood_labels=ood_labels,
        output_dir=PAPER_FIGURES_DIR,
        dataset_name="Unifoil",
        train_labels=train_labels,
    )

    # Print AUC summary
    print_auc_summary(auc_scores)
