"""Training utility functions for neural operator models.

This module provides utility functions commonly used during training, including
batch size probing, distributed training setup, and configuration validation.
"""

from typing import Tuple

import torch
from torch import nn as nn

from src.config import Config


def get_optimizer(
        model: torch.nn.Module,
        config: Config
) -> torch.optim.Optimizer:
    """Instantiate optimizer based on configuration.

    Parameters
    ----------
    model : torch.nn.Module
        Model whose parameters will be optimized.
    config : Config
        Experiment configuration with optimizer settings.

    Returns
    -------
    torch.optim.Optimizer
        Configured optimizer instance.

    Raises
    ------
    ValueError
        If optimizer name is not supported.
    """
    optimizer_config = config.training.optimizer
    if optimizer_config['name'].lower() == "adam":
        return torch.optim.Adam(
            model.parameters(),
            **optimizer_config['optimizer_params']
        )
    elif optimizer_config['name'].lower() == "adamw":
        return torch.optim.AdamW(
            model.parameters(),
            **optimizer_config['optimizer_params']
        )
    elif optimizer_config['name'].lower() == "sgd":
        return torch.optim.SGD(
            model.parameters(),
            **optimizer_config['optimizer_params']
        )
    else:
        raise ValueError(f"Unknown optimizer: {config.training.optimizer}")


def get_scheduler(
        optimizer: torch.optim.Optimizer,
        config: Config
) -> torch.optim.lr_scheduler.LRScheduler:
    """Instantiate LR scheduler based on configuration.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        Optimizer to schedule.
    config : Config
        Experiment configuration with scheduler settings.

    Returns
    -------
    Optional[torch.optim.lr_scheduler._LRScheduler]
        Configured scheduler instance or None.

    Raises
    ------
    ValueError
        If scheduler name is not supported.
    """

    scheduler_name = config.training.scheduler.lower()
    if scheduler_name == "reduce_on_plateau":
        return torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
        )
    if scheduler_name == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config.training.n_epochs,
        )
    raise ValueError(f"Unknown scheduler: {config.training.scheduler}")


def get_loss_function(loss_name: str, val_loss: str | None) -> Tuple[nn.Module, nn.Module]:
    """Get loss function by name.

    Parameters
    ----------
    loss_name : str
        Name of the loss function (e.g., 'mse', 'l1', 'gll_loss').
    val_loss : str | None
        Name of the validation loss function, if different from training loss.

    Returns
    -------
    nn.Module
        Instantiated loss function module.

    Raises
    ------
    KeyError
        If the loss function name is not recognized.
    """
    loss_functions = {
        "mse": nn.MSELoss(),
        "l1": nn.L1Loss(),
        "cross_entropy": nn.CrossEntropyLoss(),
        "bce": nn.BCELoss(),
    }
    if loss_name.lower() not in loss_functions:
        raise KeyError(f"Unknown loss function: {loss_name}")
    if val_loss is not None and val_loss.lower() not in loss_functions:
        raise KeyError(f"Unknown validation loss function: {val_loss}")

    loss = loss_functions[loss_name.lower()]
    val_loss = loss if val_loss is None else loss_functions[val_loss.lower()]
    return loss, val_loss


def convert_to_debug_config(config: Config) -> Config:
    """Convert the config to a debug configuration for quick testing.

    This function modifies the configuration to use fewer training steps for faster debugging and testing.

    Parameters
    ----------
    config : Config
        Original configuration object.

    Returns
    -------
    Config
        Modified configuration object with debug settings.
    """

    # Modify training parameters for debug mode
    config.training.n_epochs = 3
    config.training.batch_size = 2
    config.training.device = "cpu"
    config.training.n_workers = 2
    config.training.use_wandb = False

    return config


def set_random_seed(seed: int) -> None:
    """Set random seed for reproducibility.

    This function sets the random seed for PyTorch, NumPy, and Python's
    random module to ensure reproducible results.

    Parameters
    ----------
    seed : int
        Random seed to use.
    """
    import random
    import numpy as np

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Set deterministic behavior for CUDA operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
