import numpy as np


import torch
from torch import nn
from torch.nn import Module, Linear, ReLU, SiLU, Tanh


class ParallelLayerV1(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.num_fourier = kwargs.get('num_fourier', 4)
        self.num_DoG = kwargs.get('num_DoG', 3)
        self.num_gaussian = kwargs.get('num_gaussian', 3)
        self.num_params = 3
        self.bn1 = nn.BatchNorm1d(in_dim)
        self.relu = ReLU()
        self.silu = SiLU()
        self.tanh = Tanh()
        self.bn2 = nn.BatchNorm1d(in_dim * self.num_params)
        self.fc = Linear(in_dim * self.num_params, out_dim)


    def forward(self, x):
        N, C = x.shape
        x = self.bn1(x)

        out = x.unsqueeze(2)
        relu = self.relu(out)
        silu = self.silu(out)
        tanh = self.tanh(out)
        out = [relu, silu, tanh]
        out = torch.cat(out, dim=2)
        out = out.view(N, -1)
        out = self.bn2(out)
        out = self.fc(out)
        return out


class ParallelLayerV2(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.num_fourier = kwargs.get('num_fourier', 4)
        self.num_DoG = kwargs.get('num_DoG', 3)
        self.num_gaussian = kwargs.get('num_gaussian', 3)
        self.num_params = 3 + 2 * self.num_fourier + self.num_DoG + self.num_gaussian
        self.DoG_mu = nn.Parameter(torch.linspace(0, 1, self.num_DoG).view(1, 1, -1), requires_grad=True)
        self.gaussian_mu = nn.Parameter(torch.linspace(0, 1, self.num_gaussian).view(1, 1, -1), requires_grad=True)
        self.DoG_sigma = nn.Parameter(torch.ones(1, 1, self.num_DoG), requires_grad=True)
        self.gaussian_sigma = nn.Parameter(torch.ones(1, 1, self.num_gaussian), requires_grad=True)
        self.t_values = nn.Parameter(
            torch.tensor([np.pi * t / 2 for t in range(1, self.num_fourier + 1)], dtype=torch.float32).view(1, 1, -1),
            requires_grad=False)
        self.bn1 = nn.BatchNorm1d(in_dim)
        self.relu = ReLU()
        self.silu = SiLU()
        self.tanh = Tanh()
        self.bn2 = nn.BatchNorm1d(in_dim * self.num_params)
        self.fc = Linear(in_dim * self.num_params, out_dim)


    def forward(self, x):
        N, C = x.shape
        x = self.bn1(x)

        out = x.unsqueeze(2)
        relu = self.relu(out)
        silu = self.silu(out)
        tanh = self.tanh(out)
        diff = out - self.DoG_mu
        DoG = - diff * torch.exp(-diff**2 / self.DoG_sigma**2)
        diff = out - self.gaussian_mu
        gaussian = torch.exp(-diff**2 / self.gaussian_sigma ** 2)
        out = out * self.t_values
        sin = torch.sin(out)
        cos = torch.cos(out)
        out = [relu, silu, tanh, DoG, gaussian, sin, cos]
        out = torch.cat(out, dim=2)
        out = out.view(N, -1)
        out = self.bn2(out)
        out = self.fc(out)
        return out


class DoGLayer(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.num_gaussian = kwargs.get('num_gaussian', 3)
        self.bn1 = nn.BatchNorm1d(in_dim)
        self.mu = nn.Parameter(torch.linspace(0, 1, self.num_gaussian).view(1, 1, -1), requires_grad=True)
        self.sigma = nn.Parameter(torch.ones(1, 1, self.num_gaussian), requires_grad=True)
        self.bn2 = nn.BatchNorm1d(in_dim * self.num_gaussian)
        self.fc = Linear(in_dim * self.num_gaussian, out_dim)

    def forward(self, x):
        '''
        :param x: [N, in_dim]
        :return: [N, out_dim]
        '''
        N, in_dim = x.shape
        x = self.bn1(x)
        x_expanded = x.unsqueeze(2)  # [N, in_dim, 1]
        mu_expanded = self.mu  # [1, 1, num_gaussian]
        sigma_expanded = self.sigma  # [1, 1, num_gaussian]

        # Calculate the Gaussian Derivative
        diff = x_expanded - mu_expanded  # [N, in_dim, num_gaussian]
        gaussian_derivative = -diff / (sigma_expanded ** 2) * torch.exp(-diff ** 2 / (2 * sigma_expanded ** 2))

        # Reshape and pass through fully connected layer
        gaussian_derivative = gaussian_derivative.view(N, -1)  # [N, in_dim * num_gaussian]
        gaussian_derivative = self.bn2(gaussian_derivative)
        out = self.fc(gaussian_derivative)  # [N, out_dim]
        return out


class GaussianLayer(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.num_gaussian = kwargs.get('num_gaussian', 3)
        self.bn1 = nn.BatchNorm1d(in_dim)
        self.mu = nn.Parameter(torch.linspace(0, 1, self.num_gaussian).view(1, 1, -1), requires_grad=True)
        self.sigma = nn.Parameter(torch.ones(1, 1, self.num_gaussian), requires_grad=True)
        self.bn2 = nn.BatchNorm1d(in_dim * self.num_gaussian)
        self.fc = Linear(in_dim * self.num_gaussian, out_dim)

    def forward(self, x):
        '''

        :param x: [N, in_dim]
        :return: [N, out_dim]
        '''
        N, in_dim = x.shape
        out = self.bn1(x)
        out = out.unsqueeze(2)
        out = torch.exp(-(out - self.mu) ** 2).view(N, -1)
        out = self.bn2(out)
        out = self.fc(out)
        out = out.view(N, -1)
        return out


class FourierLayer(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.F = kwargs.get('num_fourier', 4)
        self.bn1 = nn.BatchNorm1d(in_dim)
        t_values = nn.Parameter(torch.tensor([t / 2 for t in range(1, self.F + 1)], dtype=torch.float32), requires_grad=False)
        self.t_values_expanded = nn.Parameter(t_values.view(1, 1, -1), requires_grad=False)
        self.bn2 = nn.BatchNorm1d(2 * in_dim * self.F)
        self.fc = Linear(2 * in_dim * self.F, out_dim)

    def forward(self, x):
        '''

        :param x: [N, in_dim]
        :return:
        '''
        N, _ = x.shape
        x = self.bn1(x)
        out = x.unsqueeze(2)
        sin_components = torch.sin(self.t_values_expanded * out)
        cos_components = torch.cos(self.t_values_expanded * out)
        out = torch.cat([sin_components, cos_components], dim=2).view(N, -1)
        out = self.bn2(out)
        out = self.fc(out)
        return out


class PolyLayer(Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.Poly_size = kwargs.get('Poly_size', 4)
        self.repeat = kwargs.get('repeat', 1)

        self.bn1 = nn.BatchNorm1d(self.in_dim)
        self.bn2 = nn.BatchNorm1d(self.in_dim * self.Poly_size * self.repeat)
        self.fc = Linear(self.in_dim * self.Poly_size * self.repeat, self.out_dim)

    def forward(self, x):
        '''

        :param x: [N, in_dim]
        :return: x: [N, out_dim]
        '''
        N, in_dim = x.shape
        x = self.bn1(x)
        out = []
        for i in range(self.Poly_size):
            out.append(x ** (i + 1))
        out = torch.stack(out, dim=2).unsqueeze(-1).expand(N, self.in_dim, self.Poly_size, self.repeat)
        out = out.reshape(N, -1)  # [N, in_dim * Poly_size * repeat]
        out = self.bn2(out)
        out = self.fc(out)
        return out


class KAAN(Module):
    def __init__(self, model_shape, module, **kwargs):
        super().__init__()
        self.width = model_shape
        self.module = module

        self.layers = nn.ModuleList()
        for i in range(len(self.width) - 1):
            self.layers.append(self.module(self.width[i], self.width[i + 1], **kwargs))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class MLP_tanh(Module):
    def __init__(self, model_shape):
        super().__init__()
        self.bn = nn.BatchNorm1d(model_shape[0])
        self.layers = nn.ModuleList()
        for i in range(len(model_shape) - 2):
            self.layers.append(nn.Sequential(
                Linear(model_shape[i], model_shape[i + 1]),
                nn.BatchNorm1d(model_shape[i + 1]),
                Tanh(),
            ))
        self.fc = Linear(model_shape[-2], model_shape[-1])

    def forward(self, x):
        x = self.bn(x)
        for layer in self.layers:
            x = layer(x)
        x = self.fc(x)
        return x


class MLP_silu(Module):
    def __init__(self, model_shape):
        super().__init__()
        self.bn = nn.BatchNorm1d(model_shape[0])
        self.layers = nn.ModuleList()
        for i in range(len(model_shape) - 2):
            self.layers.append(nn.Sequential(
                Linear(model_shape[i], model_shape[i + 1]),
                nn.BatchNorm1d(model_shape[i + 1]),
                SiLU(),
            ))
        self.fc = Linear(model_shape[-2], model_shape[-1])

    def forward(self, x):
        x = self.bn(x)
        for layer in self.layers:
            x = layer(x)
        x = self.fc(x)
        return x


class MLP_relu(Module):
    def __init__(self, model_shape):
        super().__init__()
        self.bn = nn.BatchNorm1d(model_shape[0])
        self.layers = nn.ModuleList()
        for i in range(len(model_shape) - 2):
            self.layers.append(nn.Sequential(
                Linear(model_shape[i], model_shape[i + 1]),
                nn.BatchNorm1d(model_shape[i + 1]),
                ReLU(),
            ))
        self.fc = Linear(model_shape[-2], model_shape[-1])

    def forward(self, x):
        x = self.bn(x)
        for layer in self.layers:
            x = layer(x)
        x = self.fc(x)
        return x
