import tqdm
import torch
from typing import Type, Callable

from uq_diagcfm.utils import get_device
from uq_diagcfm.data_utils_gas_turbine import (
    GAS_TURBINE_DATASET_NAME,
    LEN_PARAMETERS as LEN_PARAMETERS_GAS_TURBINE,
    LEN_LABELS as LEN_LABELS_GAS_TURBINE,
    GasTurbineDataset,
    make_surrogates,
)
from uq_diagcfm.data_utils_unifoil import (
    UNIFOIL_DATASET_NAME,
    LEN_DESIGN_PARAMETERS as LEN_DESIGN_PARAMETERS_UNIFOIL,
    UnifoilDataset,
)
from uq_diagcfm.data_utils_dtlz import (
    DTLZ_DATASET_NAME,
    DTLZDataset,
    create_dtlz_dataset_class,
    make_dtlz_surrogate,
)
from uq_diagcfm.models_for_datasets import (
    models_for_gas_turbine,
    models_for_unifoil,
    models_for_dtlz,
    inn_for_gas_turbine,
    inn_for_dtlz,
    conditional_inn_for_unifoil,
)
from uq_diagcfm.solvers import euler_method
from uq_diagcfm.checkpointing import (
    name_run,
    save_run_info,
    save_model_checkpoint,
    save_epoch_checkpoint,
)
from uq_diagcfm.losses import inn_loss, conditional_inn_loss


def train_one_model(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    val_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    nb_epochs: int,
    diag_cfm: bool,
    ground_truth_surrogate: Callable | None,
    run_path: str | None = None,
    save_every_epoch: bool = False,
):
    device = get_device()
    model.to(device)

    # peak at data
    data_iter = iter(train_dataloader)
    batch = next(data_iter)
    if len(batch) == 2:
        b_x, b_y = batch
    elif len(batch) == 3:
        b_x, _, b_y = batch
    else:
        raise ValueError("Unexpected batch length.")
    b_x = b_x.to(device)
    b_y = b_y.to(device)
    _, x_dim = b_x.shape
    y_dim = b_y.shape[1]

    if diag_cfm:

        def _augment_x(b_x):
            batch_size = b_x.shape[0]
            x_complement = torch.zeros(batch_size, y_dim, device=device)
            return torch.cat((x_complement, b_x), dim=1)

        def _augment_y(b_y):
            batch_size = b_y.shape[0]
            # noise = torch.rand(batch_size, x_dim - y_dim, device=device)
            # y_complement = torch.cat((noise, 1 - noise[:, :y_dim]), dim=1)
            y_complement = torch.rand(batch_size, x_dim, device=device)
            return torch.cat((b_y, y_complement), dim=1)

    else:

        def _augment_x(b_x):
            return b_x

        def _augment_y(b_y):
            batch_size = b_y.shape[0]
            y_complement = torch.rand(batch_size, x_dim - y_dim, device=device)
            return torch.cat((b_y, y_complement), dim=1)

    mse = torch.nn.MSELoss()

    train_loss_trajectory = list()
    val_loss_trajectory = list()
    val_surrogate_loss_trajectory = list()

    for epoch in range(nb_epochs):
        model.train()
        print(f"Epoch {epoch} — LR: {optimizer.param_groups[0]['lr']:.6f}")
        pbar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch}")
        for i, batch in enumerate(pbar):
            if len(batch) == 2:
                b_x, b_y = batch
                b_x = b_x.to(device)
                b_y = b_y.to(device)
            elif len(batch) == 3:
                b_x, b_z, b_y = batch
                b_x = b_x.to(device)
                b_z = b_z.to(device)
                b_y = b_y.to(device)
            else:
                raise ValueError("Unexpected batch length.")

            b_x_aug = _augment_x(b_x)
            b_y_aug = _augment_y(b_y)
            t = torch.rand((b_x.shape[0], 1), device=device)
            # interpolation
            b_t = (1 - t) * b_x_aug + t * b_y_aug
            # target veolcity
            v_target = b_y_aug - b_x_aug
            # model prediction
            if len(batch) == 2:
                model_input = torch.cat((b_t, t), dim=1)
            elif len(batch) == 3:
                model_input = torch.cat((b_t, b_z, t), dim=1)
            v_pred = model(model_input)
            # loss
            loss = mse(v_pred, v_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix({"loss": loss.item()})
            train_loss_trajectory.append(loss.item())

        model.eval()
        with torch.no_grad():
            val_fm_loss_list = list()
            val_surrogate_loss_list = list()
            for i, batch in enumerate(val_dataloader):
                if len(batch) == 2:
                    b_x, b_y = batch
                    b_x = b_x.to(device)
                    b_y = b_y.to(device)
                elif len(batch) == 3:
                    b_x, b_z, b_y = batch
                    b_x = b_x.to(device)
                    b_z = b_z.to(device)
                    b_y = b_y.to(device)
                else:
                    raise ValueError("Unexpected batch length.")

                b_x_aug = _augment_x(b_x)
                b_y_aug = _augment_y(b_y)
                t = torch.rand((b_x.shape[0], 1), device=device)
                # interpolation
                b_t = (1 - t) * b_x_aug + t * b_y_aug
                # target veolcity
                v_target = b_y_aug - b_x_aug
                # model prediction
                if len(batch) == 2:
                    model_input = torch.cat((b_t, t), dim=1)
                elif len(batch) == 3:
                    model_input = torch.cat((b_t, b_z, t), dim=1)
                v_pred = model(model_input)
                # loss
                val_loss = mse(v_pred, v_target)
                val_fm_loss_list.append(val_loss.item())
                if ground_truth_surrogate is not None:
                    simulated_designs_euler = euler_method(
                        model=model, input=b_y_aug, start_t=1, end_t=0
                    )
                    if diag_cfm:
                        simulated_designs_euler = simulated_designs_euler[:, y_dim:]

                    surrogate_preds = ground_truth_surrogate(simulated_designs_euler)
                    surrogate_loss = mse(surrogate_preds, b_y)
                    val_surrogate_loss_list.append(surrogate_loss.item())
            epoch_val_fm_loss = sum(val_fm_loss_list) / len(val_fm_loss_list)
            print(f"Val FM Loss: {epoch_val_fm_loss:.6f}")
            val_loss_trajectory.append(epoch_val_fm_loss)
            if val_surrogate_loss_list:
                epoch_val_surrogate_loss = sum(val_surrogate_loss_list) / len(
                    val_surrogate_loss_list
                )
                print(f"Surrogate Loss: {epoch_val_surrogate_loss:.6f}")
                val_surrogate_loss_trajectory.append(epoch_val_surrogate_loss)

        # Save epoch checkpoint if requested
        if save_every_epoch and run_path is not None:
            save_epoch_checkpoint(run_path, model, epoch + 1)
            print(f"Saved checkpoint for epoch {epoch + 1}")

    return {
        "train_loss_trajectory": train_loss_trajectory,
        "val_loss_trajectory": val_loss_trajectory,
        "val_surrogate_loss_trajectory": val_surrogate_loss_trajectory,
    }


def train_one_model_on_dataset(
    model: torch.nn.Module,
    dataset_class: Type[torch.utils.data.Dataset],
    batch_size: int,
    nb_epochs: int,
    learning_rate: float,
    diag_cfm: bool,
    ground_truth_surrogate: Callable | None,
    params_length: int,
    shuffle_params_seed: None | int = None,
    run_path: str | None = None,
    save_every_epoch: bool = False,
):

    transform = None
    perm = None
    _ground_truth_surrogate = ground_truth_surrogate
    if shuffle_params_seed is not None:
        generator = torch.Generator().manual_seed(shuffle_params_seed)
        perm = torch.randperm(params_length, generator=generator)

        def transform(*args):
            x = args[0]
            modified_x = x[perm]
            return (modified_x, *args[1:])

        if ground_truth_surrogate is not None:
            inv_perm = torch.argsort(perm)
            _ground_truth_surrogate = lambda inputs: ground_truth_surrogate(
                inputs[..., inv_perm.to(inputs.device)]
            )

    train_dataset = dataset_class("train", transform=transform)
    val_dataset = dataset_class("val", transform=transform)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    output = train_one_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        nb_epochs=nb_epochs,
        diag_cfm=diag_cfm,
        ground_truth_surrogate=_ground_truth_surrogate,
        run_path=run_path,
        save_every_epoch=save_every_epoch,
    )
    output["param_permutation"] = perm.tolist() if perm is not None else None
    return output


