import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers.Embed import TextEmbedding_wo_pos, PositionalEmbedding
from layers.FreqAttention_Family import FullFreqAttention, FreqAttentionLayer
from utils.masking import TriangularCausalMask, ProbMask, ConstantMask

from layers.Retrieval import FreqRetrievalTool, FreqRetrievalToolCI, TextRetrievalTool
class my_Layernorm(nn.Module):
    def __init__(self, eps=1e-5, elementwise_affine=True):
        """
        Custom LayerNorm for complex vectors with 4D input.
        
        Args:
            normalized_shape: Shape of the last dimension to normalize (e.g., d_model).
            eps: A value added to the denominator for numerical stability.
            elementwise_affine: A boolean value that when set to True, this module
                                has learnable per-element affine parameters.
        """
        super(my_Layernorm, self).__init__()
        # self.normalized_shape = normalized_shape
        self.eps = eps
        # self.elementwise_affine = elementwise_affine


        self.register_parameter('weight', None)
        self.register_parameter('bias', None)

    def forward(self, input):
        # Compute the magnitude of the complex tensor
        magnitude = torch.abs(input)

        # Compute mean and variance along the last dimension (d_model)
        mean = magnitude.mean(dim=-1, keepdim=True)
        var = magnitude.var(dim=-1, keepdim=True, unbiased=False)

        # Normalize the input based on the magnitude
        normalized_input = (input - mean) / torch.sqrt(var + self.eps)

        # Apply learnable affine transformation if enabled
        # if self.elementwise_affine:
        #     normalized_input = normalized_input * self.weight + self.bias

        return normalized_input



