import torch.nn as nn
import torch as th
import numpy as np
import nn as nn_modules
from einops import rearrange, repeat, reduce
from utils.utils import LambdaModule
from torch.nn.utils.parametrizations import spectral_norm as sp_norm
from nn.residual import SkipConnection

from typing import Union, Tuple


class SPConvNeXtBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int = None,
            channels_per_group = 64,
        ):

        super(SPConvNeXtBlock, self).__init__()

        channels_per_group = min(channels_per_group, in_channels)

        if out_channels is None:
            out_channels = in_channels

        self.layers = nn.Sequential(
            sp_norm(nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            LambdaModule(lambda x: th.permute(x, [0, 2, 3, 1])),
            sp_norm(nn.Linear(in_channels, max(in_channels, out_channels)*4)),
            nn.SiLU(inplace=True),
            sp_norm(nn.Linear(max(in_channels, out_channels)*4, out_channels)),
            LambdaModule(lambda x: th.permute(x, [0, 3, 1, 2])),
        )
        
    def forward(self, input: th.Tensor) -> th.Tensor:
        return input + self.layers(input)

class SPPatchUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 2):
        super(SPPatchUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.skip = SkipConnection(in_channels, out_channels, scale_factor=kernel_size)

        self.residual = nn.Sequential(
            nn.SiLU(),
            sp_norm(nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = in_channels, 
                kernel_size  = 3,
                padding      = 1
            )),
            nn.SiLU(),
            sp_norm(nn.ConvTranspose2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = kernel_size,
                stride       = kernel_size,
            )),
        )

    def forward(self, input):
        return self.skip(input) + self.residual(input)

class SPPatchDownscale(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 2):
        super(SPPatchDownscale, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = sp_norm(nn.Linear(in_channels * kernel_size**2, out_channels))

        self.kernel_size     = kernel_size
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        H, W = input.shape[2:]
        K    = self.kernel_size
        C    = self.channels_factor

        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=K, w2=K)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=C)

        input    = rearrange(input, 'b c (h h2) (w w2) -> (b h w) (c h2 w2)', h2=K, w2=K)
        residual = self.layers(input)
        residual = rearrange(residual, '(b h w) c -> b c h w', h = H // K, w = W // K)

        return skip + residual

class SPConvNeXtEncoder(nn.Module):
    def __init__(
        self, 
        in_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
        return_features = False,
    ):
        super(SPConvNeXtEncoder, self).__init__()
        self.return_features = return_features
        
        self.stem = SPPatchDownscale(in_channels, base_channels, kernel_size=4)

        self.layer0 = nn.Sequential(*[
            SPConvNeXtBlock(base_channels) for _ in range(blocks[0])
        ])

        self.layer1 = nn.Sequential(
            SPPatchDownscale(base_channels, base_channels * 2),
            *[SPConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])]
        )

        self.layer2 = nn.Sequential(
            SPPatchDownscale(base_channels * 2, base_channels * 4),
            *[SPConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])]
        )

        self.layer3 = nn.Sequential(
            SPPatchDownscale(base_channels * 4, base_channels * 8) if blocks[3] > 0 else nn.Identity(),
            *[SPConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])]
        )
        

    def forward(self, input: th.Tensor):
        
        features  = [self.stem(input)]
        features += [self.layer0(features[-1])]
        features += [self.layer1(features[-1])]
        features += [self.layer2(features[-1])]
        features += [self.layer3(features[-1])]

        if self.return_features:
            return list(reversed(features))

        return features[-1]


class SPConvNeXtDecoder(nn.Module):
    def __init__(
        self, 
        out_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
    ):
        super(SPConvNeXtDecoder, self).__init__()
        
        self.layer0 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
            SPPatchUpscale(base_channels * 8, base_channels * 4) if blocks[3] > 0 else nn.Identity(),
        )

        self.layer1 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
            SPPatchUpscale(base_channels * 4, base_channels * 2),
        )

        self.layer2 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
            SPPatchUpscale(base_channels * 2, base_channels),
        )

        self.layer3 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels) for _ in range(blocks[0])],
            SPPatchUpscale(base_channels, out_channels, kernel_size=4),
        )

    def forward(self, input: th.Tensor):
        
        x = self.layer0(input)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x


class SPConvNeXtUnet(nn.Module):
    def __init__(
        self, 
        in_channels,
        out_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
    ):
        super(SPConvNeXtUnet, self).__init__()
        
        self.encoder = SPConvNeXtEncoder(in_channels, base_channels, blocks, return_features=True)

        self.layer0 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
            SPPatchUpscale(base_channels * 8, base_channels * 4) if blocks[3] > 0 else nn.Identity(),
        )

        self.merge1 = nn.Conv2d(base_channels * 8, base_channels * 4, kernel_size=3, padding=1)
        self.layer1 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 4) for _ in range(blocks[1])],
            SPPatchUpscale(base_channels * 4, base_channels * 2),
        )

        self.merge2 = nn.Conv2d(base_channels * 4, base_channels * 2, kernel_size=3, padding=1)
        self.layer2 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
            SPPatchUpscale(base_channels * 2, base_channels),
        )

        self.merge3 = nn.Conv2d(base_channels * 2, base_channels, kernel_size=3, padding=1)
        self.layer3 = nn.Sequential(
            *[SPConvNeXtBlock(base_channels) for _ in range(blocks[1])],
            SPPatchUpscale(base_channels, out_channels, kernel_size=4),
        )

    def forward(self, input: th.Tensor):

        features = self.encoder(input)

        x = self.layer0(features[0])
        x = self.layer1(self.merge1(th.cat((x, features[1]), dim=1)))
        x = self.layer2(self.merge2(th.cat((x, features[2]), dim=1)))
        x = self.layer3(self.merge3(th.cat((x, features[3]), dim=1)))
        
        return x
