import torch
from torch import nn
from box import Box


class BaseModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super(BaseModule, self).__init__()
        self.hparams = Box(kwargs)


def get_activation(nonlinear_activation: str, nonlinear_activation_params: dict = {}):
    """
    Get activation function from torch.nn module.
    Parameters:
        nonlinear_activation (str): Name of the activation function.
        nonlinear_activation_params (dict): Parameters for the activation function.
    Returns:
        nn.Module: Activation function module
    """
    if hasattr(nn, nonlinear_activation):
        return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
    else:
        raise NotImplementedError(
            f"Activation {nonlinear_activation} is not supported!"
        )


def get_norm(norm: str | None):
    """
    Get normalization layer from torch.nn module.
    Parameters:
        norm (str): Name of the normalization layer.
        norm_params (dict): Parameters for the normalization layer.
    Returns:
        nn.Module: Normalization layer module
    """
    if norm is None:
        return nn.Identity
    if hasattr(nn, norm):
        return getattr(nn, norm)
    else:
        raise NotImplementedError(f"Normalization {norm} is not supported!")


def collate_tensors(batch: list, fill_value: int = 0):
    """
    Collate tensors with different sizes.
    Parameters:
        batch (list): List of tensors.
        fill_value (int): Value to fill the tensor.
    Returns:
        torch.Tensor: Collated tensor.
    """

    if len(batch) == 0:
        return torch.tensor([])

    dims = batch[0].dim()
    max_size = [max(b.size(i) for b in batch) for i in range(dims)]
    canvas = torch.full(
        (len(batch),) + tuple(max_size),
        fill_value=fill_value,
        dtype=batch[0].dtype,
        device=batch[0].device,
    )
    for i, b in enumerate(batch):
        slices = tuple(slice(0, s) for s in b.shape)
        canvas[i][slices] = b

    return canvas
