from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class Tokenizer(nn.Module, ABC):
    
    @abstractmethod
    def encode(self, x: torch.Tensor, state_labels: torch.Tensor | None) -> torch.Tensor:
        pass

    @abstractmethod
    def decode(self, x: torch.Tensor, state_labels: torch.Tensor | None) -> torch.Tensor:
        pass


def conv_module(ndim, transpose):
    if ndim == 1:
        return nn.Conv1d if not transpose else nn.ConvTranspose1d
    elif ndim == 2:
        return nn.Conv2d if not transpose else nn.ConvTranspose2d
    elif ndim == 3:
        return nn.Conv3d if not transpose else nn.ConvTranspose3d
    raise ValueError("ndim should be 1, 2 or 3.")


class Downsample(nn.Module):
    """ Subsample the data tensor """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding_mode, bias, ndim):
        super().__init__()
        self.conv = conv_module(ndim, False)(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding_mode=padding_mode, bias=bias)

    def forward(self, x):
        return self.conv(x)


class Upsample(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding_mode, bias, ndim):
        super().__init__()
        self.convT = conv_module(ndim, True)(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding_mode=padding_mode, bias=bias)

    def forward(self, x):
        return self.convT(x)


class SubsampledLinear(nn.Module):
    """
    Cross between a linear layer and EmbeddingBag - takes in input 
    and list of indices denoting which state variables from the state
    vocab are present and only performs the linear layer on rows/cols relevant
    to those state variables
    
    Assumes (... C) input
    """
    def __init__(self, dim_in, dim_out, subsample_in=True):
        super().__init__()
        self.subsample_in = subsample_in
        self.dim_in = dim_in
        self.dim_out = dim_out
        temp_linear = nn.Linear(dim_in, dim_out)
        self.weight = nn.Parameter(temp_linear.weight)
        self.bias = nn.Parameter(temp_linear.bias)
    
    def forward(self, x, labels):
        # Note - really only works if all batches are the same input type
        # labels = labels[0] # Figure out how to handle this for normal batches later
        label_size = len(labels)
        if self.subsample_in:
            scale = (self.dim_in / label_size) ** .5 # Equivalent to swapping init to correct for given subsample of input
            x = scale * F.linear(x, self.weight[:, labels], self.bias)
        else:
            x = F.linear(x, self.weight[labels], self.bias[labels])
        return x


class RMSGroupNorm(nn.Module):  #TODO: is repeated
    r"""Applies Group Normalization over a mini-batch of inputs as described in
    the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The input channels are separated into :attr:`num_groups` groups, each containing
    ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
    :attr:`num_groups`. The mean and standard-deviation are calculated
    separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
    per-channel affine transform parameter vectors of size :attr:`num_channels` if
    :attr:`affine` is ``True``.
    The standard-deviation is calculated via the biased estimator, equivalent to
    `torch.var(input, unbiased=False)`.

    This layer uses statistics computed from input data in both training and
    evaluation modes.

    Args:
        num_groups (int): number of groups to separate the channels into
        num_channels (int): number of channels expected in input
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        affine: a boolean value that when set to ``True``, this module
            has learnable per-channel affine parameters initialized to ones (for weights)
            and zeros (for biases). Default: ``True``.

    Shape:
        - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
        - Output: :math:`(N, C, *)` (same shape as input)

    Examples::

        >>> input = torch.randn(20, 6, 10, 10)
        >>> # Separate 6 channels into 3 groups
        >>> m = nn.GroupNorm(3, 6)
        >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
        >>> m = nn.GroupNorm(6, 6)
        >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
        >>> m = nn.GroupNorm(1, 6)
        >>> # Activating the module
        >>> output = m(input)
    """
    __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
    num_groups: int
    num_channels: int
    eps: float
    affine: bool

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if num_channels % num_groups != 0:
            raise ValueError('num_channels must be divisible by num_groups')

        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.empty(num_channels, **factory_kwargs))
            self.bias = self.register_parameter('bias', None) #Parameter(torch.empty(num_channels, **factory_kwargs))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.affine:
            nn.init.ones_(self.weight)
            # nn.init.zeros_(self.bias)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.group_norm(
            input, self.num_groups, self.weight, self.bias, self.eps)

    def extra_repr(self) -> str:
        return '{num_groups}, {num_channels}, eps={eps}, ' \
            'affine={affine}'.format(**self.__dict__)


