import math

import torch
import torch.nn as nn

from .conv2d import Conv2DMod


class NormBatch(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return (x-x.mean(dim=1,keepdim=True)) / x.std(dim=1,keepdim=True)


class FlattenAndLinear(nn.Linear):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x = x.view(-1, self.in_features)
        x = super().forward(x)
        return x
    

class LinearUntil(nn.Linear):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x_until = x[:, :self.in_features]
        x_from = x[:, self.in_features:]
        x_until = super().forward(x_until)
        x = torch.cat((x_until, x_from), dim=1)
        return x
    

class LinearFrom(nn.Linear):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x_until = x[:, :self.in_features]
        x_from = x[:, self.in_features:]
        x_from = super().forward(x_from)
        x = torch.cat((x_until, x_from), dim=1)
        return x


class ConvAndFlatten(nn.Module):

    def __init__(self, channels, kernel_size=1):
        super().__init__()
        self.conv = Conv2DMod(
            channels, channels, kernel_size=kernel_size, stride=1
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        return x
    

class ImagenizeAndConv(nn.Module):

    def __init__(self, size, kernel_size=1):
        super().__init__()
        self.size = size
        self.conv = Conv2DMod(
            size[0], size[0], kernel_size=kernel_size, stride=1
        )

    def forward(self, x):
        x = x.view(x.shape[0], *self.size)
        x = self.conv(x)
        return x


class SelectUntil(nn.Module):

    def __init__(self, idx):
        super().__init__()
        self.idx = idx

    def forward(self, x):
        x = x[:,:self.idx]
        if len(x.shape)>2:
            x = x.view(x.shape[0], -1)
        return x
    
    
class SelectFrom(nn.Module):

    def __init__(self, idx):
        super().__init__()
        self.idx = idx

    def forward(self, x):
        x = x[:,self.idx:]
        if len(x.shape)>2:
            x = x.view(-1, int(math.prod(list(x.shape[1:]))))
        return x