import torch.nn as nn
import torch as th
import numpy as np
import nn as nn_modules
from einops import rearrange, repeat, reduce
from torch.autograd import Function
import torch.nn.functional as F
import math
from utils.utils import LambdaModule

from typing import Union, Tuple


class MemoryEfficientBottleneckFunction(Function):
    @staticmethod
    def forward(ctx, input, weight1, bias1, weight2, bias2):

        # reshape input tensor to 2D
        B, C, H, W = input.shape
        input = input.permute(0, 2, 3, 1).reshape(B * H * W, -1)

        # First linear layer
        output1 = th.matmul(input, weight1.t()) + bias1
        
        # SiLU activation function using x * sigmoid(x)
        output2 = output1 * th.sigmoid(output1)

        # Second linear layer
        output3 = th.matmul(output2, weight2.t()) + bias2
        
        # Save input tensor for backward pass
        ctx.save_for_backward(input, weight1, bias1, weight2)
        
        return output3.reshape(B, H, W, -1).permute(0, 3, 1, 2)

    @staticmethod
    def backward(ctx, grad_output):
        input, weight1, bias1, weight2 = ctx.saved_tensors

        B, C, H, W = grad_output.shape
        grad_output = grad_output.permute(0, 2, 3, 1).reshape(B * H * W, -1)

        # Recalculate necessary outputs for backward pass
        # First linear layer
        output1 = th.matmul(input, weight1.t()) + bias1
        
        # SiLU activation function using x * sigmoid(x)
        output1_sigmoid = th.sigmoid(output1)
        output2 = output1 * output1_sigmoid

        # Gradients for second linear layer
        grad_output2 = grad_output
        grad_weight2 = th.matmul(grad_output2.t(), output2)
        grad_bias2 = grad_output2.sum(dim=0)
        grad_output1 = th.matmul(grad_output2, weight2)

        # Gradients for SiLU activation function
        grad_silu = grad_output1 * output1_sigmoid + output1 * grad_output1 * output1_sigmoid * (1 - output1_sigmoid)

        # Gradients for first linear layer
        grad_input = th.matmul(grad_silu, weight1).reshape(B, H, W, -1).permute(0, 3, 1, 2)
        grad_weight1 = th.matmul(grad_silu.t(), input)
        grad_bias1 = grad_silu.sum(dim=0)

        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2

class MemoryEfficientBottleneck(th.nn.Module):
    def __init__(self, in_features, out_features):
        super(MemoryEfficientBottleneck, self).__init__()
        self.weight1 = th.nn.Parameter(th.randn(out_features * 4, in_features))
        self.bias1   = th.nn.Parameter(th.zeros(out_features * 4))
        self.weight2 = th.nn.Parameter(th.randn(out_features, out_features * 4))
        self.bias2   = th.nn.Parameter(th.zeros(out_features))

        th.nn.init.xavier_uniform_(self.weight1)
        th.nn.init.xavier_uniform_(self.weight2)

    def forward(self, input):
        return MemoryEfficientBottleneckFunction.apply(input, self.weight1, self.bias1, self.weight2, self.bias2)

class ConvNeXtBlock(nn.Module):
    def __init__(
            self,
            channels: int,
            alpha: float = 1e-6,
            num_norm_groups: int = 1
        ):

        super(ConvNeXtBlock, self).__init__()

        assert(channels % num_norm_groups == 0)
        assert alpha <= 1, 'invalid alpha value in ConvNeXtBlock'

        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels),
            nn.GroupNorm(num_norm_groups, channels),
            MemoryEfficientBottleneck(channels, channels),
        )
        
        self.alpha = nn.Parameter(th.ones(1, channels, 1, 1) * alpha)

    def forward(self, input: th.Tensor) -> th.Tensor:
        return input + self.alpha * self.layers(input)

class ConvNeXtStem(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 4,
        ):

        super(ConvNeXtStem, self).__init__()
        
        self.kernel_size = kernel_size
        self.layers = nn.Linear(in_channels * kernel_size**2, out_channels)

    def forward(self, input: th.Tensor) -> th.Tensor:
        K = self.kernel_size
        input = rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2 = K, w2 = K)
        return th.permute(self.layers(th.permute(input, [0, 2, 3, 1])), [0, 3, 1, 2])

class ConvNeXtPatchDown(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 2,
            num_norm_groups: int = 1
        ):

        super(ConvNeXtPatchDown, self).__init__()

        assert(in_channels % num_norm_groups == 0)

        self.kernel_size = kernel_size
        self.layers = nn.Linear(in_channels * kernel_size**2, out_channels)
        self.norm   = nn.GroupNorm(num_norm_groups, in_channels)
        self.num_norm_groups = num_norm_groups

    def forward(self, input: th.Tensor) -> th.Tensor:
        K = self.kernel_size
        input = rearrange(self.norm(input), 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2 = K, w2 = K)
        return th.permute(self.layers(th.permute(input, [0, 2, 3, 1])), [0, 3, 1, 2])

class ConvNeXtPatchUp(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 2,
        ):
        super(ConvNeXtPatchUp, self).__init__()

        self.kernel_size = kernel_size
        self.layers = nn.Linear(in_channels, out_channels * kernel_size**2)

    def forward(self, input: th.Tensor) -> th.Tensor:
        K = self.kernel_size
        output = th.permute(self.layers(th.permute(input, [0, 2, 3, 1])), [0, 3, 1, 2])
        return rearrange(output, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2 = K, w2 = K)

