import numpy as np

import torch
import torch.nn as nn
from torch.nn import Parameter, ModuleList
from torch.nn import Module, Conv2d, BatchNorm2d, ReLU, Unfold, Identity


class ParallelConv2dV1(Module):
    """
    A parallel convolutional layer that applies multiple activation functions
    in parallel and combines their results. Uses unfolding and einsum operations
    for efficient computation.

    Args:
        in_channel (int): Number of input channels.
        out_channel (int): Number of output channels.
        stride (int): Stride for the convolution operation.
        kernel_size (int): Size of the convolution kernel.

    Input:
        x (torch.Tensor): Input tensor of shape (N, C, H, W), where:
            - N is the batch size,
            - C is the number of input channels,
            - H is the height of the input,
            - W is the width of the input.

    Output:
        torch.Tensor: Output tensor of shape (N, out_channel, out_H, out_W), where:
            - out_H and out_W are the height and width of the output after convolution.
    """

    def __init__(self, in_channel, out_channel, stride, kernel_size, **kwargs):
        super(ParallelConv2dV1, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = (kernel_size - 1) // 2

        # Unfold operation to create sliding windows
        self.Unfold = Unfold(kernel_size, padding=self.padding, stride=stride)

        # Activation functions
        self.silu = nn.SiLU()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Parameter for combination of activations
        self.param_length = 3
        self.darts_weight = Parameter(
            torch.Tensor(out_channel, in_channel, kernel_size * kernel_size, self.param_length))
        torch.nn.init.kaiming_normal_(self.darts_weight, mode='fan_out', nonlinearity='leaky_relu')

        # Batch normalization
        self.bn = BatchNorm2d(out_channel)

    def forward(self, x):
        # Input x shape: (N, C, H, W)
        N, C, H, W = x.size()

        # Compute output height and width after the convolution operation
        out_H = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_W = (W + 2 * self.padding - self.kernel_size) // self.stride + 1

        # Step 1: Unfold the input tensor into sliding windows
        # Unfolded shape: (N, C * kernel_size * kernel_size, out_H * out_W)
        # After view, shape: (N, C, kernel_size * kernel_size, out_H, out_W)
        x = self.Unfold(x).view(N, C, self.kernel_size * self.kernel_size, out_H, out_W)

        # Step 2: Add a dimension for parallel activation functions
        # Shape after view: (N, C, kernel_size * kernel_size, 1, out_H, out_W)
        x = x.view(N, C, self.kernel_size * self.kernel_size, 1, out_H, out_W)

        # Step 3: Apply activation functions in parallel
        # Each component (silu, relu, tanh) has shape: (N, C, kernel_size * kernel_size, 1, out_H, out_W)
        silu_component = self.silu(x)
        relu_component = self.relu(x)
        tanh_component = self.tanh(x)

        # Step 4: Concatenate the results along the fourth dimension
        # Shape after concatenation: (N, C, kernel_size * kernel_size, 3, out_H, out_W)
        x = torch.cat([silu_component, relu_component, tanh_component], dim=3)

        # Step 5: Perform einsum operation to combine the activations
        # The einsum performs a matrix multiplication-like operation.
        # Resulting shape: (N, out_channel, out_H, out_W)
        x = torch.einsum('bikfhw, oikf -> bohw', x, self.darts_weight)

        # Final output shape: (N, out_channel, out_H, out_W)
        return x


class ParallelConv2dV2(Module):
    """
    A parallel convolutional layer that applies multiple transformations
    (activation functions, Gaussian, Difference of Gaussians, Fourier transforms)
    in parallel and combines their results. Uses unfolding and einsum operations
    for efficient computation.

    Args:
        in_channel (int): Number of input channels.
        out_channel (int): Number of output channels.
        stride (int): Stride for the convolution operation.
        kernel_size (int): Size of the convolution kernel.
        num_gaussian (int): Number of Gaussian functions to apply in parallel.
        num_DoG (int): Number of Difference of Gaussians (DoG) functions to apply.
        num_fourier (int): Number of Fourier components to apply.

    Input:
        x (torch.Tensor): Input tensor of shape (N, C, H, W), where:
            - N is the batch size,
            - C is the number of input channels,
            - H is the height of the input,
            - W is the width of the input.

    Output:
        torch.Tensor: Output tensor of shape (N, out_channel, out_H, out_W), where:
            - out_H and out_W are the height and width of the output after convolution.
    """

    def __init__(self, in_channel, out_channel, stride, kernel_size, **kwargs):
        super(ParallelConv2dV2, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.num_gaussian = kwargs.get('num_gaussian', 3)
        self.num_DoG = kwargs.get('num_DoG', 3)
        self.num_fourier = kwargs.get('num_fourier', 4)
        self.padding = (kernel_size - 1) // 2

        # Unfold operation to create sliding windows
        self.Unfold = Unfold(kernel_size, padding=self.padding, stride=stride)

        # Gaussian and DoG parameters initialization
        self.gaussian_mu = torch.linspace(-1, 1, self.num_gaussian).repeat(in_channel, kernel_size * kernel_size)
        self.gaussian_mu = Parameter(
            self.gaussian_mu.view(1, in_channel, kernel_size * kernel_size, self.num_gaussian, 1, 1))

        self.DoG_mu = torch.linspace(-1, 1, self.num_DoG).repeat(in_channel, kernel_size * kernel_size)
        self.DoG_mu = Parameter(self.DoG_mu.view(1, in_channel, kernel_size * kernel_size, self.num_DoG, 1, 1))

        # Fourier t-values
        self.t_values = Parameter(
            torch.tensor([np.pi * t / 2 for t in range(1, self.num_fourier + 1)], dtype=torch.float32),
            requires_grad=False)

        # Traditional activation functions
        self.silu = nn.SiLU()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Total number of transformation components
        self.param_length = self.num_DoG + self.num_gaussian + 2 * self.num_fourier + 3

        # Darts weights for einsum operation
        self.darts_weight = Parameter(
            torch.Tensor(out_channel, in_channel, kernel_size * kernel_size, self.param_length))
        torch.nn.init.kaiming_normal_(self.darts_weight, mode='fan_out', nonlinearity='leaky_relu')

        # Batch normalization
        self.bn = BatchNorm2d(out_channel)

    def forward(self, x):
        # Input x shape: (N, C, H, W)
        N, C, H, W = x.size()

        # Compute output height and width after convolution operation
        out_H = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_W = (W + 2 * self.padding - self.kernel_size) // self.stride + 1

        # Step 1: Unfold the input tensor into sliding windows
        # Unfolded shape: (N, C * kernel_size * kernel_size, out_H * out_W)
        # After view, shape: (N, C, kernel_size * kernel_size, out_H, out_W)
        x = self.Unfold(x).view(N, C, self.kernel_size * self.kernel_size, out_H, out_W)

        # Step 2: Add a dimension for parallel transformations
        # Shape after view: (N, C, kernel_size * kernel_size, 1, out_H, out_W)
        x = x.view(N, C, self.kernel_size * self.kernel_size, 1, out_H, out_W)

        # Step 3: Apply activation functions
        silu_component = self.silu(x)
        relu_component = self.relu(x)
        tanh_component = self.tanh(x)

        # Step 3: Apply Gaussian and DoG transformations
        gaussian_component = torch.exp(- (x - self.gaussian_mu) ** 2)
        DoG_component = - (x - self.DoG_mu) * torch.exp(- (x - self.DoG_mu) ** 2)

        # Step 3: Apply Fourier transformations
        x = torch.einsum('bikphw, f -> bikfhw', x, self.t_values)
        fourier_component = torch.cat([torch.sin(x), torch.cos(x)], dim=3)

        # Step 4: Concatenate all the transformation results along the fourth dimension
        # Concatenated shape: (N, C, kernel_size * kernel_size, param_length, out_H, out_W)
        x = torch.cat(
            [silu_component, relu_component, tanh_component, gaussian_component, fourier_component, DoG_component],
            dim=3)

        # Step 5: Perform einsum operation to combine the transformations
        # Final output shape: (N, out_channel, out_H, out_W)
        x = torch.einsum('bikfhw, oikf -> bohw', x, self.darts_weight)

        return x


class BasicResidualV2(Module):
    def __init__(self, module, in_channel, out_channel, stride, **kwargs):
        super(BasicResidualV2, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.stride = stride
        self.module = module
        self.bottleneck = kwargs.get('bottleneck', 4)
        self.c_hidden = out_channel // self.bottleneck
        self.costume_module_flag = kwargs.get('costum_module_flag', False)
        self.kernel_size = kwargs.get('kernel_size', 3)

        self.shortcut = Identity()
        if stride != 1 or in_channel != out_channel:
            self.shortcut = Conv2d(in_channel, out_channel, 1, stride=stride, bias=False)

        self.conv1 = Conv2d(in_channel, self.c_hidden, 1, stride=1, bias=False)
        self.bn1 = BatchNorm2d(self.c_hidden)
        self.relu1 = ReLU(inplace=True)
        self.conv2 = module(self.c_hidden, self.c_hidden, kernel_size=self.kernel_size, stride=stride, padding=1, **kwargs)
        self.bn2 = BatchNorm2d(self.c_hidden)
        self.relu2 = ReLU(inplace=True)
        self.conv3 = Conv2d(self.c_hidden, out_channel, 1, stride=1, bias=False)
        self.bn3 = BatchNorm2d(out_channel)
        self.relu3 = ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if not self.costume_module_flag:
            out = self.relu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += shortcut
        out = self.relu3(out)
        return out


class BasicBlockV2(Module):
    def __init__(self, module, in_channel, out_channel, stride, n_stack, **kwargs):
        super(BasicBlockV2, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.stride = stride
        self.module = module
        self.n_stack = n_stack

        self.layers = ModuleList()
        for i in range(n_stack):
            self.layers.append(BasicResidualV2(module, in_channel, out_channel, stride, **kwargs))
            in_channel = out_channel
            stride = 1

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class CostumeResNetV2(Module):
    def __init__(self, module, model_shape, n_class, **kwargs):
        super().__init__()
        self.module = module
        self.model_shape = model_shape
        self.n_class = n_class

        self.conv_stem = Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = BatchNorm2d(64)
        self.relu = ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.block1 = BasicBlockV2(self.module, 64, 256, 1, model_shape[0], **kwargs)
        self.block2 = BasicBlockV2(self.module, 256, 512, 2, model_shape[1], **kwargs)
        self.block3 = BasicBlockV2(self.module, 512, 1024, 2, model_shape[2], **kwargs)
        self.block4 = BasicBlockV2(self.module, 1024, 2048, 2, model_shape[3], **kwargs)

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(2048, n_class)

    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
