import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

class MemoryEfficientConv3x3ResidualFunction(Function):
    @staticmethod
    def forward(ctx, input, weight1, bias1, weight2, bias2):
        # First 3x3 convolution
        output1 = F.conv2d(input, weight1, bias1, padding=1)

        # SiLU activation function
        output2 = output1 * th.sigmoid(output1)

        # Second 3x3 convolution
        output3 = F.conv2d(output2, weight2, bias2, padding=1)

        # Save for backward pass
        ctx.save_for_backward(input, weight1, bias1, weight2)

        return output3 + input

    @staticmethod
    def backward(ctx, grad_output):
        input, weight1, bias1, weight2 = ctx.saved_tensors
        
        # Recalculate necessary outputs for backward pass
        # First 3x3 convolution
        output1 = F.conv2d(input, weight1, bias1, padding=1)

        # SiLU activation function
        output1_sigmoid = th.sigmoid(output1)
        output2 = output1 * output1_sigmoid

        # Gradients for second 3x3 convolution
        grad_output2 = F.grad.conv2d_input(output2.shape, weight2, grad_output, padding=1)
        grad_weight2 = F.grad.conv2d_weight(output2, weight2.shape, grad_output, padding=1)
        grad_bias2   = grad_output.sum(dim=(0, 2, 3))

        # Gradients for SiLU activation function
        grad_silu = grad_output2 * output1_sigmoid + output1 * grad_output2 * output1_sigmoid * (1 - output1_sigmoid)

        # Gradients for first 3x3 convolution
        grad_input   = F.grad.conv2d_input(input.shape, weight1, grad_silu, padding=1) + grad_output
        grad_weight1 = F.grad.conv2d_weight(input, weight1.shape, grad_silu, padding=1)
        grad_bias1   = grad_silu.sum(dim=(0, 2, 3))

        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2


class MemoryEfficientConv3x3Residual(th.nn.Module):
    def __init__(self, channels, expand_ratio=4):
        super(MemoryEfficientConv3x3Residual, self).__init__()

        self.weight1 = th.nn.Parameter(th.randn(channels * expand_ratio, channels, 3, 3))
        self.bias1   = th.nn.Parameter(th.zeros(channels * expand_ratio))
        self.weight2 = th.nn.Parameter(th.randn(channels, channels * expand_ratio, 3, 3))
        self.bias2   = th.nn.Parameter(th.zeros(channels))

        th.nn.init.xavier_uniform_(self.weight1)
        th.nn.init.xavier_uniform_(self.weight2)

    def forward(self, input):
        return MemoryEfficientConv3x3ResidualFunction.apply(input, self.weight1, self.bias1, self.weight2, self.bias2)