def train_one_inn_model(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    val_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    nb_epochs: int,
    lambda_y: float = 1.0,
    lambda_z: float = 1.0,
    lambda_x: float = 1.0,
    ground_truth_surrogate: Callable | None = None,
):
    """Train an INN model with bidirectional loss.

    Args:
        model: INN model with forward(x) -> (y, z) and inverse(y, z) -> x.
        train_dataloader: Training data loader.
        val_dataloader: Validation data loader.
        optimizer: Optimizer for the model.
        nb_epochs: Number of training epochs.
        lambda_y: Weight for forward MSE loss.
        lambda_z: Weight for latent MMD loss.
        lambda_x: Weight for backward MMD loss.
        ground_truth_surrogate: Optional surrogate model for evaluation.

    Returns:
        Dictionary with training metrics.
    """
    device = get_device()
    model.to(device)

    train_loss_trajectory = []
    val_loss_trajectory = []
    val_surrogate_loss_trajectory = []

    for epoch in range(nb_epochs):
        model.train()
        print(f"Epoch {epoch} — LR: {optimizer.param_groups[0]['lr']:.6f}")
        pbar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch}")

        for i, batch in enumerate(pbar):
            if len(batch) == 2:
                b_x, b_y = batch
                b_x = b_x.to(device)
                b_y = b_y.to(device)
            else:
                raise ValueError("INN training expects batch of (x, y).")

            # Compute INN loss
            loss, loss_dict = inn_loss(
                model,
                b_x,
                b_y,
                lambda_y=lambda_y,
                lambda_z=lambda_z,
                lambda_x=lambda_x,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "L_y": f"{loss_dict['L_y']:.4f}",
                    "L_z": f"{loss_dict['L_z']:.4f}",
                    "L_x": f"{loss_dict['L_x']:.4f}",
                }
            )
            train_loss_trajectory.append(loss_dict)

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss_list = []
            val_surrogate_loss_list = []

            for i, batch in enumerate(val_dataloader):
                if len(batch) == 2:
                    b_x, b_y = batch
                    b_x = b_x.to(device)
                    b_y = b_y.to(device)
                else:
                    raise ValueError("INN training expects batch of (x, y).")

                # Compute validation loss
                loss, loss_dict = inn_loss(
                    model,
                    b_x,
                    b_y,
                    lambda_y=lambda_y,
                    lambda_z=lambda_z,
                    lambda_x=lambda_x,
                )
                val_loss_list.append(loss_dict)

                # Evaluate with surrogate if available
                if ground_truth_surrogate is not None:
                    # Generate designs from labels using inverse pass
                    batch_size = b_y.shape[0]
                    z_sampled = torch.randn(batch_size, model.latent_dim, device=device)
                    x_generated = model.inverse(b_y, z_sampled)
                    surrogate_preds = ground_truth_surrogate(x_generated)
                    surrogate_loss = torch.nn.functional.mse_loss(surrogate_preds, b_y)
                    val_surrogate_loss_list.append(surrogate_loss.item())

            # Average validation losses
            avg_val_loss = sum(d["total"] for d in val_loss_list) / len(val_loss_list)
            avg_L_y = sum(d["L_y"] for d in val_loss_list) / len(val_loss_list)
            avg_L_z = sum(d["L_z"] for d in val_loss_list) / len(val_loss_list)
            avg_L_x = sum(d["L_x"] for d in val_loss_list) / len(val_loss_list)

            print(
                f"Val Loss: {avg_val_loss:.6f} (L_y: {avg_L_y:.6f}, L_z: {avg_L_z:.6f}, L_x: {avg_L_x:.6f})"
            )
            val_loss_trajectory.append(
                {
                    "total": avg_val_loss,
                    "L_y": avg_L_y,
                    "L_z": avg_L_z,
                    "L_x": avg_L_x,
                }
            )

            if val_surrogate_loss_list:
                avg_surrogate_loss = sum(val_surrogate_loss_list) / len(
                    val_surrogate_loss_list
                )
                print(f"Surrogate Loss: {avg_surrogate_loss:.6f}")
                val_surrogate_loss_trajectory.append(avg_surrogate_loss)

    return {
        "train_loss_trajectory": train_loss_trajectory,
        "val_loss_trajectory": val_loss_trajectory,
        "val_surrogate_loss_trajectory": val_surrogate_loss_trajectory,
    }


