import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

def FFT_for_Period(x, k=4):
    xf = torch.fft.rfft(x, dim=1)
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()

    period = []
    for top in top_list:
        div = max(x.shape[1] / top, 3)
        p = math.ceil(div)
        if p >= 3:
            period.append(p)
    if len(period) == 0:
        period = [4, 6, 8]
    return [1] + period, abs(xf).mean(-1)[:, top_list]

class moving_avg(nn.Module):
    def __init__(self, kernel_size):
        super(moving_avg, self).__init__()
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=kernel_size, padding=0)
    def forward(self, x):
        x = self.avg(x.permute(0, 2, 1))
        return x.permute(0, 2, 1)

class multi_scale_data(nn.Module):
    def __init__(self, kernel_size, return_len):
        super(multi_scale_data, self).__init__()
        self.kernel_size = kernel_size
        self.max_len = return_len
        self.moving_avg = [moving_avg(k) for k in kernel_size]
    def forward(self, x):
        different_scale_x = []
        for func in self.moving_avg:
            moving_avg = func(x)
            if moving_avg.shape[1] == 0:
                continue
            different_scale_x.append(moving_avg)
        multi_scale_x = torch.cat(different_scale_x, dim=1)
        if multi_scale_x.shape[1] < self.max_len:
            padding = torch.zeros(x.shape[0], self.max_len - multi_scale_x.shape[1], x.shape[2], device=x.device)
            multi_scale_x = torch.cat([multi_scale_x, padding], dim=1)
        elif multi_scale_x.shape[1] > self.max_len:
            multi_scale_x = multi_scale_x[:, :self.max_len, :]
        return multi_scale_x

class nconv(nn.Module):
    def __init__(self, gnn_type):
        super(nconv, self).__init__()
        self.gnn_type = gnn_type
    def forward(self, x, A):
        if self.gnn_type == 'time':
            return torch.einsum('btdc,tw->bwdc', x, A)
        else:
            return torch.einsum('btdc,dw->btwc', x, A)

class gcn(nn.Module):
    def __init__(self, c_in, c_out, dropout, gnn_type, order=2):
        super(gcn, self).__init__()
        self.nconv = nconv(gnn_type)
        self.c_in = (order+1) * c_in
        self.mlp = nn.Linear(self.c_in, c_out)
        self.dropout = dropout
        self.order = order
        self.act = nn.GELU()
    def forward(self, x, a):
        out = [x]
        x1 = self.nconv(x, a)
        out.append(x1)
        for _ in range(2, self.order + 1):
            x2 = self.nconv(x1, a)
            out.append(x2)
            x1 = x2
        h = torch.cat(out, dim=-1)
        h = self.mlp(h)
        h = self.act(h)
        return F.dropout(h, self.dropout, self.training)

