import torch as th
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F

from flash_attn.modules.mha import MHA

class MemoryEfficientBottleneckFunction(Function):
    @staticmethod
    def forward(ctx, input, weight1, bias1, weight2, bias2):

        # reshape input tensor to 2D
        B, N, C = input.shape
        input = input.reshape(B * N, -1)

        # First linear layer
        output1 = th.matmul(input, weight1.t()) + bias1
        
        # SiLU activation function using x * sigmoid(x)
        output2 = output1 * th.sigmoid(output1)

        # Second linear layer
        output3 = th.matmul(output2, weight2.t()) + bias2
        
        # Save input tensor for backward pass
        ctx.save_for_backward(input, weight1, bias1, weight2)
        
        return output3.reshape(B, N, -1)

    @staticmethod
    def backward(ctx, grad_output):
        input, weight1, bias1, weight2 = ctx.saved_tensors

        B, N, C = grad_output.shape
        grad_output = grad_output.reshape(B * N, -1)

        # Recalculate necessary outputs for backward pass
        # First linear layer
        output1 = th.matmul(input, weight1.t()) + bias1
        
        # SiLU activation function using x * sigmoid(x)
        output1_sigmoid = th.sigmoid(output1)
        output2 = output1 * output1_sigmoid

        # Gradients for second linear layer
        grad_output2 = grad_output
        grad_weight2 = th.matmul(grad_output2.t(), output2)
        grad_bias2 = grad_output2.sum(dim=0)
        grad_output1 = th.matmul(grad_output2, weight2)

        # Gradients for SiLU activation function
        grad_silu = grad_output1 * output1_sigmoid + output1 * grad_output1 * output1_sigmoid * (1 - output1_sigmoid)

        # Gradients for first linear layer
        grad_input = th.matmul(grad_silu, weight1).reshape(B, N, -1)
        grad_weight1 = th.matmul(grad_silu.t(), input)
        grad_bias1 = grad_silu.sum(dim=0)

        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2

class MemoryEfficientBottleneck(th.nn.Module):
    def __init__(self, in_features, out_features):
        super(MemoryEfficientBottleneck, self).__init__()
        self.weight1 = th.nn.Parameter(th.randn(out_features * 4, in_features))
        self.bias1   = th.nn.Parameter(th.zeros(out_features * 4))
        self.weight2 = th.nn.Parameter(th.randn(out_features, out_features * 4))
        self.bias2   = th.nn.Parameter(th.zeros(out_features))

        th.nn.init.xavier_uniform_(self.weight1)
        th.nn.init.xavier_uniform_(self.weight2)

    def forward(self, input):
        return MemoryEfficientBottleneckFunction.apply(input, self.weight1, self.bias1, self.weight2, self.bias2)

class AttentionBlock(nn.Module):

    def __init__(self, dim):
        super(AttentionBlock, self).__init__()
        self.mixer = MHA(dim, num_heads=dim // 64)
        self.norm1 = nn.LayerNorm(dim)
        self.mlp   = MemoryEfficientBottleneck(dim, dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, input):
        residual = self.norm1(input.to(dtype=self.norm1.weight.dtype))
        input    = input + self.mixer(residual)
        residual = self.norm2(input)
        residual = self.mlp(residual)
        return input + residual

