"""
Centralized ensemble loading functions for each dataset.

This module provides the canonical way to load model ensembles. Each function
returns models with the default criteria used for evaluation, ensuring consistency
across all scripts that need to load models.

All ensemble functions return models sorted by final training loss (ascending).
"""

import torch
from typing import Tuple, List, Dict

from uq_diagcfm.checkpointing import load_ensemble
from uq_diagcfm.utils import get_device
from uq_diagcfm.data_utils_gas_turbine import GAS_TURBINE_DATASET_NAME
from uq_diagcfm.data_utils_unifoil import UNIFOIL_DATASET_NAME
from uq_diagcfm.data_utils_dtlz import DTLZ_DATASET_NAME


def _validate_ensemble_parameter_counts(
    models: List[torch.nn.Module],
    checkpoint_names: List[str],
) -> None:
    """Validate that all models in the ensemble have the same number of parameters.

    Args:
        models: List of loaded models.
        checkpoint_names: List of checkpoint directory names (for error messages).

    Raises:
        ValueError: If models have different parameter counts.
    """
    if len(models) <= 1:
        return

    param_counts = [sum(p.numel() for p in model.parameters()) for model in models]

    if len(set(param_counts)) > 1:
        # Build detailed error message
        details = "\n".join(
            f"  {ckpt}: {count:,} parameters"
            for ckpt, count in zip(checkpoint_names, param_counts)
        )
        raise ValueError(
            f"Ensemble models have inconsistent parameter counts:\n{details}"
        )


def _sort_by_final_train_loss(
    models: List[torch.nn.Module],
    run_infos: List[Dict],
    checkpoint_names: List[str],
) -> Tuple[List[torch.nn.Module], List[Dict], List[str]]:
    """Sort models by final training loss (ascending).

    Args:
        models: List of loaded models.
        run_infos: List of run info dictionaries.
        checkpoint_names: List of checkpoint directory names.

    Returns:
        Tuple of (models, run_infos, checkpoint_names) sorted by final train loss.
    """
    if not models:
        return models, run_infos, checkpoint_names

    # Create list of tuples with final train loss for sorting
    # Use train_loss from run_info if available, otherwise use a large value
    items = []
    for model, run_info, ckpt_name in zip(models, run_infos, checkpoint_names):
        # Try different keys for final train loss
        final_loss = run_info.get("final_train_loss")
        if final_loss is None:
            # Fall back to last value in train_loss_trajectory if available
            train_loss_traj = run_info.get("train_loss_trajectory", [])
            if isinstance(train_loss_traj, list) and train_loss_traj:
                final_loss = train_loss_traj[-1]
            elif isinstance(train_loss_traj, dict):
                # Handle case where it might be a dict (shouldn't happen but be safe)
                final_loss = float("inf")
            else:
                final_loss = float("inf")  # Put models without loss info at the end
        # Ensure final_loss is a number
        if not isinstance(final_loss, (int, float)):
            final_loss = float("inf")
        items.append((final_loss, model, run_info, ckpt_name))

    # Sort by final loss (ascending)
    items.sort(key=lambda x: x[0])

    # Unpack sorted items
    sorted_models = [item[1] for item in items]
    sorted_run_infos = [item[2] for item in items]
    sorted_checkpoint_names = [item[3] for item in items]

    return sorted_models, sorted_run_infos, sorted_checkpoint_names


# =============================================================================
# Gas Turbine Ensembles
# =============================================================================


def load_gas_turbine_diag_cfm_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load Diag-CFM ensemble for Gas Turbine dataset.

    Uses the default criteria:
    - epochs: 20
    - shuffle_params_seed: None (unshuffled only)
    - diag_cfm: True

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 20,
        "diag_cfm": True,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_gas_turbine_cfm_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load vanilla CFM ensemble for Gas Turbine dataset.

    Uses the default criteria:
    - epochs: 20
    - shuffle_params_seed: None (unshuffled only)
    - diag_cfm: False

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 20,
        "diag_cfm": False,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_gas_turbine_inn_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load INN ensemble for Gas Turbine dataset.

    Uses the default criteria:
    - epochs: 20
    - model_type: INN

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "model_type": "INN",
        "epochs": 20,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


# =============================================================================
# Unifoil Ensembles
# =============================================================================