class MLP(nn.Module):
    def __init__(self, layer_sizes, dropout_rate=0.5):
        super().__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout_rate)  
        for i in range(len(layer_sizes) - 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:  
                x = F.relu(x)
                x = self.dropout(x)  
            x = F.relu(x)
            x = self.dropout(x)
        return x

class TextEncoder(nn.Module):
    def __init__(self, configs):
        super(TextEncoder, self).__init__()
        self.configs = configs
        self.text_emb = 8 #configs.text_emb

        self.text_proj_layer = MLP(layer_sizes=[configs.llm_emb_size, self.text_emb], dropout_rate=0.1)
        self.text_embed_layer = TextEmbedding_wo_pos(configs.llm_emb_size, configs.llm_emb_size, configs.embed, configs.freq, configs.dropout)

        self.embed_size = 16#configs.mm_emb_size  # embed_size
        self.hidden_size = 16#configs.mm_hidden_size  # hidden_size
        self.pred_len = configs.pred_len
        # self.feature_size = configs.enc_in  # channels
        self.feature_size = 1# configs.n_ts_features
        self.seq_len = configs.seq_len
        self.channel_independence = configs.channel_independence
        self.sparsity_threshold = 0.01
        self.scale = 1
        self.T_f = self.pred_len//2+1
        self.H_f = self.seq_len//2+1

        self.dominance_freq = int(self.H_f)
        self.dominance_freq_pred = int(self.T_f)
        
        self.r1 = nn.Parameter(self.scale * torch.randn(self.feature_size, self.text_emb, self.dominance_freq*self.embed_size))
        self.i1 = nn.Parameter(self.scale * torch.randn(self.feature_size, self.text_emb, self.dominance_freq*self.embed_size))
        self.rb1 = nn.Parameter(self.scale * torch.randn(self.feature_size, self.dominance_freq, self.embed_size))
        self.ib1 = nn.Parameter(self.scale * torch.randn(self.feature_size, self.dominance_freq, self.embed_size))
        
        
    def forward(self, text_embeddings, x_mark_enc):
        B = text_embeddings.shape[0]
        N_c = 1
        text_embeddings = self.text_embed_layer(text_embeddings, x_mark_enc)
        text_embeddings = self.text_proj_layer(text_embeddings) 

        text_embeddings = text_embeddings.unsqueeze(1).repeat(1, N_c, 1, 1) # B x N_c x N_t x D
        
        r1 = self.r1.unsqueeze(0).repeat(B,1,1,1) # 16 x 3 x 6 x 1200
        text_real = torch.matmul(text_embeddings, r1).reshape(B, N_c, self.seq_len, self.dominance_freq, self.embed_size)
        # text_real = text_real[:,:,:,0,:].unsqueeze(3).repeat(1,1,1,self.dominance_freq,1)
        # text_real = torch.sum(text_real, dim=2)
        # text_real = F.relu(text_real+self.rb1)

        i1 = self.i1.unsqueeze(0).repeat(B,1,1,1) # 16 x 3 x 6 x 1200
        text_imag = torch.matmul(text_embeddings, i1).reshape(B, N_c, self.seq_len, self.dominance_freq, self.embed_size)
        # text_imag = text_imag[:,:,:,0,:].unsqueeze(3).repeat(1,1,1,self.dominance_freq,1)
        # text_imag = torch.sum(text_imag, dim=2)
        # text_imag = F.relu(text_imag+self.ib1)

        # if self.dominance_freq != self.H_f:
        #     real_padding = torch.zeros(B, N_c, self.H_f - self.dominance_freq, self.embed_size).to(text_imag.device)
        #     imag_padding = torch.zeros(B, N_c, self.H_f - self.dominance_freq, self.embed_size).to(text_imag.device)
        #     text_real = torch.cat([text_real, real_padding], dim=2)
        #     text_imag = torch.cat([text_imag, imag_padding], dim=2)
        
        text = torch.stack([text_real, text_imag], dim=-1)
        text = F.softshrink(text, lambd=self.sparsity_threshold)
        text = torch.view_as_complex(text)
        
        return text

class FreqMLP(nn.Module):
    def __init__(self, d_in, d_out):
        super(FreqMLP, self).__init__()
        self.freq_upsampler_real = nn.Linear(d_in, d_out) # complex layer for frequency upcampling]
        self.freq_upsampler_imag = nn.Linear(d_in, d_out) 
        self.sparsity_threshold = 0.01
        self.real_norm = nn.InstanceNorm2d(1)
        self.imag_norm = nn.InstanceNorm2d(1)
    def forward(self, x):
        real = self.freq_upsampler_real(x.real)-self.freq_upsampler_imag(x.imag)
        imag = self.freq_upsampler_real(x.imag)+self.freq_upsampler_imag(x.real)
        real = F.relu(real)
        imag = F.relu(imag)

        real = real + x.real
        imag = imag + x.imag

        real = self.real_norm(real)
        imag = self.imag_norm(imag)
        
        x = torch.stack([real, imag], dim=-1)
        # text = torch.stack([real, imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)
        return x

        
class Model(nn.Module):
    """
    Paper link: https://arxiv.org/pdf/2311.06184.pdf
    """

    def __init__(self, configs):
        super(Model, self).__init__()
        self.device = torch.device(f'cuda:{configs.gpu}')
        self.task_name = configs.task_name
        if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
            self.pred_len = configs.seq_len
        else:
            self.pred_len = configs.pred_len
        self.embed_size = 16 # configs.mm_emb_size  # embed_size
        self.hidden_size = 16 # configs.mm_hidden_size  # hidden_size
        self.pred_len = configs.pred_len
        # self.feature_size = configs.enc_in  # channels
        self.feature_size = 1 #configs.n_ts_features
        self.seq_len = configs.seq_len
        self.channel_independence = configs.channel_independence
        self.sparsity_threshold = 0.01
        self.scale = 1
        self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
        
        self.configs = configs
        # if not configs.proj_per_freq:
        #     self.fc = nn.Sequential(
        #         nn.Linear(self.pred_len * self.embed_size, self.hidden_size),
        #         nn.LeakyReLU(),
        #         nn.Linear(self.hidden_size, self.pred_len)
        #     )
        # else:
        # print('Project per frequency')
        self.fc = nn.Linear(self.embed_size, 1)

        self.text_emb = 12 #configs.text_emb

        llm_dim=4096
        self.text_proj_layer = MLP(layer_sizes=[llm_dim, int(llm_dim/8), self.text_emb], dropout_rate=0.3)

        self.text_embed_layer = TextEmbedding_wo_pos(768, 768, configs.embed, configs.freq, configs.dropout)
        self.T_f = self.pred_len//2+1
        self.H_f = self.seq_len//2+1

        self.dominance_freq = int(self.H_f )
        self.dominance_freq_pred = int(self.T_f)


        self.freq_upsampler_real = nn.Linear(self.H_f, self.T_f) # complex layer for frequency upcampling]
        self.freq_upsampler_imag = nn.Linear(self.H_f, self.T_f) 
        # if self.configs.only_text_input:
        #     self.freq_upsampler_real = nn.Linear(self.seq_len, self.T_f) # complex layer for frequency upcampling]
        #     self.freq_upsampler_imag = nn.Linear(self.seq_len, self.T_f)     

        self.real_norm = nn.InstanceNorm2d(1)
        self.imag_norm = nn.InstanceNorm2d(1)

        scale = 0.1
        self.embedding_real = nn.Parameter(scale * torch.randn(1, self.embed_size))
        self.embedding_imag = nn.Parameter(scale * torch.randn(1, self.embed_size))
        self.embedding_real_dec = nn.Parameter(scale * torch.randn(self.embed_size, 1))
        self.embedding_imag_dec = nn.Parameter(scale * torch.randn(self.embed_size, 1))
        
        self.freq_attn_layer = FreqAttentionLayer(
                        FullFreqAttention(False, 5, attention_dropout=configs.dropout,
                                      output_attention=False),
                        d_model=self.embed_size, n_heads=1)

        self.position_emb = PositionalEmbedding(self.embed_size)
        self.dropout = nn.Dropout(0.1)

        self.channels = configs.enc_in

        self.linear_x = nn.Linear(self.seq_len, self.pred_len)
        
        self.n_period = configs.n_period
        self.topm = configs.topm

        if configs.channel_independence:
            self.rt = TextRetrievalTool(
            seq_len=self.seq_len,
            pred_len=self.pred_len,
            channels=self.channels,
            n_period=self.n_period,
            topm=self.topm,
        )
        else:
            self.rt = TextRetrievalTool(
            seq_len=self.seq_len,
            pred_len=self.pred_len,
            channels=self.channels,
            n_period=self.n_period,
            topm=self.topm,
        )
        self.period_num = self.rt.period_num[-1 * self.n_period:]
        module_list = [
            nn.Linear(self.pred_len // g, self.pred_len)
            for g in self.period_num
        ]
        self.retrieval_pred = nn.ModuleList(module_list)
        self.linear_pred = nn.Linear(2 * self.pred_len, self.pred_len)

    
    def prepare_dataset(self, train_data, valid_data, test_data):
        print('start prepare dataset')
        print('original train_data shape: ', len(train_data))
        self.rt.prepare_dataset(train_data)
        print('len of train_data: ', len(train_data))
        self.retrieval_dict = {}
        
        print('Doing Train Retrieval')
        train_rt = self.rt.retrieve_all(train_data, train=True, device=self.device)

        print('Doing Valid Retrieval')
        valid_rt = self.rt.retrieve_all(valid_data, train=False, device=self.device)

        print('Doing Test Retrieval')
        test_rt = self.rt.retrieve_all(test_data, train=False, device=self.device)

        del self.rt
        torch.cuda.empty_cache()
            
        self.retrieval_dict['train'] = train_rt.detach().to(self.device)
        self.retrieval_dict['valid'] = valid_rt.detach().to(self.device)
        self.retrieval_dict['test'] = test_rt.detach().to(self.device)


    # dimension extension
    def tokenEmb(self, x):
        # x: [Batch, Input length, Channel]
        x = x.permute(0, 2, 1)
        x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embeddings
        return x * y

    def realEmb(self, x):
        # x: [Batch, Input length, Channel]
        x = x.permute(0, 2, 1)
        x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_real
        return x * y

    def imagEmb(self, x):
        # x: [Batch, Input length, Channel]
        x = x.permute(0, 2, 1)
        x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_imag
        return x * y

    def realEmb_dec(self, x):
        # x: [Batch, Input length, Channel]
        # x = x.permute(0, 2, 1)
        # x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_real_dec.unsqueeze(0).unsqueeze(0)
        return torch.matmul(x,y)

    def imagEmb_dec(self, x):
        # x: [Batch, Input length, Channel]
        # x = x.permute(0, 2, 1)
        # x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_imag_dec.unsqueeze(0).unsqueeze(0)
        return torch.matmul(x,y)

    def realEmb_ffd(self, x):
        # x: [Batch, Input length, Channel]
        # x = x.permute(0, 2, 1)
        # x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_real_ffd.unsqueeze(0).unsqueeze(0)
        return torch.matmul(x,y)

    def imagEmb_ffd(self, x):
        # x: [Batch, Input length, Channel]
        # x = x.permute(0, 2, 1)
        # x = x.unsqueeze(3)
        # N*T*1 x 1*D = N*T*D
        y = self.embedding_imag_ffd.unsqueeze(0).unsqueeze(0)
        return torch.matmul(x,y)
    

    def FreqMLP(self, x, real_layer, imag_layer):
        x_real = real_layer(x.real) - imag_layer(x.imag)
        x_imag = imag_layer(x.real) + real_layer(x.imag)
        x = torch.stack([x_real, x_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)
        return x

    def forecast(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, text_embeddings=None, index=None, mode='train'):
        # x: [Batch, Input length, Channel]
        # B, T, N_c = x_dec.shape
        # embedding x: [B, N, T, D]
        bsz, seq_len, channels = x_enc.shape
        # Retrieval
        # print('index: ', index.shape)
        pred_from_retrieval = self.retrieval_dict[mode][:, index] # G, B, P, C
        pred_from_retrieval = pred_from_retrieval.to(self.device)
        # print('pred_from_retrieval: ', pred_from_retrieval.shape)
        retrieval_pred_list = []
        for i, pr in enumerate(pred_from_retrieval):
            # print('pr.shape: ', pr.shape)
            # print('bsz, self.pred_len, channels: ', (pr.shape[0], self.pred_len, channels))
            pr = pr.transpose(1,2)
            assert((bsz, self.pred_len, channels) == pr.shape), f"Shape mismatch: expected {(bsz, self.pred_len, channels)}, got {pr.shape}"
            g = self.period_num[i]
            pr = pr.reshape(bsz, self.pred_len // g, g, channels)
            pr = pr[:, :, 0, :]
            # print('pr: ', pr.shape)
            # print('x_offset: ', x_offset.shape)
            # pr  = torch.cat([x_offset, pr], dim=1)
            pr = self.retrieval_pred[i](pr.permute(0, 2, 1)).permute(0, 2, 1)

            pr = pr.reshape(bsz, self.pred_len, self.channels)
            retrieval_pred_list.append(pr)

        retrieval_pred_list = torch.stack(retrieval_pred_list, dim=1)
        retrieval_pred_list = retrieval_pred_list.sum(dim=1)

        

        # normalize
        means = x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        # DFT
        x = torch.fft.rfft(x_enc, dim=1, norm='ortho') 
        
        x = self.FreqMLP(x, self.realEmb, self.imagEmb)

        # add position embedding
        sample = x.real[:,0,:,:]
        position_emb = self.position_emb(sample).unsqueeze(1)
        
        x.real += position_emb
        x.imag += position_emb
        
        B, N_c, T, D = x.real.shape
        bias = x

        # TS-Text Fusion
        text = text_embeddings

        # if self.configs.fuse_history:
        if True:
            # print('text_embeddings: ', text_embeddings.shape)
            # print('x: ', x.real.shape)
            attn_mask = ConstantMask(x.real.shape[0], x.real.shape[1], text_embeddings.shape[1], device=x_enc.device)
            B, N_c, H_f, D = x.shape
            # print('text: ' , text.shape)
            B, N_c, H, H_f, D = text.shape
            x = x.reshape(B*N_c, H_f, D).contiguous()

            text = text[:,:,:,0,:].reshape(B*N_c, H, D) 

            output, attn_score = self.freq_attn_layer(x, text, text, attn_mask)
            
            output = output.reshape(B, N_c, H_f, D).contiguous()
            x = x.reshape(B, N_c, H_f, D).contiguous()

            if True:
                real = torch.mul(output.real, x.real) - torch.mul(output.imag, x.imag)
                imag = torch.mul(output.real, x.imag) + torch.mul(output.imag, x.real)
            else:
                real = output.real
                imag = output.imag

            real += x.real
            imag += x.imag

            # Normalization
            real = self.real_norm(real)
            imag = self.imag_norm(imag)

        # elif self.configs.sum_fusion:
        #     real_text = torch.sum(text_embeddings.real, dim=2)
        #     imag_text = torch.sum(text_embeddings.imag, dim=2)
        #     real = x.real + real_text
        #     imag = x.imag + imag_text

        else:
            real = x.real
            imag = x.imag

        # if self.configs.only_text_input:
        #     real = text_embeddings.real[:,:,:,0,:]
        #     imag = text_embeddings.imag[:,:,:,0,:]

        x = torch.stack([real, imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)

        # Freq Upsampling
        x = x.permute(0, 1, 3, 2)  # B x N_c x T_f x D
        x = self.FreqMLP(x, self.freq_upsampler_real, self.freq_upsampler_imag)  # B x N_c x T_f x D
        x = x.permute(0, 1, 3, 2)  # B x N_c x D x T_f
        
        # Freq projection
        x_real = self.realEmb_dec(x.real) - self.imagEmb_dec(x.imag)
        x_imag = self.imagEmb_dec(x.real) + self.realEmb_dec(x.imag)
     
        text = torch.stack([x_real, x_imag], dim=-1)
        text = F.softshrink(text, lambd=self.sparsity_threshold)
        text = torch.view_as_complex(text)
        
        x = torch.fft.irfft(text, n=self.pred_len, dim=2, norm="ortho")
        

        # De-normalize
        # x = x * stdev.unsqueeze(-1) + means.unsqueeze(-1)
        x = x.squeeze(1)
        # print('x: ', x.shape)
        # print('retrieval_pred_list: ', retrieval_pred_list.shape)
        # normalize retrieval_pred_list
        # mean_retrieval_pred = torch.mean(retrieval_pred_list, dim=1, keepdim=True)
        # std_retrieval_pred = torch.std(retrieval_pred_list, dim=1, keepdim=True)
        # retrieval_pred_list = (retrieval_pred_list - mean_retrieval_pred) / (std_retrieval_pred + 1e-5)
        pred = torch.cat([x, retrieval_pred_list], dim=1)
        pred = self.linear_pred(pred.permute(0, 2, 1)).permute(0, 2, 1).reshape(bsz, self.pred_len, self.channels)
        x = pred
        # print('pred: ', pred.shape)
        # print('stdev: ', stdev.shape)
        # print('means: ', means.shape)
        x = x * stdev + means
        # x = x.squeeze(1)
        # print('x: ', x.shape)
        return x

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, text_embeddings=None, index=None, mode='train'):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, text_embeddings, index=index, mode=mode)
            return dec_out[:, -self.pred_len:, :]  # [B, L, D]
        else:
            raise ValueError('Only forecast tasks implemented yet')
