"""
Model builders for different datasets.

This module provides factory functions that create MLP models with appropriate
input/output dimensions for each dataset. The dimensions depend on:
- Whether Diag-CFM or vanilla CFM is used
- The dataset's design parameter and label dimensions

For Diag-CFM:
    - Input: [state (P+L), time (1)] = P + L + 1
    - Output: [velocity (P+L)] = P + L

For vanilla CFM:
    - Input: [state (P), time (1)] = P + 1
    - Output: [velocity (P)] = P

For INN:
    - Input: [design parameters (P)] = P
    - Output: [labels (L), latent (P-L)] = P
"""

from uq_diagcfm.models import MLP, INN, ConditionalINN
from uq_diagcfm.data_utils_gas_turbine import (
    LEN_PARAMETERS as LEN_PARAMETERS_GAS_TURBINE,
    LEN_LABELS as LEN_LABELS_GAS_TURBINE,
)
from uq_diagcfm.data_utils_unifoil import (
    LEN_DESIGN_PARAMETERS as LEN_DESIGN_PARAMETERS_UNIFOIL,
    LEN_PHYSICAL_PARAMS as LEN_PHYSICAL_PARAMS_UNIFOIL,
    LEN_PHYSICAL_PERFORMANCE as LEN_PHYSICAL_PERFORMANCE_UNIFOIL,
)


