import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import math
class Tem_Agg_Block(nn.Module):
    def __init__(self, top_k,seq_len):
        super(Tem_Agg_Block, self).__init__()
        self.chan_share_tb=Tem_agg(top_k,seq_len)
    def split_and_apply(self, x):
        # Reshape x to [bs*channels, seq_len, 1]
        x_processed = self.chan_share_tb(x) 
        return x_processed
    def forward(self, x):
        return self.split_and_apply(x)


class Tem_agg(nn.Module):
    def __init__(self, top_k,seq_len):
        super(Tem_agg, self).__init__()
        self.seq_len = seq_len
        self.k = top_k
        # parameter-efficient design
    def temporal_aggregation(self,x):
        # Calculate the self-attention scores
        d = x.size(-1)  # 
        scores = torch.matmul(x, x.transpose(-2, -1)) / math.sqrt(d)  
        
        # Apply softmax along the last dimension to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)  # Shape: [m, n, p, n]
        
        # Element-wise multiplication of attention weights with x
        x_hat = x.unsqueeze(3) * attention_weights.unsqueeze(-1)  # Shape: [m, n, p, n, c]
        
        # Sum along the last dimension to perform temporal aggregation
        x_hat = x_hat.sum(dim=-2)  # Shape: [m, n, p, c]
        
        return x_hat
  
    
    def forward(self, x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)
        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len) % period != 0:
                length = (((self.seq_len) // period) + 1) * period
                
                padding = torch.zeros([x.shape[0], (length - (self.seq_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len)
                out = x
            # reshape
            out = out.reshape(B, length // period, period,
                              N).permute(0, 3, 1, 2).contiguous()
            # 2D conv: from 1d Variation to 2d Variation
            out = self.temporal_aggregation(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, :(self.seq_len), :])
            
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res
    
class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """

    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x

class series_decomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


def FFT_for_Period(x, k=2):
    # x.shape: [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]