def train_one_conditional_inn_model(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    val_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    nb_epochs: int,
    lambda_y: float = 1.0,
    lambda_z: float = 1.0,
    lambda_x: float = 1.0,
):
    """Train a conditional INN model with bidirectional loss.

    Args:
        model: Conditional INN model with forward(x, c) -> (y, z) and inverse(y, z, c) -> x.
        train_dataloader: Training data loader yielding (x, c, y) tuples.
        val_dataloader: Validation data loader yielding (x, c, y) tuples.
        optimizer: Optimizer for the model.
        nb_epochs: Number of training epochs.
        lambda_y: Weight for forward MSE loss.
        lambda_z: Weight for latent MMD loss.
        lambda_x: Weight for backward MMD loss.

    Returns:
        Dictionary with training metrics.
    """
    device = get_device()
    model.to(device)

    train_loss_trajectory = []
    val_loss_trajectory = []

    for epoch in range(nb_epochs):
        model.train()
        print(f"Epoch {epoch} — LR: {optimizer.param_groups[0]['lr']:.6f}")
        pbar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch}")

        for i, batch in enumerate(pbar):
            if len(batch) != 3:
                raise ValueError("Conditional INN training expects batch of (x, c, y).")

            b_x, b_c, b_y = batch
            b_x = b_x.to(device)
            b_c = b_c.to(device)
            b_y = b_y.to(device)

            # Compute conditional INN loss
            loss, loss_dict = conditional_inn_loss(
                model,
                b_x,
                b_y,
                b_c,
                lambda_y=lambda_y,
                lambda_z=lambda_z,
                lambda_x=lambda_x,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "L_y": f"{loss_dict['L_y']:.4f}",
                    "L_z": f"{loss_dict['L_z']:.4f}",
                    "L_x": f"{loss_dict['L_x']:.4f}",
                }
            )
            train_loss_trajectory.append(loss_dict)

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss_list = []

            for i, batch in enumerate(val_dataloader):
                if len(batch) != 3:
                    raise ValueError(
                        "Conditional INN training expects batch of (x, c, y)."
                    )

                b_x, b_c, b_y = batch
                b_x = b_x.to(device)
                b_c = b_c.to(device)
                b_y = b_y.to(device)

                # Compute validation loss
                loss, loss_dict = conditional_inn_loss(
                    model,
                    b_x,
                    b_y,
                    b_c,
                    lambda_y=lambda_y,
                    lambda_z=lambda_z,
                    lambda_x=lambda_x,
                )
                val_loss_list.append(loss_dict)

            # Average validation losses
            avg_val_loss = sum(d["total"] for d in val_loss_list) / len(val_loss_list)
            avg_L_y = sum(d["L_y"] for d in val_loss_list) / len(val_loss_list)
            avg_L_z = sum(d["L_z"] for d in val_loss_list) / len(val_loss_list)
            avg_L_x = sum(d["L_x"] for d in val_loss_list) / len(val_loss_list)

            print(
                f"Val Loss: {avg_val_loss:.6f} (L_y: {avg_L_y:.6f}, L_z: {avg_L_z:.6f}, L_x: {avg_L_x:.6f})"
            )
            val_loss_trajectory.append(
                {
                    "total": avg_val_loss,
                    "L_y": avg_L_y,
                    "L_z": avg_L_z,
                    "L_x": avg_L_x,
                }
            )

    return {
        "train_loss_trajectory": train_loss_trajectory,
        "val_loss_trajectory": val_loss_trajectory,
    }