class ConvNeXtEncoder(nn.Module):
    def __init__(
        self, 
        in_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
        norm_group_size = 32, 
        alpha = 1e-6,
        return_features = False,
    ):
        super(ConvNeXtEncoder, self).__init__()
        self.return_features = return_features
        
        if base_channels < norm_group_size:
            norm_group_size = base_channels

        assert(base_channels % norm_group_size == 0)

        self.stem = ConvNeXtStem(in_channels, base_channels)

        self.layer0 = nn.Sequential(*[
            ConvNeXtBlock(base_channels, alpha, base_channels // norm_group_size) for _ in range(blocks[0])
        ])

        self.layer1 = nn.Sequential(
            ConvNeXtPatchDown(base_channels, base_channels * 2),
            *[ConvNeXtBlock(base_channels * 2, alpha, 2 * base_channels // norm_group_size) for _ in range(blocks[1])]
        )

        self.layer2 = nn.Sequential(
            ConvNeXtPatchDown(base_channels * 2, base_channels * 4),
            *[ConvNeXtBlock(base_channels * 4, alpha, 4 * base_channels // norm_group_size) for _ in range(blocks[2])]
        )

        self.layer3 = nn.Sequential(
            ConvNeXtPatchDown(base_channels * 4, base_channels * 8) if blocks[3] > 0 else nn.Identity(),
            *[ConvNeXtBlock(base_channels * 8, alpha, 8 * base_channels // norm_group_size) for _ in range(blocks[3])]
        )
        

    def forward(self, input: th.Tensor):
        
        features  = [self.stem(input)]
        features += [self.layer0(features[-1])]
        features += [self.layer1(features[-1])]
        features += [self.layer2(features[-1])]
        features += [self.layer3(features[-1])]

        if self.return_features:
            return list(reversed(features))

        return features[-1]


class ConvNeXtDecoder(nn.Module):
    def __init__(
        self, 
        out_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
        norm_group_size = 32, 
        alpha = 1e-6,
    ):
        super(ConvNeXtDecoder, self).__init__()
        
        if base_channels < norm_group_size:
            norm_group_size = base_channels

        assert(base_channels % norm_group_size == 0)

        print(blocks)
        self.layer0 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 8, alpha, 8 * base_channels // norm_group_size) for _ in range(blocks[3])],
            ConvNeXtPatchUp(base_channels * 8, base_channels * 4) if blocks[3] > 0 else nn.Identity(),
        )

        self.layer1 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 4, alpha, 4 * base_channels // norm_group_size) for _ in range(blocks[2])],
            ConvNeXtPatchUp(base_channels * 4, base_channels * 2),
        )

        self.layer2 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 2, alpha, 2 * base_channels // norm_group_size) for _ in range(blocks[1])],
            ConvNeXtPatchUp(base_channels * 2, base_channels),
        )

        self.layer3 = nn.Sequential(
            *[ConvNeXtBlock(base_channels, alpha, base_channels // norm_group_size) for _ in range(blocks[0])],
            ConvNeXtPatchUp(base_channels, out_channels, kernel_size=4),
        )

    def forward(self, input: th.Tensor):
        
        x = self.layer0(input)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x


class ConvNeXtUnet(nn.Module):
    def __init__(
        self, 
        in_channels,
        out_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
        norm_group_size = 32, 
        alpha = 1e-6,
    ):
        super(ConvNeXtUnet, self).__init__()
        
        self.encoder = ConvNeXtEncoder(in_channels, base_channels, blocks, norm_group_size, alpha, True)

        if base_channels < norm_group_size:
            norm_group_size = base_channels

        assert(base_channels % norm_group_size == 0)

        self.layer0 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 8, alpha, 8 * base_channels // norm_group_size) for _ in range(blocks[3])],
            ConvNeXtPatchUp(base_channels * 8, base_channels * 4) if blocks[3] > 0 else nn.Identity(),
        )

        self.merge1 = nn.Conv2d(base_channels * 8, base_channels * 4, kernel_size=3, padding=1)
        self.layer1 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 4, alpha, 4 * base_channels // norm_group_size) for _ in range(blocks[1])],
            ConvNeXtPatchUp(base_channels * 4, base_channels * 2),
        )

        self.merge2 = nn.Conv2d(base_channels * 4, base_channels * 2, kernel_size=3, padding=1)
        self.layer2 = nn.Sequential(
            *[ConvNeXtBlock(base_channels * 2, alpha, 2 * base_channels // norm_group_size) for _ in range(blocks[1])],
            ConvNeXtPatchUp(base_channels * 2, base_channels),
        )

        self.merge3 = nn.Conv2d(base_channels * 2, base_channels, kernel_size=3, padding=1)
        self.layer3 = nn.Sequential(
            *[ConvNeXtBlock(base_channels, alpha, base_channels // norm_group_size) for _ in range(blocks[1])],
            ConvNeXtPatchUp(base_channels, out_channels, kernel_size=4),
        )

    def forward(self, input: th.Tensor):

        features = self.encoder(input)

        x = self.layer0(features[0])
        x = self.layer1(self.merge1(th.cat((x, features[1]), dim=1)))
        x = self.layer2(self.merge2(th.cat((x, features[2]), dim=1)))
        x = self.layer3(self.merge3(th.cat((x, features[3]), dim=1)))
        
        return x
