__all__ = ['SPM_backbone']

import torch
from torch import nn
import torch.nn.functional as F
from layers.RevIN import RevIN

# Cell
class SPM_backbone(nn.Module):
    def __init__(self, c_in:int, context_window:int, target_window:int, revin=True, affine=True, subtract_last=False,
                  ep_topk=0, num_hard_example=0, ep_mem_num=0, mem_num=0, gamma=1, substitution=1):
        
        super().__init__()
        
        # RevIn
        self.revin = revin
        if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
        # memory
        self.n_vars = c_in
        self.h_units = target_window
        self.ep_topk = ep_topk
        self.num_hard_example = num_hard_example
        self.mem_num = mem_num
        self.mem_dim = target_window
        self.gamma = gamma
        self.ep_mem_num = self.n_vars*ep_mem_num
        self.ep_mem_dim = target_window
        self.target_window = target_window
        self.memory = self.construct_memory()
        self.episodic_memory, self.ep_frequency = self.construct_episodic_memory()
        if gamma==-2:
            self.end_conv = nn.Conv1d(self.h_units + self.ep_mem_dim, target_window, kernel_size=1)
        else:
            self.end_conv = nn.Conv1d(self.h_units + self.mem_dim + self.ep_mem_dim, target_window, kernel_size=1)
            self.end_conv2 = nn.Conv1d(self.h_units + self.mem_dim, target_window, kernel_size=1)
        self.Linear_backbone = nn.Linear(context_window, self.h_units)
        self.point = 0
        self.first = 1
        self.substitution = substitution
        assert self.substitution*self.num_hard_example <= (self.ep_mem_num/self.n_vars), 'Error temp memory is bigger ' \
                                                                          'than the episodic memory'


    def construct_memory(self):
        memory_dict = nn.ParameterDict()
        memory_dict['Memory'] = nn.Parameter(torch.randn(self.mem_num, self.mem_dim), requires_grad=True)  # (M, d)
        memory_dict['Wq'] = nn.Parameter(torch.randn(self.h_units, self.mem_dim), #randn sssssssssssssssssssssssssssssssssssssssssssssssssssss
                                         requires_grad=True)  # project to query

        for param in memory_dict.values():
            nn.init.xavier_normal_(param)
        return memory_dict

    def query_memory(self, h_t: torch.Tensor):
        query = torch.matmul(h_t, self.memory['Wq'])  # (B, N, d)
        att_score = torch.softmax(torch.matmul(query, self.memory['Memory'].t()/(self.mem_dim**0.5)), dim=-1)  # alpha: (B, N, M)
        value = torch.matmul(att_score, self.memory['Memory'])  # (B, N, d)
        _, matched_idx = torch.topk(att_score, k=2, dim=-1)
        pos = self.memory['Memory'][matched_idx[:, :, 0]]  # B, N, d
        neg = self.memory['Memory'][matched_idx[:, :, 1]]  # B, N, d

        return value, query, pos, neg

    # episodic memory
    def construct_episodic_memory(self):
        ep_memory = nn.Parameter((torch.zeros(self.ep_mem_num, self.ep_mem_dim)), requires_grad=False)
        list1 = range(0, self.ep_mem_num)
        list2 = [1]*self.ep_mem_num
        ep_frequency = dict(zip(list1, list2))
        return ep_memory, ep_frequency

    def query_episodic_memory(self, h_t: torch.Tensor):
        query = h_t
        top_k = self.ep_topk
        matched_idx, k_score = self.get_nearest_key(query, key_dict=self.episodic_memory.detach(), k=top_k)
        k_score = torch.softmax(k_score, dim=-1).unsqueeze(-1)
        mem = self.episodic_memory[matched_idx].detach()
        value = torch.sum(mem * k_score, dim=2)
        frequency = torch.bincount(matched_idx.reshape(-1)).tolist()
        if self.first and torch.sum(self.episodic_memory)==0:
            self.first=0
            pass
        else:
            for f in range(len(frequency)):
                self.ep_frequency[f] += frequency[f]
        if self.point==0:
            self.ep_frequency = dict(sorted(self.ep_frequency.items(), key=lambda x: x[1]))  # slow
            self.id = list(self.ep_frequency.keys())
        return value

    def update_episodic_memory(self, h_t, flag):
        if flag == 'train':
            for batch_id in range(h_t.shape[0]):
                for id in range(self.n_vars):  # 每次查询最金没有被访问过的向量，并将其替换
                    self.episodic_memory[self.id[self.point], :] = h_t[batch_id, id, :].data
                    self.ep_frequency[self.id[self.point]] = 0
                    self.point += 1
                    self.point = self.point % (self.num_hard_example * self.n_vars*self.substitution)
        else:
            # print('testing do not save!!!!')
            pass

    def get_nearest_key(self, query, key_dict, k=3, sim='cosine'):
        """
        Function for the k-NN
          input:
             - query: fourier transform of historical speed (B, N, 2 * T)
             - key_dict: fourier transform of representative pattern (K, 2 * T)
             - k: k for k-Nearest Neighbor
             - sim: pairwise similarity function, default is CosineSimilarity
          output:
             - k_index: k-Nearest indices for the memory selection
             - k_value: k-Nearest similarity values for the memory selection
          Note:
             - Instead of calculating similarity matrix between query and key_dict
               we utilize lazy evaluation to avoid large memory consumption
             - k-NN operation doesn't need gradient calculation - torch.no_grad()
        """
        if sim == 'cosine':
            sim_func = nn.CosineSimilarity(dim=-1, eps=1e-6)
        else:
            raise NotImplementedError
        with torch.no_grad():
            B, N = query.size(0), query.size(1)
            key_dict = key_dict.view(1, *key_dict.shape)  # 1, K, 2 * T
            k_index = []
            k_score = []
            for n in range(N):
                query_n = query[:, [n]]  # B, 2 * N
                similarity = sim_func(key_dict, query_n)  # B, K
                topk = torch.topk(similarity, k)
                k_score.append(topk.values)
                k_index.append(topk.indices)

        k_index = torch.stack(k_index, dim=1).to(query.device)
        k_score = torch.stack(k_score, dim=1).to(query.device)

        if k > 1:
            return k_index.squeeze(), k_score.squeeze()
        else:
            return k_index.squeeze()

    def forward(self, x, target, flag):
        # norm
        if self.revin: 
            x = x.permute(0,2,1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0,2,1)
        x_unpadding = x
        h = self.Linear_backbone(x_unpadding)

        h_memory, query, pos, neg = self.query_memory(h)
        h_m_sum = torch.cat((h_memory * 1, h), dim=-1)
        if self.gamma ==-1:
            y1 = h
        elif self.gamma==0:
            y1 = self.end_conv2(h_m_sum.permute(0,2,1)).permute(0,2,1)
        elif self.gamma==-2:
            h_ep = self.query_episodic_memory(h)
            h_e_sum = torch.cat((h, h_ep), dim=-1)
            y1 = self.end_conv(h_e_sum.permute(0,2,1)).permute(0,2,1)
        else:
            h_ep = self.query_episodic_memory(h)
            h_all = torch.cat((h_m_sum, h_ep*self.gamma), dim=-1)
            y1 = self.end_conv(h_all.permute(0,2,1)).permute(0,2,1)
        y = y1
        # denorm
        if self.revin: 
            y = y.permute(0,2,1)
            y = self.revin_layer(y, 'denorm')
            y = y.permute(0,2,1)   # b * c * l
        if(flag=='train' and self.gamma>0):
            example = h
            mae = abs(y-target.permute(0,2,1))
            mae = torch.sum(torch.sum(mae, dim=-1), dim=-1)/(y.shape[1]*y.shape[2])
            value, index = torch.topk(mae, k=self.num_hard_example, largest=True)
            hard_example = example[index]
            self.update_episodic_memory(hard_example, flag)
        return y, query, pos, neg



