"""
Evaluate uncertainty quantification with OOD detection for DTLZ benchmark.

This script evaluates Diag-CFM models trained on DTLZ test functions,
measuring their ability to detect out-of-distribution generation requests
using five uncertainty quantification metrics:

Diag-CFM Specific Metrics:
- Zero-Deviation: ||output[:L]||² at t=0 after synthesis (no extra passes)
- Self-Consistency: ||y_reconstructed - y*||² via analysis pass (1 extra pass)

General-Purpose Metrics:
- Ensemble Variance: Variance across ensemble model predictions
- FM Loss: Flow matching loss at t=0.5

Usage:
    python -m uq_diagcfm.evaluate_uq_dtlz [P] [epochs] [nb_samples] [difficulty]

Examples:
    python -m uq_diagcfm.evaluate_uq_dtlz 12 50 1000 hard
    python -m uq_diagcfm.evaluate_uq_dtlz 50 50 1000 hard
"""

import json
from typing import Dict, List, Tuple

import numpy as np
import torch

from uq_diagcfm.data_utils_dtlz import (
    DTLZ_DATASET_NAME,
    DTLZDataset,
)
from uq_diagcfm.ensembles import load_dtlz_diag_cfm_ensemble
from uq_diagcfm.ood_utils import get_ood_points_by_difficulty, DIFFICULTY_CONFIGS
from uq_diagcfm.solvers import euler_method
from uq_diagcfm.uq_evaluation_utils import (
    compute_ensemble_variance,
    compute_fm_losses,
    compute_zero_deviation,
    compute_self_consistency,
    create_y_complement,
    create_all_plots,
    print_uq_summary,
    print_auc_summary,
    sample_in_dist_and_ood,
)
from uq_diagcfm.utils import get_device


# =============================================================================
# Main Evaluation Function
# =============================================================================


