import torch
from torch.nn import Module, Conv2d, BatchNorm2d, ReLU, Identity, ModuleList, Parameter, Unfold


class CostumCNN(Module):
    """
    A custom CNN module that applies a convolution-like operation using unfolding and einsum.

    Args:
        in_shape (tuple): Shape of the input tensor in the form (C, H, W), where:
            - C is the number of input channels,
            - H is the height of the input,
            - W is the width of the input.
        out_channel (int): Number of output channels.
        kernel_size (int): Size of the convolution kernel.
        stride (int): Stride of the convolution.

    Input:
        x (torch.Tensor): Input tensor of shape (N, C, H, W), where:
            - N is the batch size,
            - C is the number of input channels (should match in_shape[0]),
            - H is the height of the input (should match in_shape[1]),
            - W is the width of the input (should match in_shape[2]).

    Output:
        torch.Tensor: Output tensor of shape (N, out_channel, out_H, out_W), where:
            - out_H and out_W are the height and width after applying convolution-like operations.
    """

    def __init__(self, in_shape, out_channel, kernel_size, stride):
        super(CostumCNN, self).__init__()
        self.in_shape = in_shape
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = (kernel_size - 1) // 2
        self.in_channel = in_shape[0]

        # Output height and width computation
        self.out_H = (in_shape[1] + 2 * self.padding - self.kernel_size) // self.stride + 1
        self.out_W = (in_shape[2] + 2 * self.padding - self.kernel_size) // self.stride + 1

        # Number of sequences for unfolded data
        self.sequence_length = self.out_H * self.out_W

        # Unfold layer to create sliding windows over input
        self.Unfold = Unfold(kernel_size, padding=self.padding, stride=stride)

        # Weight initialization
        self.weight = Parameter(
            torch.Tensor(out_channel, self.in_channel * kernel_size * kernel_size, self.sequence_length))
        torch.nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='leaky_relu')

    def forward(self, x):
        # Input x shape: (N, C, H, W)
        N, _, _, _ = x.size()

        # Step 1: Unfold the input tensor into sliding windows
        # Unfolded shape: (N, C * kernel_size * kernel_size, L), where
        # L = out_H * out_W, the number of sliding windows.
        x = self.Unfold(x)

        # Step 2: Perform einsum operation between unfolded input and weight
        # Resulting shape after einsum: (N, out_channel, L), where
        # L = out_H * out_W.
        x = torch.einsum('bil,oil->bol', x, self.weight)

        # Step 3: Reshape the result to (N, out_channel, out_H, out_W)
        x = x.view(N, self.out_channel, self.out_H, self.out_W)

        # Final output shape: (N, out_channel, out_H, out_W)
        return x


class BasicResidualV2(Module):
    def __init__(self, in_shape, out_channel, stride, module, bottleneck=4):
        super(BasicResidualV2, self).__init__()
        self.in_shape = in_shape
        self.out_channel = out_channel
        self.stride = stride
        self.module = module

        self.in_channel = in_shape[0]
        self.out_H = in_shape[1] // self.stride
        self.out_W = in_shape[2] // self.stride
        self.c_hidden = out_channel // bottleneck
        self.hidden_shape = (self.c_hidden, in_shape[1], in_shape[2])

        self.shortcut = Identity()
        if stride != 1 or self.in_channel != out_channel:
            self.shortcut = Conv2d(self.in_channel, out_channel, 1, stride=stride, bias=False)

        self.conv1 = Conv2d(self.in_channel, self.c_hidden, 1, stride=1, bias=False)
        self.bn1 = BatchNorm2d(self.c_hidden)
        self.relu1 = ReLU(inplace=True)
        self.conv2 = module(self.hidden_shape, self.c_hidden, kernel_size=3, stride=stride)
        self.bn2 = BatchNorm2d(self.c_hidden)
        self.relu2 = ReLU(inplace=True)
        self.conv3 = Conv2d(self.c_hidden, out_channel, 1, stride=1, bias=False)
        self.bn3 = BatchNorm2d(out_channel)
        self.relu3 = ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += shortcut
        out = self.relu3(out)
        return out


class BasicBlockV2(Module):
    def __init__(self, in_shape, out_channel, stride, module, n_stack, bottleneck=4):
        super(BasicBlockV2, self).__init__()
        self.in_shape = in_shape
        self.out_channel = out_channel
        self.stride = stride
        self.module = module
        self.n_stack = n_stack

        self.in_channel = in_shape[0]
        self.out_H = in_shape[1] // self.stride
        self.out_W = in_shape[2] // self.stride
        self.output_shape = (out_channel, self.out_H, self.out_W)

        self.layers = ModuleList()
        tensor_shape = in_shape
        for i in range(n_stack):
            self.layers.append(BasicResidualV2(tensor_shape, out_channel, stride, module, bottleneck=bottleneck))
            tensor_shape = self.output_shape
            stride = 1

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class CostumeResNetV2(Module):
    def __init__(self, module, model_shape, n_class, bottleneck=4):
        super().__init__()
        self.module = module
        self.model_shape = model_shape
        self.n_class = n_class
        self.bottleneck = bottleneck

        self.conv_stem = Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = BatchNorm2d(64)
        self.relu = ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        tensor_shape = [64, 56, 56]
        self.block1 = BasicBlockV2(tensor_shape, 256, 1, self.module, model_shape[0], bottleneck=bottleneck)
        tensor_shape = [256, 56, 56]
        self.block2 = BasicBlockV2(tensor_shape, 512, 2, self.module, model_shape[1], bottleneck=bottleneck)
        tensor_shape = [512, 28, 28]
        self.block3 = BasicBlockV2(tensor_shape, 1024, 2, self.module, model_shape[2], bottleneck=bottleneck)
        tensor_shape = [1024, 14, 14]
        self.block4 = BasicBlockV2(tensor_shape, 2048, 2, self.module, model_shape[3], bottleneck=bottleneck)

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(2048, n_class)

    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


class CostumeResNetV3(Module):
    def __init__(self, module, model_shape, n_class, bottleneck=4):
        super().__init__()
        self.module = module
        self.model_shape = model_shape
        self.n_class = n_class
        self.bottleneck = bottleneck

        self.conv_stem = Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = BatchNorm2d(64)
        self.relu = ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        tensor_shape = [64, 8, 8]
        self.block1 = BasicBlockV2(tensor_shape, 256, 1, self.module, model_shape[0], bottleneck=bottleneck)
        tensor_shape = [256, 8, 8]
        self.block2 = BasicBlockV2(tensor_shape, 512, 2, self.module, model_shape[1], bottleneck=bottleneck)
        tensor_shape = [512, 4, 4]
        self.block3 = BasicBlockV2(tensor_shape, 1024, 2, self.module, model_shape[2], bottleneck=bottleneck)
        tensor_shape = [1024, 2, 2]
        self.block4 = BasicBlockV2(tensor_shape, 2048, 2, self.module, model_shape[3], bottleneck=bottleneck)

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(2048, n_class)

    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