class AGCN(nn.Module):
    def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
        super(AGCN, self).__init__()
        self.cheb_k = cheb_k
        self.weights_pool = torch.nn.init.xavier_uniform(nn.Parameter(torch.empty(embed_dim, cheb_k, dim_in, dim_out)))
        self.bias_pool = torch.nn.init.xavier_uniform(nn.Parameter(torch.empty(embed_dim, dim_out)))
        self.dim_in = dim_in
        self.embed_dim = embed_dim
        self.dim_out = dim_out
    def forward(self, x, node_embeddings):
        #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
        #output shape [B, N, C]
        node_num = node_embeddings.shape[0]
        supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
        support_set = [torch.eye(node_num).to(supports.device), supports]
        #default cheb_k = 3
        for k in range(2, self.cheb_k):
            support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
        supports = torch.stack(support_set, dim=0)
        weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool)  #N, cheb_k, dim_in, dim_out
        bias = torch.matmul(node_embeddings, self.bias_pool)                       #N, dim_out
        x_g = torch.einsum("knm,bmc->bknc", supports, x)      #B, cheb_k, N, dim_in
        x_g = x_g.permute(0, 2, 1, 3)  # B, N, cheb_k, dim_in
        x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias     #b, N, dim_out
        return x_gconv

            