def models_for_gas_turbine(
    diag_cfm: bool,
    model_hidden_dimension: int,
    model_depth: int,
    dropout: float,
    model_activation: str,
):
    if diag_cfm:
        model = MLP(
            input_dim=LEN_PARAMETERS_GAS_TURBINE + LEN_LABELS_GAS_TURBINE + 1,
            output_dim=LEN_PARAMETERS_GAS_TURBINE + LEN_LABELS_GAS_TURBINE,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    else:
        model = MLP(
            input_dim=LEN_PARAMETERS_GAS_TURBINE + 1,
            output_dim=LEN_PARAMETERS_GAS_TURBINE,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    return model


def models_for_unifoil(
    diag_cfm: bool,
    model_hidden_dimension: int,
    model_depth: int,
    dropout: float,
    model_activation: str,
):
    if diag_cfm:
        model = MLP(
            input_dim=(LEN_DESIGN_PARAMETERS_UNIFOIL + LEN_PHYSICAL_PERFORMANCE_UNIFOIL)
            + (LEN_PHYSICAL_PARAMS_UNIFOIL + 1),
            output_dim=LEN_DESIGN_PARAMETERS_UNIFOIL + LEN_PHYSICAL_PERFORMANCE_UNIFOIL,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    else:
        model = MLP(
            input_dim=LEN_DESIGN_PARAMETERS_UNIFOIL + (LEN_PHYSICAL_PARAMS_UNIFOIL + 1),
            output_dim=LEN_DESIGN_PARAMETERS_UNIFOIL,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    return model


def models_for_dtlz(
    diag_cfm: bool,
    model_hidden_dimension: int,
    model_depth: int,
    dropout: float,
    model_activation: str,
    num_design_params: int,
    num_objectives: int,
):
    """
    Create MLP model for DTLZ benchmark functions.

    This function creates models with configurable design dimension (P) and
    number of objectives (L), enabling scalability experiments.

    Args:
        diag_cfm: Whether to use Diagonal CFM (True) or vanilla CFM (False).
        model_hidden_dimension: Width of hidden layers.
        model_depth: Number of hidden layers.
        dropout: Dropout probability (0 = no dropout).
        model_activation: Activation function name ("ReLU", "SiLU", etc.).
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.

    Returns:
        MLP model with appropriate input/output dimensions.

    Example:
        >>> model = models_for_dtlz(
        ...     diag_cfm=True,
        ...     model_hidden_dimension=1024,
        ...     model_depth=4,
        ...     dropout=0.0,
        ...     model_activation="LeakyReLU",
        ...     num_design_params=50,
        ...     num_objectives=3,
        ... )
    """
    P = num_design_params
    L = num_objectives

    if diag_cfm:
        # Diag-CFM: state is [labels_complement, design] concatenated with [labels, noise]
        # State dimension: P + L
        # Input: state + time = P + L + 1
        # Output: velocity = P + L
        model = MLP(
            input_dim=P + L + 1,
            output_dim=P + L,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    else:
        # Vanilla CFM: state is design parameters
        # State dimension: P
        # Input: state + time = P + 1
        # Output: velocity = P
        model = MLP(
            input_dim=P + 1,
            output_dim=P,
            hidden_dim=model_hidden_dimension,
            depth=model_depth,
            dropout=dropout,
            activation=model_activation,
        )
    return model


def inn_for_gas_turbine(
    num_blocks: int = 8,
    hidden_dim: int = 512,
    subnet_depth: int = 3,
    clamp: float = 2.0,
    activation: str = "ReLU",
):
    """Create INN model for gas turbine dataset.

    Args:
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.

    Returns:
        INN model with gas turbine dimensions (6 params -> 3 labels).
    """
    return INN(
        input_dim=LEN_PARAMETERS_GAS_TURBINE,
        output_dim=LEN_LABELS_GAS_TURBINE,
        num_blocks=num_blocks,
        hidden_dim=hidden_dim,
        subnet_depth=subnet_depth,
        clamp=clamp,
        activation=activation,
    )


def inn_for_unifoil(
    num_blocks: int = 4,
    hidden_dim: int = 256,
    subnet_depth: int = 3,
    clamp: float = 2.0,
    activation: str = "LeakyReLU",
):
    """Create INN model for unifoil dataset.

    Note: For unifoil, use conditional_inn_for_unifoil instead, as the dataset
    requires conditioning on physical parameters.

    Default hyperparameters are tuned to match Diag-CFM parameter count (~2.1M).

    Args:
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.

    Returns:
        INN model with unifoil dimensions.
    """
    return INN(
        input_dim=LEN_DESIGN_PARAMETERS_UNIFOIL,
        output_dim=LEN_PHYSICAL_PERFORMANCE_UNIFOIL,
        num_blocks=num_blocks,
        hidden_dim=hidden_dim,
        subnet_depth=subnet_depth,
        clamp=clamp,
        activation=activation,
    )


def inn_for_dtlz(
    num_design_params: int,
    num_objectives: int,
    num_blocks: int = 4,
    hidden_dim: int = 256,
    subnet_depth: int = 4,
    clamp: float = 2.0,
    activation: str = "LeakyReLU",
):
    """Create INN model for DTLZ benchmark functions.

    Default hyperparameters are suitable for P=24-50. The train function
    automatically adjusts these based on P to match Diag-CFM parameter count:
    - P <= 20: nb=4, hd=128, sd=3 (~555K params)
    - P <= 50: nb=4, hd=256, sd=4 (~3.3M params)
    - P > 50:  nb=6, hd=384, sd=2 (~4.5M params)

    Args:
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.

    Returns:
        INN model with DTLZ dimensions.
    """
    return INN(
        input_dim=num_design_params,
        output_dim=num_objectives,
        num_blocks=num_blocks,
        hidden_dim=hidden_dim,
        subnet_depth=subnet_depth,
        clamp=clamp,
        activation=activation,
    )


def conditional_inn_for_unifoil(
    num_blocks: int = 4,
    hidden_dim: int = 256,
    subnet_depth: int = 3,
    clamp: float = 2.0,
    activation: str = "LeakyReLU",
):
    """Create Conditional INN model for unifoil dataset.

    The unifoil dataset requires conditioning on physical parameters
    (angle of attack, Mach number) during both forward and inverse passes.

    Default hyperparameters are tuned to match Diag-CFM parameter count (~2.1M).

    Args:
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.

    Returns:
        ConditionalINN model with unifoil dimensions.
    """
    return ConditionalINN(
        input_dim=LEN_DESIGN_PARAMETERS_UNIFOIL,  # 14
        output_dim=LEN_PHYSICAL_PERFORMANCE_UNIFOIL,  # 3
        cond_dim=LEN_PHYSICAL_PARAMS_UNIFOIL,  # 2
        num_blocks=num_blocks,
        hidden_dim=hidden_dim,
        subnet_depth=subnet_depth,
        clamp=clamp,
        activation=activation,
    )