def train_one_inn_on_gas_turbine_dataset(epochs: int = 50):
    """Train an INN model on the gas turbine dataset.

    Args:
        epochs: Number of training epochs.

    Returns:
        Dictionary containing run information and training metrics.
    """
    # Hyperparameters - tuned to match Diag-CFM model parameter count (~2.1M)
    # num_blocks=4, hidden_dim=256, subnet_depth=3 gives ~2.1M params
    run_info = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "model_type": "INN",
        "num_blocks": 4,
        "hidden_dim": 256,
        "subnet_depth": 3,
        "clamp": 2.0,
        "batch_size": 1000,
        "epochs": epochs,
        "learning_rate": 1e-3,
        "lambda_y": 1.0,
        "lambda_z": 1.0,
        "lambda_x": 10.0,
    }

    run_path, run_name = name_run(GAS_TURBINE_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")

    save_run_info(run_path, run_info)

    device = get_device()
    model = inn_for_gas_turbine(
        num_blocks=run_info["num_blocks"],
        hidden_dim=run_info["hidden_dim"],
        subnet_depth=run_info["subnet_depth"],
        clamp=run_info["clamp"],
    ).to(device)

    print("Model:")
    print(model)
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params:,}")
    run_info["number_of_parameters"] = num_params

    # Create dataloaders
    train_dataset = GasTurbineDataset("train")
    val_dataset = GasTurbineDataset("val")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=run_info["batch_size"], shuffle=True
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=run_info["batch_size"], shuffle=False
    )

    # Create surrogate for evaluation
    model_Unmix_O, model_IO_PD, model_IFD1 = make_surrogates()
    model_Unmix_O = model_Unmix_O.to(device)
    model_IO_PD = model_IO_PD.to(device)
    model_IFD1 = model_IFD1.to(device)

    def ground_truth_surrogate(x: torch.Tensor) -> torch.Tensor:
        y1 = model_Unmix_O(x)
        y2 = model_IO_PD(x)
        y3 = model_IFD1(x)
        y = torch.cat((y1, y2, y3), dim=1)
        return y

    optimizer = torch.optim.Adam(model.parameters(), lr=run_info["learning_rate"])

    train_output = train_one_inn_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        nb_epochs=run_info["epochs"],
        lambda_y=run_info["lambda_y"],
        lambda_z=run_info["lambda_z"],
        lambda_x=run_info["lambda_x"],
        ground_truth_surrogate=ground_truth_surrogate,
    )

    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)

    return run_info


def train_one_inn_on_dtlz_dataset(
    num_design_params: int = 12,
    num_objectives: int = 3,
    function_name: str = "dtlz2",
    epochs: int = 50,
):
    """Train an INN model on the DTLZ dataset.

    Args:
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        function_name: DTLZ function name ("dtlz1" or "dtlz2").
        epochs: Number of training epochs.

    Returns:
        Dictionary containing run information and training metrics.
    """
    # Hyperparameters - scaled to match Diag-CFM parameter count
    # Diag-CFM params by P: P=12 ~542K, P=24 ~3.2M, P=50 ~3.3M, P=100 ~4.4M
    if num_design_params <= 20:
        # ~542K params to match Diag-CFM (~542K for P=12)
        num_blocks = 2
        hidden_dim = 128
        subnet_depth = 5
    elif num_design_params <= 35:
        # ~3.2M params to match Diag-CFM (~3.2M for P=24)
        num_blocks = 3
        hidden_dim = 256
        subnet_depth = 5
    elif num_design_params <= 70:
        # ~3.3M params to match Diag-CFM (~3.3M for P=50)
        num_blocks = 5
        hidden_dim = 224
        subnet_depth = 4
    else:
        # ~4.5M params to match Diag-CFM (~4.4M for P=100)
        num_blocks = 5
        hidden_dim = 256
        subnet_depth = 4

    run_info = {
        "dataset": DTLZ_DATASET_NAME,
        "model_type": "INN",
        "function_name": function_name,
        "num_design_params": num_design_params,
        "num_objectives": num_objectives,
        "num_blocks": num_blocks,
        "hidden_dim": hidden_dim,
        "subnet_depth": subnet_depth,
        "clamp": 2.0,
        "batch_size": 1000,
        "epochs": epochs,
        "learning_rate": 1e-3,
        "lambda_y": 1.0,
        "lambda_z": 1.0,
        "lambda_x": 10.0,
        "sampling_strategy": "stratified",
        "g_max": 2.0,
    }

    run_path, run_name = name_run(DTLZ_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")
    print(
        f"DTLZ INN config: P={num_design_params}, L={num_objectives}, func={function_name}"
    )

    save_run_info(run_path, run_info)

    device = get_device()
    model = inn_for_dtlz(
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        num_blocks=run_info["num_blocks"],
        hidden_dim=run_info["hidden_dim"],
        subnet_depth=run_info["subnet_depth"],
        clamp=run_info["clamp"],
    ).to(device)

    print("Model:")
    print(model)
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params:,}")
    run_info["number_of_parameters"] = num_params

    # Create dataloaders
    train_dataset = DTLZDataset(
        split="train",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=run_info["sampling_strategy"],
        g_max=run_info["g_max"],
    )
    val_dataset = DTLZDataset(
        split="val",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
        sampling_strategy=run_info["sampling_strategy"],
        g_max=run_info["g_max"],
    )
    label_scale = train_dataset.label_scale

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=run_info["batch_size"], shuffle=True
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=run_info["batch_size"], shuffle=False
    )

    # Create surrogate for evaluation (analytical forward function with normalization)
    base_forward = make_dtlz_surrogate(function_name, num_objectives)

    def ground_truth_surrogate(x: torch.Tensor) -> torch.Tensor:
        """Forward function with label normalization."""
        x_clamped = torch.clamp(x, 0, 1)
        return base_forward(x_clamped) / label_scale

    optimizer = torch.optim.Adam(model.parameters(), lr=run_info["learning_rate"])

    train_output = train_one_inn_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        nb_epochs=run_info["epochs"],
        lambda_y=run_info["lambda_y"],
        lambda_z=run_info["lambda_z"],
        lambda_x=run_info["lambda_x"],
        ground_truth_surrogate=ground_truth_surrogate,
    )

    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)

    return run_info


