import torch
import torch.nn as nn
import torch.nn.functional as F

class FGN(nn.Module):
    def __init__(self, pre_length, embed_size,
                 feature_size, seq_length, hidden_size, hard_thresholding_fraction=1, hidden_size_factor=1, sparsity_threshold=0.01):
        super().__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.number_frequency = 1
        self.pre_length = pre_length
        self.feature_size = feature_size
        self.seq_length = seq_length
        self.frequency_size = self.embed_size // self.number_frequency
        self.hidden_size_factor = hidden_size_factor
        self.sparsity_threshold = sparsity_threshold
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.scale = 0.02
        self.embeddings_1 = nn.Parameter(torch.randn(self.feature_size, self.embed_size))
        self.embeddings_2 = nn.Parameter(torch.randn(self.seq_length, self.embed_size))
        self.w1 = nn.Parameter(
            self.scale * torch.randn(2, self.number_frequency, self.frequency_size, self.frequency_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.number_frequency, self.frequency_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(
            self.scale * torch.randn(2, self.number_frequency, self.frequency_size * self.hidden_size_factor, self.frequency_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.number_frequency, self.frequency_size))
        self.w3 = nn.Parameter(
            self.scale * torch.randn(2, self.number_frequency, self.frequency_size,
                                     self.frequency_size * self.hidden_size_factor))
        self.b3 = nn.Parameter(
            self.scale * torch.randn(2, self.number_frequency, self.frequency_size * self.hidden_size_factor))
        self.fc = nn.Sequential(
                nn.Linear(self.seq_length*self.embed_size, self.hidden_size),
                nn.LeakyReLU(),
                nn.Linear(self.hidden_size, self.pre_length)
               )
        self.embeddings_10 = nn.Parameter(torch.randn(self.seq_length, 8))
        self.fc = nn.Sequential(
            nn.Linear(self.embed_size * 8, 256),
            nn.LeakyReLU(),
            nn.Linear(256, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.pre_length)
        )
        self.to('cuda:0')

    def tokenEmb(self, x):
        x = x.unsqueeze(3)
        y = self.embeddings_1
        y = y.unsqueeze(1)
        z = self.embeddings_2
        z = z.unsqueeze(0)
        return x * y * z #[B, N, T, d]

    def fourierGC(self, x, B, N, L):
        o1_real = torch.zeros([B, N, L // 2 + 1, self.number_frequency, self.frequency_size * self.hidden_size_factor],
                              device=x.device)
        o1_imag = torch.zeros([B, N, L // 2 + 1, self.number_frequency, self.frequency_size * self.hidden_size_factor],
                              device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)

        o3_real = torch.zeros(x.shape, device=x.device)
        o3_imag = torch.zeros(x.shape, device=x.device)

        total_modes = L // 2 + 1
        kept_modes = int(total_modes * self.hard_thresholding_fraction)

		# layer 1
		'''
		complex-valued multiplications:
		For a complex value m=a+bi, a is the real part and b is the imaginary part,
		For a complex value n=c+di, c is the real part and d is the imaginary part,
		then, m*n = (a+bi)(c+di) = (ac-bd)+(ad+bc)i, (ac-bd) is the real part and (ad+bc) is the imaginary part.
		o1_real, o2_real and o3_real are the real part, o1_imag, o2_imag and o3_imag are the imaginary part.
		'''
        o1_real[:, :, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, :, :kept_modes].real, self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, :, :kept_modes].imag, self.w1[1]) + \
            self.b1[0]
        )

        o1_imag[:, :, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, :, :kept_modes].imag, self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, :, :kept_modes].real, self.w1[1]) + \
            self.b1[1]
        )

        y = torch.stack([o1_real, o1_imag], dim=-1)
        y = F.softshrink(y, lambd=self.sparsity_threshold)

		# layer 2
        o2_real[:, :, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', o1_real[:, :, :kept_modes], self.w2[0]) - \
            torch.einsum('...bi,bio->...bo', o1_imag[:, :, :kept_modes], self.w2[1]) + \
            self.b2[0]
        )

        o2_imag[:, :, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', o1_imag[:, :, :kept_modes], self.w2[0]) + \
            torch.einsum('...bi,bio->...bo', o1_real[:, :, :kept_modes], self.w2[1]) + \
            self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = x + y

		# layer 3
        o3_real[:, :, :kept_modes] = F.relu(
                torch.einsum('...bi,bio->...bo', o2_real[:, :, :kept_modes], self.w3[0]) - \
                torch.einsum('...bi,bio->...bo', o2_imag[:, :, :kept_modes], self.w3[1]) + \
                self.b3[0]
        )

        o3_imag[:, :, :kept_modes] = F.relu(
                torch.einsum('...bi,bio->...bo', o2_imag[:, :, :kept_modes], self.w3[0]) + \
                torch.einsum('...bi,bio->...bo', o2_real[:, :, :kept_modes], self.w3[1]) + \
                self.b3[1]
        )

        z = torch.stack([o3_real, o3_imag], dim=-1)
        z = F.softshrink(z, lambd=self.sparsity_threshold)
        z = z + x
        z = torch.view_as_complex(z)
        return z

    def forward(self, x):
        x = x.permute(0, 2, 1).contiguous()
        # embedding B*N*L ==> B*N*L*D
        x = self.tokenEmb(x)

        B, N, L, D = x.shape

        # FFT B*N*L*D ==> B*N*T*D
        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")

        x = x.reshape(B, N, L//2+1, self.number_frequency, self.frequency_size)

        bias = x

		# EV-FGN
        x = self.fourierGC(x, B, N, L)

        x = x + bias

        # ifft
        x = x.reshape(B, N, L//2 + 1, D)
        x = torch.fft.irfft2(x, s=(N, L), dim=(1, 2), norm="ortho")

        x = x.permute(0, 1, 3, 2)  # B, N, D, L
        x = torch.matmul(x, self.embeddings_10)
        x = x.reshape(B, N, -1)
        x = self.fc(x)

        return x