def evaluate_uq_dtlz(
    models: List[torch.nn.Module],
    run_infos: List[Dict],
    checkpoint_names: List[str],
    num_design_params: int,
    num_objectives: int,
    function_name: str = "dtlz2",
    nb_samples: int = 1000,
    sampling_strategy: str = "stratified",
    g_max: float = 2.0,
    ood_difficulty: str = "hard",
    euler_steps: int = 30,
    seed: int = 42,
) -> Tuple[Dict, Dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Evaluate all UQ metrics with OOD detection for DTLZ.

    This function computes:
    - Zero-Deviation (Diag-CFM specific, no extra passes)
    - Self-Consistency (Diag-CFM specific, 1 extra pass)
    - Ensemble Variance (general-purpose)
    - FM Loss (general-purpose)

    Models must be loaded using load_dtlz_diag_cfm_ensemble from ensembles.py.

    Args:
        models: Pre-loaded ensemble models from ensembles.py.
        run_infos: Run info dictionaries for each model.
        checkpoint_names: Checkpoint names for each model.
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        function_name: DTLZ function name.
        nb_samples: Number of samples per category (in-dist and OOD).
        sampling_strategy: Sampling strategy for DTLZ data.
        g_max: Maximum g value for DTLZ data.
        ood_difficulty: OOD difficulty level ("easy", "medium", "hard").
        euler_steps: Number of Euler integration steps.
        seed: Random seed for reproducible sampling of in-dist and OOD points (default: 42).

    Returns:
        Tuple of:
        - results: Summary statistics dictionary
        - metrics: Dictionary of metric arrays
        - ood_mask: Binary OOD mask
        - in_dist_labels: In-distribution labels (numpy)
        - ood_labels: OOD labels (numpy)
        - train_labels: Training labels (numpy)
    """
    device = get_device()
    print(f"Using device: {device}")

    # Use pre-loaded models (must be loaded via ensembles.py functions)
    print("\n" + "=" * 70)
    print(f"Using {len(models)} pre-loaded ensemble models")
    reference_model = models[0]
    ensemble_num_params = [sum(p.numel() for p in m.parameters()) for m in models]

    # =========================================================================
    # Load Data
    # =========================================================================
    print("\n" + "=" * 70)
    print("Loading datasets...")

    # Validation dataset (in-distribution)
    val_dataset = DTLZDataset(
        split="val",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=sampling_strategy,
        g_max=g_max,
    )
    val_labels = val_dataset.labels

    # Training dataset (for OOD generation)
    train_dataset = DTLZDataset(
        split="train",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=sampling_strategy,
        g_max=g_max,
    )
    train_labels = train_dataset.labels[:50000]  # Subset for efficiency

    print(f"Validation labels: {val_labels.shape}")
    print(f"Training labels (for OOD): {train_labels.shape}")

    # =========================================================================
    # Generate OOD Points
    # =========================================================================
    print("\n" + "=" * 70)
    print(f"Generating OOD points (difficulty={ood_difficulty})...")

    ood_points, ood_dists, ood_config = get_ood_points_by_difficulty(
        train_labels,
        difficulty=ood_difficulty,
        n_points=nb_samples,
        grid_steps=30,
    )
    print(f"OOD config: {ood_config}")

    # Sample in-distribution and OOD points with fixed seed for reproducibility
    in_dist_labels, ood_labels, ood_mask = sample_in_dist_and_ood(
        val_labels=val_labels,
        ood_points=ood_points,
        nb_samples=nb_samples,
        device=device,
        seed=seed,
    )

    # Concatenate all labels
    all_labels = torch.cat([in_dist_labels, ood_labels], dim=0)
    total_samples = all_labels.shape[0]

    # =========================================================================
    # Synthesis Pass with Zero-Deviation
    # =========================================================================
    print("\n" + "=" * 70)
    print("Running synthesis pass...")

    # Augment labels with noise complement
    augmented_labels = torch.cat([
        all_labels,
        create_y_complement(total_samples, num_design_params, num_objectives, device),
    ], dim=1)

    # Run synthesis pass
    with torch.no_grad():
        synthesis_result = euler_method(
            model=reference_model,
            input=augmented_labels,
            start_t=1,
            end_t=0,
            steps=euler_steps,
        )

    # Compute zero-deviation from synthesis output
    zero_deviation = compute_zero_deviation(synthesis_result, num_objectives)
    print(f"Zero-deviation: mean={zero_deviation.mean():.6f}, "
          f"std={zero_deviation.std():.6f}")

    # Extract design parameters
    simulated_designs = synthesis_result[:, num_objectives:]
    print(f"Simulated designs shape: {simulated_designs.shape}")

    # =========================================================================
    # Self-Consistency
    # =========================================================================
    print("\n" + "=" * 70)
    print("Computing self-consistency...")

    self_consistency = compute_self_consistency(
        model=reference_model,
        simulated_designs=simulated_designs,
        target_labels=all_labels,
        num_labels=num_objectives,
        device=device,
        conditioning=None,
        steps=euler_steps,
    )
    print(f"Self-consistency: mean={self_consistency.mean():.6f}, "
          f"std={self_consistency.std():.6f}")

    # =========================================================================
    # Ensemble Variance
    # =========================================================================
    print("\n" + "=" * 70)
    print("Computing ensemble variance...")

    ensemble_variance = compute_ensemble_variance(
        simulated_designs=simulated_designs,
        models=models,
        num_labels=num_objectives,
        device=device,
        conditioning=None,
        steps=euler_steps,
    )
    print(f"Ensemble variance: mean={ensemble_variance.mean():.6f}, "
          f"std={ensemble_variance.std():.6f}")

    # =========================================================================
    # FM Loss
    # =========================================================================
    print("\n" + "=" * 70)
    print("Computing FM losses...")

    fm_loss = compute_fm_losses(
        simulated_designs=simulated_designs,
        augmented_labels=augmented_labels,
        model=reference_model,
        num_labels=num_objectives,
        device=device,
        conditioning=None,
        batch_size=100,
    )
    print(f"FM loss: mean={fm_loss.mean():.6f}")

    # =========================================================================
    # Compile Results
    # =========================================================================
    print("\n" + "=" * 70)
    print("Compiling results...")

    # Collect all UQ metrics for plotting (canonical order)
    metrics = {
        "zero_deviation": zero_deviation.detach().cpu().numpy(),
        "self_consistency": self_consistency.detach().cpu().numpy(),
        "ensemble_variance": ensemble_variance.detach().cpu().numpy(),
        "fm_loss": fm_loss.detach().numpy(),
    }

    # Create summary statistics
    def compute_stats(arr):
        return {
            "mean": float(np.mean(arr)),
            "std": float(np.std(arr)),
            "min": float(np.min(arr)),
            "max": float(np.max(arr)),
        }

    results = {
        "dataset": DTLZ_DATASET_NAME,
        "function_name": function_name,
        "num_design_params": num_design_params,
        "num_objectives": num_objectives,
        "ood_difficulty": ood_difficulty,
        "ood_config": ood_config,
        "n_in_dist": int(in_dist_labels.shape[0]),
        "n_ood": int(ood_labels.shape[0]),
        "n_total": int(total_samples),
        "n_ensemble_models": len(models),
        "checkpoint_names": checkpoint_names,
        "num_parameters": ensemble_num_params,
    }

    # Add metric statistics
    for name, values in metrics.items():
        results[name] = compute_stats(values)

    # Get label arrays for plotting
    in_dist_labels_np = in_dist_labels.detach().cpu().numpy()
    ood_labels_np = ood_labels.detach().cpu().numpy()
    train_labels_np = train_labels.numpy()

    return results, metrics, ood_mask, in_dist_labels_np, ood_labels_np, train_labels_np


# =============================================================================
# Main Entry Point
# =============================================================================


if __name__ == "__main__":
    import sys
    from uq_diagcfm.paths import RESULTS_UQ_DIR, PAPER_FIGURES_DIR, ensure_paper_dirs_exist

    # Parse command line arguments
    num_design_params = int(sys.argv[1]) if len(sys.argv) > 1 else 12
    nb_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 1000
    ood_difficulty = sys.argv[3] if len(sys.argv) > 3 else "hard"
    sampling_strategy = "stratified"
    num_objectives = 3
    function_name = "dtlz2"
    max_models = 5

    # Validate difficulty
    if ood_difficulty not in DIFFICULTY_CONFIGS:
        print(f"Invalid difficulty: {ood_difficulty}. Use one of: {list(DIFFICULTY_CONFIGS.keys())}")
        sys.exit(1)

    # Load models using canonical ensemble function
    print("=" * 80)
    print(f"EVALUATING UQ WITH OOD DETECTION FOR DTLZ (P={num_design_params})")
    print("=" * 80)

    print("Loading ensemble models...")
    models, run_infos, checkpoint_names, criteria = load_dtlz_diag_cfm_ensemble(P=num_design_params)
    epochs = criteria["epochs"]

    # Limit models if requested
    if max_models is not None and len(models) > max_models:
        models = models[:max_models]
        run_infos = run_infos[:max_models]
        checkpoint_names = checkpoint_names[:max_models]

    print(f"Epochs: {epochs}, OOD difficulty: {ood_difficulty}")

    # Run evaluation
    results, metrics, ood_mask, in_dist_labels, ood_labels, train_labels = evaluate_uq_dtlz(
        models=models,
        run_infos=run_infos,
        checkpoint_names=checkpoint_names,
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        nb_samples=nb_samples,
        sampling_strategy=sampling_strategy,
        ood_difficulty=ood_difficulty,
    )

    # Save results
    ensure_paper_dirs_exist()
    RESULTS_UQ_DIR.mkdir(exist_ok=True, parents=True)

    # Save numpy data
    data_file = RESULTS_UQ_DIR / f"uq_ood_dtlz_P{num_design_params}_{ood_difficulty}_ep{epochs}.npz"
    np.savez(data_file, ood_mask=ood_mask, **metrics)
    print(f"\nData saved to: {data_file}")

    # Save JSON results
    json_file = RESULTS_UQ_DIR / f"uq_ood_dtlz_P{num_design_params}_{ood_difficulty}_ep{epochs}.json"
    with open(json_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to: {json_file}")

    # Print summary
    print_uq_summary(results, error_metric_name="self_consistency")

    # Create all plots and get AUC scores
    auc_scores = create_all_plots(
        metrics=metrics,
        ood_mask=ood_mask,
        in_dist_labels=in_dist_labels,
        ood_labels=ood_labels,
        output_dir=PAPER_FIGURES_DIR,
        dataset_name=f"DTLZ_P{num_design_params}",
        train_labels=train_labels,
    )

    # Print AUC summary
    print_auc_summary(auc_scores)