def train_one_model_on_gas_turbine_dataset(
    diag_cfm: bool,
    shuffle_params_seed: None | int,
):

    # hyperparameters
    run_info = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "model_hidden_dimension": 512 * 2,
        "model_depth": 3,
        "model_activation": "LeakyReLU",
        "batch_size": 5000,
        "epochs": 20,
        "learning_rate": 1e-3,
        "scheduler_step_size": 20,
        "scheduler_gamma": 0.8,
        "dropout": 0,
        "diag_cfm": diag_cfm,
        "shuffle_params_seed": shuffle_params_seed,
    }

    run_path, run_name = name_run(GAS_TURBINE_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")

    save_run_info(run_path, run_info)

    device = get_device()
    model = models_for_gas_turbine(
        diag_cfm=diag_cfm,
        model_hidden_dimension=run_info["model_hidden_dimension"],
        model_depth=run_info["model_depth"],
        dropout=run_info["dropout"],
        model_activation=run_info["model_activation"],
    ).to(device)

    print("Model:")
    print(model)
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    run_info["number_of_parameters"] = sum(p.numel() for p in model.parameters())

    model_Unmix_O, model_IO_PD, model_IFD1 = make_surrogates()
    model_Unmix_O = model_Unmix_O.to(device)
    model_IO_PD = model_IO_PD.to(device)
    model_IFD1 = model_IFD1.to(device)

    def ground_truth_surrogate(x: torch.Tensor) -> torch.Tensor:
        y1 = model_Unmix_O(x)
        y2 = model_IO_PD(x)
        y3 = model_IFD1(x)
        y = torch.cat((y1, y2, y3), dim=1)
        return y

    train_output = train_one_model_on_dataset(
        model=model,
        dataset_class=GasTurbineDataset,
        batch_size=run_info["batch_size"],
        nb_epochs=run_info["epochs"],
        learning_rate=run_info["learning_rate"],
        diag_cfm=diag_cfm,
        ground_truth_surrogate=ground_truth_surrogate,
        params_length=LEN_PARAMETERS_GAS_TURBINE,
        shuffle_params_seed=run_info["shuffle_params_seed"],
    )
    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)

    return run_info


def train_one_model_on_unifoil_dataset(
    diag_cfm: bool,
    shuffle_params_seed: None | int,
):

    # hyperparameters
    run_info = {
        "dataset": UNIFOIL_DATASET_NAME,
        "model_hidden_dimension": 512 * 2,
        "model_depth": 3,
        "model_activation": "ReLU",
        "batch_size": 100,
        "epochs": 100,
        "learning_rate": 1e-3,
        "scheduler_step_size": 20,
        "scheduler_gamma": 0.8,
        "dropout": 0,
        "diag_cfm": diag_cfm,
        "shuffle_params_seed": shuffle_params_seed,
    }

    run_path, run_name = name_run(UNIFOIL_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")

    save_run_info(run_path, run_info)

    device = get_device()
    model = models_for_unifoil(
        diag_cfm=diag_cfm,
        model_hidden_dimension=run_info["model_hidden_dimension"],
        model_depth=run_info["model_depth"],
        dropout=run_info["dropout"],
        model_activation=run_info["model_activation"],
    ).to(device)

    print("Model:")
    print(model)
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    run_info["number_of_parameters"] = sum(p.numel() for p in model.parameters())

    train_output = train_one_model_on_dataset(
        model=model,
        dataset_class=UnifoilDataset,
        batch_size=run_info["batch_size"],
        nb_epochs=run_info["epochs"],
        learning_rate=run_info["learning_rate"],
        diag_cfm=diag_cfm,
        ground_truth_surrogate=None,
        params_length=LEN_DESIGN_PARAMETERS_UNIFOIL,
        shuffle_params_seed=run_info["shuffle_params_seed"],
    )
    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)
    return run_info