class CNN(nn.Module):
    """ Image to Patch Embedding """

    def __init__(
        self, in_chans, embed_dim=768, spatial_ndims=2, groups=12, padding_mode='reflect', 
        customize=False, finetune=False
    ):
        super().__init__()
        self.spatial_ndims = spatial_ndims
        self.customize = customize

        if self.customize:
            n_states = 12
            if finetune:
                n_states += 5  # additional states for finetuning on shearflow or euler fields
            self.space_bag = SubsampledLinear(dim_in=n_states, dim_out=embed_dim//4, subsample_in=True)
            self.encoder1d = nn.Sequential(*[
                Downsample(embed_dim//4, embed_dim//4, kernel_size=4, stride=4, padding_mode=padding_mode, bias=False, ndim=1),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=1),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=1),
                RMSGroupNorm(groups, embed_dim, affine=True),
            ])
            self.encoder2d = nn.Sequential(*[
                Downsample(embed_dim//4, embed_dim//4, kernel_size=4, stride=4, padding_mode=padding_mode, bias=False, ndim=2),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=2),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=2),
                RMSGroupNorm(groups, embed_dim, affine=True),
            ])
        else:
            self.encoder = nn.Sequential(*[
                conv_module(spatial_ndims, False)(in_chans, embed_dim//4, kernel_size=1, stride=1, padding_mode=padding_mode, bias=True),
                Downsample(embed_dim//4, embed_dim//4, kernel_size=4, stride=4, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Downsample(embed_dim//4, embed_dim, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
                RMSGroupNorm(groups, embed_dim, affine=True),
            ])

        if customize: 
            self.decoder1d = nn.Sequential(*[
                Upsample(embed_dim, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=1),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Upsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=1),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
            ])
            self.decoder2d = nn.Sequential(*[
                Upsample(embed_dim, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=2),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Upsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=2),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
            ])
            out_head1d = conv_module(1, True)(embed_dim//4, n_states, kernel_size=4, stride=4)  # included in the decoder, won't be a problem
            self.out_kernel1d = nn.Parameter(out_head1d.weight)
            self.out_bias1d = nn.Parameter(out_head1d.bias)
            out_head2d = conv_module(2, True)(embed_dim//4, n_states, kernel_size=4, stride=4)
            self.out_kernel2d = nn.Parameter(out_head2d.weight)
            self.out_bias2d = nn.Parameter(out_head2d.bias)
        else:
            self.decoder = nn.Sequential(*[
                Upsample(embed_dim, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=spatial_ndims),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Upsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode='zeros', bias=False, ndim=spatial_ndims),
                RMSGroupNorm(groups, embed_dim//4, affine=True),
                nn.GELU(),
                Upsample(embed_dim//4, in_chans, kernel_size=4, stride=4, padding_mode='zeros', bias=True, ndim=spatial_ndims),
            ])

    def encode(self, x, state_labels):
        indims = x.ndim
        if self.customize:
            # first linear depending on the type of fields ('state') present
            x = rearrange(x, 'b c h w -> b h w c')
            x = self.space_bag(x, state_labels)
            x = rearrange(x, 'b h w c -> b c h w')
        x = x.squeeze((-2, -1))
        spatial_ndims = x.ndim - 2
        if self.customize and spatial_ndims == 1:
            x = self.encoder1d(x)
        elif self.customize and spatial_ndims == 2:
            x = self.encoder2d(x)
        else:
            x = self.encoder(x)
        if x.ndim < indims:
            x = x.unsqueeze(-1)
        return x

    def decode(self, x, state_labels):
        indims = x.ndim
        x = x.squeeze((-2, -1))
        spatial_ndims = x.ndim - 2
        if self.customize and spatial_ndims == 1:
            x = self.decoder1d(x)
            x = F.conv_transpose1d(x, self.out_kernel1d[:, state_labels], self.out_bias1d[state_labels], stride=4)
        elif self.customize and spatial_ndims == 2:
            x = self.decoder2d(x)
            x = F.conv_transpose2d(x, self.out_kernel2d[:, state_labels], self.out_bias2d[state_labels], stride=4)
        else:
            x = self.decoder(x)
        if x.ndim < indims:
            x = x.unsqueeze(-1)
        return x