def load_unifoil_diag_cfm_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load Diag-CFM ensemble for Unifoil dataset.

    Uses the default criteria:
    - epochs: 100
    - shuffle_params_seed: None (unshuffled only)
    - model_depth: 3
    - model_activation: ReLU
    - diag_cfm: True

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": UNIFOIL_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 100,
        "model_depth": 3,
        "model_activation": "ReLU",
        "diag_cfm": True,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_unifoil_cfm_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load vanilla CFM ensemble for Unifoil dataset.

    Uses the default criteria:
    - epochs: 100
    - shuffle_params_seed: None (unshuffled only)
    - model_depth: 3
    - model_activation: ReLU
    - diag_cfm: False

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": UNIFOIL_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 100,
        "model_depth": 3,
        "model_activation": "ReLU",
        "diag_cfm": False,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_unifoil_inn_ensemble(
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load INN ensemble for Unifoil dataset.

    Uses the default criteria:
    - epochs: 100
    - model_type: INN

    Args:
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": UNIFOIL_DATASET_NAME,
        "model_type": "INN",
        "epochs": 100,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


# =============================================================================
# DTLZ Ensembles
# =============================================================================


def load_dtlz_diag_cfm_ensemble(
    P: int,
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load Diag-CFM ensemble for DTLZ dataset.

    Uses the default criteria:
    - epochs: 50
    - shuffle_params_seed: None (unshuffled only)
    - sampling_strategy: stratified
    - num_objectives: 3
    - diag_cfm: True

    Args:
        P: Number of design parameters.
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": DTLZ_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 50,
        "sampling_strategy": "stratified",
        "num_design_params": P,
        "num_objectives": 3,
        "diag_cfm": True,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_dtlz_cfm_ensemble(
    P: int,
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load vanilla CFM ensemble for DTLZ dataset.

    Uses the default criteria:
    - epochs: 50
    - shuffle_params_seed: None (unshuffled only)
    - sampling_strategy: stratified
    - num_objectives: 3
    - diag_cfm: False

    Args:
        P: Number of design parameters.
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": DTLZ_DATASET_NAME,
        "shuffle_params_seed": None,
        "epochs": 50,
        "sampling_strategy": "stratified",
        "num_design_params": P,
        "num_objectives": 3,
        "diag_cfm": False,
    }

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


def load_dtlz_inn_ensemble(
    P: int,
    device: torch.device = None,
    verbose: bool = True,
) -> Tuple[List[torch.nn.Module], List[Dict], List[str], Dict]:
    """Load INN ensemble for DTLZ dataset.

    Uses the default criteria:
    - epochs: 50
    - model_type: INN
    - sampling_strategy: stratified
    - num_objectives: 3

    Args:
        P: Number of design parameters.
        device: Device to load models on (defaults to auto-detection).
        verbose: Whether to print loading progress.

    Returns:
        Tuple of (models, run_infos, checkpoint_names, criteria) where models
        are sorted by final training loss (ascending).
    """
    if device is None:
        device = get_device()

    criteria = {
        "dataset": DTLZ_DATASET_NAME,
        "model_type": "INN",
        "epochs": 50,
        "sampling_strategy": "stratified",
        "num_design_params": P,
        "num_objectives": 3,
    }

    if P == 100:
        criteria["num_blocks"] = 5

    models, run_infos, checkpoint_names = load_ensemble(
        criteria=criteria,
        device=device,
        verbose=verbose,
    )

    _validate_ensemble_parameter_counts(models, checkpoint_names)

    models, run_infos, checkpoint_names = _sort_by_final_train_loss(
        models, run_infos, checkpoint_names
    )

    return models, run_infos, checkpoint_names, criteria


if __name__ == "__main__":
    print("=" * 70)
    print("ENSEMBLE LOADING TEST")
    print("=" * 70)

    # Gas Turbine
    print("\n--- Gas Turbine ---")
    models, _, _, _ = load_gas_turbine_diag_cfm_ensemble(verbose=False)
    print(f"Diag-CFM: {len(models)} models")

    models, _, _, _ = load_gas_turbine_cfm_ensemble(verbose=False)
    print(f"CFM:      {len(models)} models")

    models, _, _, _ = load_gas_turbine_inn_ensemble(verbose=False)
    print(f"INN:      {len(models)} models")

    # Unifoil
    print("\n--- Unifoil ---")
    models, _, _, _ = load_unifoil_diag_cfm_ensemble(verbose=False)
    print(f"Diag-CFM: {len(models)} models")

    models, _, _, _ = load_unifoil_cfm_ensemble(verbose=False)
    print(f"CFM:      {len(models)} models")

    models, _, _, _ = load_unifoil_inn_ensemble(verbose=False)
    print(f"INN:      {len(models)} models")

    # DTLZ (test with P=12)
    print("\n--- DTLZ (P=12) ---")
    models, _, _, _ = load_dtlz_diag_cfm_ensemble(P=12, verbose=False)
    print(f"Diag-CFM: {len(models)} models")

    models, _, _, _ = load_dtlz_cfm_ensemble(P=12, verbose=False)
    print(f"CFM:      {len(models)} models")

    models, _, _, _ = load_dtlz_inn_ensemble(P=12, verbose=False)
    print(f"INN:      {len(models)} models")

    print("\n" + "=" * 70)
    print("All ensembles loaded successfully!")
    print("=" * 70)