def train_one_inn_on_unifoil_dataset(epochs: int = 100):
    """Train a conditional INN model on the unifoil dataset.

    The unifoil dataset requires conditioning on physical parameters
    (angle of attack, Mach number) during both forward and inverse passes.

    Args:
        epochs: Number of training epochs.

    Returns:
        Dictionary containing run information and training metrics.
    """
    # Hyperparameters - tuned to match Diag-CFM parameter count
    # Diag-CFM on unifoil has ~2.1M params (hidden=1024, depth=3)
    # ConditionalINN with nb=4, hd=256, sd=3 gives ~2.17M params (ratio 1.02x)
    run_info = {
        "dataset": UNIFOIL_DATASET_NAME,
        "model_type": "INN",
        "num_blocks": 4,
        "hidden_dim": 256,
        "subnet_depth": 3,
        "clamp": 2.0,
        "batch_size": 100,
        "epochs": epochs,
        "learning_rate": 1e-3,
        "lambda_y": 1.0,
        "lambda_z": 1.0,
        "lambda_x": 10.0,
    }

    run_path, run_name = name_run(UNIFOIL_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")

    save_run_info(run_path, run_info)

    device = get_device()
    model = conditional_inn_for_unifoil(
        num_blocks=run_info["num_blocks"],
        hidden_dim=run_info["hidden_dim"],
        subnet_depth=run_info["subnet_depth"],
        clamp=run_info["clamp"],
    ).to(device)

    print("Model:")
    print(model)
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params:,}")
    run_info["number_of_parameters"] = num_params

    # Create dataloaders
    train_dataset = UnifoilDataset("train")
    val_dataset = UnifoilDataset("val")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=run_info["batch_size"], shuffle=True
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=run_info["batch_size"], shuffle=False
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=run_info["learning_rate"])

    train_output = train_one_conditional_inn_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        nb_epochs=run_info["epochs"],
        lambda_y=run_info["lambda_y"],
        lambda_z=run_info["lambda_z"],
        lambda_x=run_info["lambda_x"],
    )

    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)

    return run_info


def train_one_model_on_dtlz_dataset(
    diag_cfm: bool,
    shuffle_params_seed: None | int,
    num_design_params: int = 50,
    num_objectives: int = 3,
    function_name: str = "dtlz2",
):
    """
    Train a single Diag-CFM or CFM model on the DTLZ benchmark.

    This function trains a flow matching model on synthetically generated
    DTLZ data. The key advantage is that the ground truth forward function
    is analytical, enabling precise round-trip error evaluation.

    Args:
        diag_cfm: Whether to use Diagonal CFM (True) or vanilla CFM (False).
        shuffle_params_seed: Seed for parameter permutation (None = no shuffle).
        num_design_params: Design space dimension P (default: 50).
        num_objectives: Number of objectives L (default: 3).
        function_name: DTLZ function name ("dtlz1" or "dtlz2").

    Returns:
        Dictionary containing run information and training metrics.

    Example:
        >>> out = train_one_model_on_dtlz_dataset(
        ...     diag_cfm=True,
        ...     shuffle_params_seed=None,
        ...     num_design_params=50,
        ...     num_objectives=3,
        ... )
    """
    # Hyperparameters - scaled based on problem dimension
    # Use larger model for higher dimensions
    if num_design_params <= 20:
        hidden_dim = 512
        depth = 3
    elif num_design_params <= 50:
        hidden_dim = 1024
        depth = 4
    else:
        hidden_dim = 1024
        depth = 5

    run_info = {
        "dataset": DTLZ_DATASET_NAME,
        "function_name": function_name,
        "num_design_params": num_design_params,
        "num_objectives": num_objectives,
        "model_hidden_dimension": hidden_dim,
        "model_depth": depth,
        "model_activation": "LeakyReLU",
        "batch_size": 1000,
        "epochs": 50,
        "learning_rate": 1e-3,
        "scheduler_step_size": 20,
        "scheduler_gamma": 0.8,
        "dropout": 0,
        "diag_cfm": diag_cfm,
        "shuffle_params_seed": shuffle_params_seed,
        "sampling_strategy": "stratified",
        "g_max": 2.0,
    }

    run_path, run_name = name_run(DTLZ_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")
    print(
        f"DTLZ config: P={num_design_params}, L={num_objectives}, func={function_name}"
    )

    save_run_info(run_path, run_info)

    device = get_device()
    model = models_for_dtlz(
        diag_cfm=diag_cfm,
        model_hidden_dimension=run_info["model_hidden_dimension"],
        model_depth=run_info["model_depth"],
        dropout=run_info["dropout"],
        model_activation=run_info["model_activation"],
        num_design_params=num_design_params,
        num_objectives=num_objectives,
    ).to(device)

    print("Model:")
    print(model)
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    run_info["number_of_parameters"] = sum(p.numel() for p in model.parameters())

    # Create the ground truth surrogate (analytical forward function)
    # The DTLZ dataset normalizes labels, so we need to match that
    train_dataset_for_scale = DTLZDataset(
        split="train",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
    )
    label_scale = train_dataset_for_scale.label_scale

    base_forward = make_dtlz_surrogate(function_name, num_objectives)

    def ground_truth_surrogate(x: torch.Tensor) -> torch.Tensor:
        """Forward function with label normalization."""
        return base_forward(x) / label_scale

    # Create configured dataset class
    DatasetClass = create_dtlz_dataset_class(
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name=function_name,
        normalize_labels=True,
    )

    train_output = train_one_model_on_dataset(
        model=model,
        dataset_class=DatasetClass,
        batch_size=run_info["batch_size"],
        nb_epochs=run_info["epochs"],
        learning_rate=run_info["learning_rate"],
        diag_cfm=diag_cfm,
        ground_truth_surrogate=ground_truth_surrogate,
        params_length=num_design_params,
        shuffle_params_seed=run_info["shuffle_params_seed"],
    )
    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    save_model_checkpoint(run_path, model)
    return run_info


def main(
    trainer_function: Callable,
    number_of_models: int,
    diag_cfm_values: list[bool],
    shuffle_params_seeds: list[int | None],
):
    all_runs_output = list()
    train_run_count = 1
    for shuffle_seed in shuffle_params_seeds:
        print(f"Shuffle seed: {shuffle_seed}")
        for diag_cfm in diag_cfm_values:
            for i in range(number_of_models):
                print(
                    f"Training {'diag_cfm' if diag_cfm else 'non-diag_cfm'} model {i + 1}/{number_of_models} for shuffle seed {shuffle_seed}."
                )
                print(
                    f"This is train run {train_run_count} of {len(shuffle_params_seeds) * len(diag_cfm_values) * number_of_models}."
                )
                out = trainer_function(
                    diag_cfm=diag_cfm,
                    shuffle_params_seed=shuffle_seed,
                )
                all_runs_output.append(out)
                print("----------------------------------------")
                print()
                print()
                train_run_count += 1
    return all_runs_output


def train_diag_cfm_with_epoch_checkpoints(nb_epochs: int = 20):
    """Train a Diag-CFM model on gas turbine, saving checkpoint at each epoch.

    Args:
        nb_epochs: Number of training epochs

    Returns:
        Tuple of (run_path, run_info) for loading checkpoints later
    """
    run_info = {
        "dataset": GAS_TURBINE_DATASET_NAME,
        "model_hidden_dimension": 512 * 2,
        "model_depth": 3,
        "model_activation": "LeakyReLU",
        "batch_size": 5000,
        "epochs": nb_epochs,
        "learning_rate": 1e-3,
        "scheduler_step_size": 20,
        "scheduler_gamma": 0.8,
        "dropout": 0,
        "diag_cfm": True,
        "shuffle_params_seed": None,
    }

    run_path, run_name = name_run(GAS_TURBINE_DATASET_NAME, run_info)
    run_info["run_path"] = str(run_name)
    print(f"Run name: {run_name}")

    save_run_info(run_path, run_info)

    device = get_device()
    model = models_for_gas_turbine(
        diag_cfm=True,
        model_hidden_dimension=run_info["model_hidden_dimension"],
        model_depth=run_info["model_depth"],
        dropout=run_info["dropout"],
        model_activation=run_info["model_activation"],
    ).to(device)

    print("Model:")
    print(model)
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params:,}")
    run_info["number_of_parameters"] = num_params

    model_Unmix_O, model_IO_PD, model_IFD1 = make_surrogates()
    model_Unmix_O = model_Unmix_O.to(device)
    model_IO_PD = model_IO_PD.to(device)
    model_IFD1 = model_IFD1.to(device)

    def ground_truth_surrogate(x: torch.Tensor) -> torch.Tensor:
        y1 = model_Unmix_O(x)
        y2 = model_IO_PD(x)
        y3 = model_IFD1(x)
        return torch.cat((y1, y2, y3), dim=1)

    train_output = train_one_model_on_dataset(
        model=model,
        dataset_class=GasTurbineDataset,
        batch_size=run_info["batch_size"],
        nb_epochs=run_info["epochs"],
        learning_rate=run_info["learning_rate"],
        diag_cfm=True,
        ground_truth_surrogate=ground_truth_surrogate,
        params_length=LEN_PARAMETERS_GAS_TURBINE,
        shuffle_params_seed=None,
        run_path=run_path,
        save_every_epoch=True,
    )

    for k, v in train_output.items():
        run_info[k] = v

    save_run_info(run_path, run_info)
    # Also save final checkpoint as the default
    save_model_checkpoint(run_path, model)

    return run_path, run_info