class single_scale_gnn(nn.Module):
    def __init__(self, configs):
        super(single_scale_gnn, self).__init__()
        self.tk = configs.tk
        self.scale_number = configs.scale_number
        self.use_tgcn = configs.use_tgcn
        self.use_ngcn = configs.use_ngcn
        self.init_seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.channels = configs.enc_in
        self.dropout = configs.dropout
        self.device = f'cuda:{configs.gpu}'
        self.GraphforPre = False
        self.tvechidden = configs.tvechidden
        self.d_model = configs.hidden
        self.start_linear = nn.Linear(1, self.d_model)

        self.max_seq_len = 1024
        self.timevec1_full = nn.Parameter(torch.randn(self.max_seq_len, self.tvechidden).to(self.device))
        self.timevec2_full = nn.Parameter(torch.randn(self.tvechidden, self.max_seq_len).to(self.device))

        self.tgcn = gcn(self.d_model, self.d_model, self.dropout, gnn_type='time')
        self.nodevec1 = nn.Parameter(torch.randn(self.channels, configs.nvechidden).to(self.device))
        self.nodevec2 = nn.Parameter(torch.randn(configs.nvechidden, self.channels).to(self.device))
        self.gconv = gcn(self.d_model, self.d_model, self.dropout, gnn_type='nodes')
        self.grang_emb_len = math.ceil(self.d_model // 4)
        self.graph_mlp = nn.Linear(2 * self.tvechidden, self.grang_emb_len)
        self.Linear = nn.Linear(2 * self.d_model, 1)

    def logits_warper_softmax(self, adj, mask):
        adj = adj.masked_fill(mask, -1e9)
        adj = F.softmax(adj, dim=0)
        return adj

    def logits_warper(self, adj, mask, pos, neg):
        pos_inv = ~pos
        neg_inv = ~neg
        processed_pos = pos * F.softmax(adj.masked_fill(pos_inv, -1e9), dim=-1)
        processed_neg = -1 * neg * F.softmax((1/(adj+1)).masked_fill(neg_inv, -1e9), dim=-1)
        return processed_pos + processed_neg

    def add_cross_scale_connect(self, adj, periods, max_L):
        mask = torch.tensor([], dtype=torch.bool, device=adj.device)
        k = self.tk
        start = 0
        for period in periods:
            ls = self.init_seq_len // period
            end = min(start + ls, max_L)
            ls = end - start
            kp = max(k // period, 5)
            kp = min(kp, ls)
            mask = torch.cat([mask, adj[:, start:end] < torch.topk(adj[:, start:end], k=kp)[0][..., -1, None]], dim=1)
            start = end
            if start == max_L:
                break
        if start < max_L:
            mask = torch.cat([mask, torch.zeros(self.max_seq_len, max_L - start, dtype=torch.bool, device=mask.device)], dim=1)
        return mask

    def add_adjecent_connect(self, mask, L):
        s = np.arange(0, L - 1)
        e = np.arange(1, L)
        all = np.concatenate([np.stack([s, e], 0), np.stack([e, s], 0)], 1)
        mask[all] = False
        return mask

    def add_cross_var_adj(self, adj):
        k = 3
        k = min(k, adj.shape[0])
        mask = (adj < torch.topk(adj, k=adj.shape[0]-k)[0][..., -1, None]) * (adj > torch.topk(adj, k=adj.shape[0]-k)[0][..., -1, None])
        mask_pos = adj >= torch.topk(adj, k=k)[0][..., -1, None]
        mask_neg = adj <= torch.kthvalue(adj, k=k)[0][..., -1, None]
        return mask, mask_pos, mask_neg

    def forward(self, x):
        periods, _ = FFT_for_Period(x, self.scale_number)
        multi_scale_func = multi_scale_data(kernel_size=periods, return_len=self.max_seq_len)
        x = multi_scale_func(x)
        self.seq_len = x.shape[1]
        self.timevec1 = self.timevec1_full[:self.seq_len, :]
        self.timevec2 = self.timevec2_full[:, :self.seq_len]

        x = self.expand_channel(x)
        batch_size = x.shape[0]
        x_ = x

        if self.use_tgcn:
            adj = F.relu(torch.einsum('td,dm->tm', self.timevec1, self.timevec2))
            mask = self.add_cross_scale_connect(adj, periods, self.seq_len)
            mask = self.add_adjecent_connect(mask, self.seq_len)
            time_adp = self.logits_warper_softmax(adj, mask)
            x = self.tgcn(x, time_adp) + x

        if self.use_ngcn:
            adj = F.relu(torch.einsum('td,dm->tm', self.nodevec1, self.nodevec2))
            mask, pos, neg = self.add_cross_var_adj(adj)
            var_adp = self.logits_warper(adj, mask, pos, neg)
            x = self.gconv(x, var_adp) + x

        x = torch.cat([x_, x], dim=-1)
        x = self.Linear(x).squeeze(-1)
        return x[:, :self.init_seq_len, :]

    def expand_channel(self, x):
        return self.start_linear(x.unsqueeze(-1))

class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.graph_encs = nn.ModuleList([single_scale_gnn(configs) for _ in range(configs.e_layers)])
        self.Linear = nn.Linear(self.seq_len, self.pred_len)
        self.anti_ood = configs.anti_ood
    def forward(self, x, x_mark_enc, x_dec, x_mark_dec):
        if self.anti_ood:
            seq_last = x[:, -1:, :].detach()
            x = x - seq_last
        for layer in self.graph_encs:
            x = layer(x)
        x = self.Linear(x.permute(0, 2, 1)).permute(0, 2, 1)
        if self.anti_ood:
            x = x + seq_last
        return x
