import torch
import torch.nn as nn
import torch.nn.functional as F
# reference: https://github.com/aikunyi/FourierGNN
class FGN(nn.Module):
    def __init__(self, pre_length, embed_size,
                 feature_size, seq_length, hidden_size, hard_thresholding_fraction=1, hidden_size_factor=1, sparsity_threshold=0.01):
        super().__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.number_frequency = 1
        self.pre_length = pre_length
        self.feature_size = feature_size
        self.seq_length = seq_length
        self.frequency_size = self.embed_size // self.number_frequency
        self.hidden_size_factor = hidden_size_factor
        self.sparsity_threshold = sparsity_threshold
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.scale = 1
        self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))

        # self.w1, self.w2,self.w3 = (2,128,128)
        # self.b1, self.b2, self.b3 = (2,128)
        self.w1 = nn.Parameter(
            self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(
            self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor, self.frequency_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size))
        self.w3 = nn.Parameter(
            self.scale * torch.randn(2, self.frequency_size,
                                     self.frequency_size * self.hidden_size_factor))
        self.b3 = nn.Parameter(
            self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
        self.embeddings_10 = nn.Parameter(torch.randn(self.seq_length, 8))
        self.fc = nn.Sequential(
            nn.Linear(self.embed_size * 8, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.pre_length)
        )
        self.dp = nn.Dropout(0.5)
        self.fc1 = nn.Sequential(nn.Linear(14, 1), nn.ReLU())

    def tokenEmb(self, x):
        x = x.unsqueeze(2)  # torch.Size([128, 420, 1])
        y = self.embeddings  # torch.Size([1, 128])
        return x * y

    # FourierGNN
    def fourierGC(self, x, B, N, L):

        #o1_real: (128, 211, 128)
        o1_real = F.relu(
            torch.einsum('bli,ii->bli', x.real, self.w1[0]) - \
            torch.einsum('bli,ii->bli', x.imag, self.w1[1]) + \
            self.b1[0]
        )
        #o1_imag: (128, 211, 128)
        o1_imag = F.relu(
            torch.einsum('bli,ii->bli', x.imag, self.w1[0]) + \
            torch.einsum('bli,ii->bli', x.real, self.w1[1]) + \
            self.b1[1]
        )
        # 1 layer
        y = torch.stack([o1_real, o1_imag], dim=-1)  # (128, 211, 128, 2)
        y = F.softshrink(y, lambd=self.sparsity_threshold)  # (128, 211, 128, 2)

        o2_real = F.relu(
            torch.einsum('bli,ii->bli', o1_real, self.w2[0]) - \
            torch.einsum('bli,ii->bli', o1_imag, self.w2[1]) + \
            self.b2[0]
        )

        o2_imag = F.relu(
            torch.einsum('bli,ii->bli', o1_imag, self.w2[0]) + \
            torch.einsum('bli,ii->bli', o1_real, self.w2[1]) + \
            self.b2[1]
        )

        # 2 layer
        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = x + y  # (128, 211, 128, 2)

        o3_real = F.relu(
                torch.einsum('bli,ii->bli', o2_real, self.w3[0]) - \
                torch.einsum('bli,ii->bli', o2_imag, self.w3[1]) + \
                self.b3[0]
        )

        o3_imag = F.relu(
                torch.einsum('bli,ii->bli', o2_imag, self.w3[0]) + \
                torch.einsum('bli,ii->bli', o2_real, self.w3[1]) + \
                self.b3[1]
        )

        # 3 layer
        z = torch.stack([o3_real, o3_imag], dim=-1)
        z = F.softshrink(z, lambd=self.sparsity_threshold)
        z = z + x  # (128, 211, 128, 2)
        z = torch.view_as_complex(z)
        return z

    def forward(self, x):  # x: (128,30,14)
        x = x.permute(0, 2, 1).contiguous() # x: (128, 14, 30)
        B, N, L = x.shape  # B, N, L = 128, 14, 30
        x = x.reshape(B, -1)  # x: (128, 420)

        x = torch.fft.rfft(x, dim=1, norm='ortho')  # x: (128, 211, 128), torch.complex64
        x = self.tokenEmb(x)
        x = x.reshape(B, (N*L)//2+1, self.frequency_size) # self.frequency_size = 128, x: (128, 211, 128)

        bias = x

        # FourierGNN
        x = self.fourierGC(x, B, N, L)

        x = x + bias

        x = x.reshape(B, (N*L)//2+1, self.embed_size)

        x = torch.fft.irfft(x, n=N*L, dim=1, norm="ortho") # x: (128, 420, 128)
        x = x.reshape(B, N, L, self.embed_size) # x: (128, 14, 30, 128)
        x = x.permute(0, 1, 3, 2)  # x: (128, 14, 128, 30)

        # projection
        x = torch.matmul(x, self.embeddings_10)  # x: (128, 14, 128, 8)
        # x = self.dp(x)
        x = x.reshape(B, N, -1) # x: (128, 14, 1024)
        x = self.fc(x) # x: (128, 14, 1)
        x = self.fc1(x.squeeze(-1))  # x: (128, 1)
        return x.unsqueeze(-1)

# model = FGN(pre_length=1,
#             embed_size=128,
#             feature_size=14,
#             seq_length=180,
#             hidden_size=128)
# from torchsummary import summary
# summary(model)