if __name__ == "__main__":
    import sys

    if len(sys.argv) == 2 and sys.argv[1] == "train_gas_turbine":
        torch.manual_seed(1)
        diag_cfm = True
        out = train_one_model_on_gas_turbine_dataset(
            diag_cfm=diag_cfm,
            shuffle_params_seed=1,
        )
        for k, v in out.items():
            print(f"{k}:\n{v}")

    if len(sys.argv) == 2 and sys.argv[1] == "train_unifoil":
        torch.manual_seed(1)
        diag_cfm = True
        out = train_one_model_on_unifoil_dataset(
            diag_cfm=diag_cfm,
            shuffle_params_seed=None,
        )
        for k, v in out.items():
            print(f"{k}:\n{v}")

    elif len(sys.argv) == 2 and sys.argv[1] == "main_gas_turbine":
        diag_cfm_values = [True, False]
        number_of_models = 5
        shuffle_params_seeds = [None]
        all_out = main(
            trainer_function=train_one_model_on_gas_turbine_dataset,
            number_of_models=number_of_models,
            diag_cfm_values=diag_cfm_values,
            shuffle_params_seeds=shuffle_params_seeds,
        )

    elif len(sys.argv) == 2 and sys.argv[1] == "main_unifoil":
        diag_cfm_values = [True, False]
        number_of_models = 5
        shuffle_params_seeds = [None]
        all_out = main(
            trainer_function=train_one_model_on_unifoil_dataset,
            number_of_models=number_of_models,
            diag_cfm_values=diag_cfm_values,
            shuffle_params_seeds=shuffle_params_seeds,
        )

    elif len(sys.argv) == 2 and sys.argv[1] == "train_dtlz":
        # Train a single DTLZ model for testing
        torch.manual_seed(1)
        out = train_one_model_on_dtlz_dataset(
            diag_cfm=True,
            shuffle_params_seed=None,
            num_design_params=50,
            num_objectives=3,
            function_name="dtlz2",
        )
        for k, v in out.items():
            print(f"{k}:\n{v}")

    elif len(sys.argv) == 2 and sys.argv[1] == "main_dtlz":
        # Full DTLZ experiment: train ensembles for P=50 with Diag-CFM and CFM
        # This trains 5 models each for Diag-CFM and CFM
        from functools import partial

        num_design_params = 50
        num_objectives = 3

        # Create a partial function with fixed DTLZ parameters
        trainer_fn = partial(
            train_one_model_on_dtlz_dataset,
            num_design_params=num_design_params,
            num_objectives=num_objectives,
            function_name="dtlz2",
        )

        diag_cfm_values = [True, False]
        number_of_models = 5
        shuffle_params_seeds = [None]

        all_out = main(
            trainer_function=trainer_fn,
            number_of_models=number_of_models,
            diag_cfm_values=diag_cfm_values,
            shuffle_params_seeds=shuffle_params_seeds,
        )

    elif len(sys.argv) == 2 and sys.argv[1] == "main_dtlz_scaling":
        # Scaling experiment: train models for different design dimensions
        # This demonstrates how Diag-CFM scales with dimensionality
        from functools import partial

        dimensions = [12, 24, 50, 100]
        num_objectives = 3

        for P in dimensions:
            print(f"\n{'='*80}")
            print(f"TRAINING DTLZ WITH P={P} DESIGN PARAMETERS")
            print(f"{'='*80}\n")

            trainer_fn = partial(
                train_one_model_on_dtlz_dataset,
                num_design_params=P,
                num_objectives=num_objectives,
                function_name="dtlz2",
            )

            # Train fewer models for scaling experiment
            all_out = main(
                trainer_function=trainer_fn,
                number_of_models=5,
                diag_cfm_values=[True, False],
                shuffle_params_seeds=[None],  # Only unshuffled for scaling
            )

    elif len(sys.argv) >= 2 and sys.argv[1] == "train_inn_gas_turbine":
        # Train a single INN model on gas turbine dataset
        # Usage: train_inn_gas_turbine [epochs]
        epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 50
        torch.manual_seed(1)
        out = train_one_inn_on_gas_turbine_dataset(epochs=epochs)
        for k, v in out.items():
            print(f"{k}:\n{v}")
    elif len(sys.argv) >= 2 and sys.argv[1] == "main_inn_gas_turbine":
        # Train ensemble of INN models on gas turbine dataset
        # Usage: main_inn_gas_turbine [number_of_models] [epochs]
        epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 20
        number_of_models = 5

        for i in range(number_of_models):
            print(
                f"Training INN model {i + 1}/{number_of_models} on gas turbine dataset."
            )
            out = train_one_inn_on_gas_turbine_dataset(epochs=epochs)
            for k, v in out.items():
                print(f"{k}:\n{v}")
            print("----------------------------------------")
            print()
            print()

    elif len(sys.argv) >= 2 and sys.argv[1] == "train_inn_dtlz":
        # Train a single INN model on DTLZ dataset
        # Usage: train_inn_dtlz [P] [epochs]
        P = int(sys.argv[2]) if len(sys.argv) > 2 else 12
        epochs = int(sys.argv[3]) if len(sys.argv) > 3 else 50
        torch.manual_seed(1)
        out = train_one_inn_on_dtlz_dataset(
            num_design_params=P,
            num_objectives=3,
            function_name="dtlz2",
            epochs=epochs,
        )
        for k, v in out.items():
            print(f"{k}:\n{v}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "main_inn_dtlz":
        # Train ensemble of INN models on DTLZ dataset
        # Usage: main_inn_dtlz [P] [epochs]
        P = int(sys.argv[2]) if len(sys.argv) > 2 else 12
        epochs = int(sys.argv[3]) if len(sys.argv) > 3 else 50
        number_of_models = 5

        for i in range(number_of_models):
            print(
                f"Training INN model {i + 1}/{number_of_models} on DTLZ P={P} dataset."
            )
            out = train_one_inn_on_dtlz_dataset(
                num_design_params=P,
                num_objectives=3,
                function_name="dtlz2",
                epochs=epochs,
            )
            for k, v in out.items():
                print(f"{k}:\n{v}")
            print("----------------------------------------")
            print()
            print()

    elif len(sys.argv) >= 2 and sys.argv[1] == "train_inn_unifoil":
        # Train a single conditional INN model on unifoil dataset
        # Usage: train_inn_unifoil [epochs]
        epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 100
        torch.manual_seed(1)
        out = train_one_inn_on_unifoil_dataset(epochs=epochs)
        for k, v in out.items():
            print(f"{k}:\n{v}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "main_inn_unifoil":
        # Train ensemble of conditional INN models on unifoil dataset
        # Usage: main_inn_unifoil [epochs]
        epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 100
        number_of_models = 5

        for i in range(number_of_models):
            print(
                f"Training conditional INN model {i + 1}/{number_of_models} on unifoil dataset."
            )
            out = train_one_inn_on_unifoil_dataset(epochs=epochs)
            for k, v in out.items():
                print(f"{k}:\n{v}")
            print("----------------------------------------")
            print()
            print()

    elif len(sys.argv) >= 2 and sys.argv[1] == "train_diag_cfm_epoch_checkpoints":
        # Train Diag-CFM with checkpoints at each epoch
        # Usage: train_diag_cfm_epoch_checkpoints [epochs]
        epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 20
        torch.manual_seed(1)
        run_path, run_info = train_diag_cfm_with_epoch_checkpoints(nb_epochs=epochs)
        print(f"\nTraining complete!")
        print(f"Run path: {run_path}")
        print(
            f"Epoch checkpoints saved: model_checkpoint_epoch1.pth to model_checkpoint_epoch{epochs}.pth"
        )
