import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .LPF_layer import LPFAttLayer as GraphLayer
import math


############################################  Graph functions  ############################################

class OutLayer(nn.Module):
    def __init__(self, in_num, node_num, layer_num, inter_num = 512):
        super(OutLayer, self).__init__()
        modules = []
        for i in range(layer_num):
            if i == layer_num-1:
                modules.append(nn.Linear( in_num if layer_num == 1 else inter_num, 1))
            else:
                layer_in_num = in_num if i == 0 else inter_num
                modules.append(nn.Linear( layer_in_num, inter_num ))
                modules.append(nn.BatchNorm1d(inter_num))
                modules.append(nn.ReLU())

        self.mlp = nn.ModuleList(modules)

    def forward(self, x):
        out = x
        for mod in self.mlp:
            if isinstance(mod, nn.BatchNorm1d):
                out = out.permute(0,2,1)
                out = mod(out)
                out = out.permute(0,2,1)
            else:
                out = mod(out)

        return out


class GNNLayer(nn.Module):
    def __init__(self, in_channel, out_channel, inter_dim=0, heads=1, node_num=100):
        super(GNNLayer, self).__init__()
        self.gnn = GraphLayer(in_channel, out_channel, inter_dim=inter_dim, heads=heads, concat=False)
        self.bn = nn.BatchNorm1d(out_channel)
        self.relu = nn.ReLU()
        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x, edge_index, embedding=None, node_num=0):
        out = self.gnn(x, edge_index, embedding, return_attention_weights=False)
        out = self.bn(out)
        
        return self.relu(out)


def get_batch_edge_index(org_edge_index, batch_num, node_num):
    edge_index = org_edge_index.clone().detach()
    edge_num = org_edge_index.shape[1]
    batch_edge_index = edge_index.repeat(1,batch_num).contiguous()
    for i in range(batch_num):
        batch_edge_index[:, i*edge_num:(i+1)*edge_num] += i*node_num

    return batch_edge_index.long()


##### graph_module
class graph_module(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, feature_dim, input_dim, onehot_dim, dim, embed_dim, edge_set_num, topk, adaptive_gcn_option):
        super(graph_module, self).__init__()
        self.feature_dim = feature_dim
        self.cache_edge_index_sets = [None] * edge_set_num      
        self.embedding = nn.Embedding(feature_dim, embed_dim)   
        self.topk = topk                                        
        self.adaptive_gcn_option = adaptive_gcn_option          
        
        if self.adaptive_gcn_option:
            self.lin = nn.Linear(input_dim*onehot_dim, onehot_dim, bias=False)                  
            self.cheb_k = 2
            self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, self.cheb_k, onehot_dim,onehot_dim)) 
            self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, onehot_dim))   
            self.learned_a = None
        else:
            self.gnn_layers = nn.ModuleList([
            GNNLayer(input_dim, 1, inter_dim = dim + embed_dim, heads = 1) for i in range(edge_set_num)])
        self.init_params()
    
    def init_params(self):
        nn.init.kaiming_uniform_(self.embedding.weight, a=math.sqrt(5))
        if self.adaptive_gcn_option  == True:
            nn.init.xavier_uniform_(self.weights_pool, gain=1.414)
            nn.init.xavier_uniform_(self.bias_pool, gain=1.414)

    def forward(self, edge_index_sets, x):
        device = x.device
        batch_num, node_num, all_feature = x.shape      
        x = x.reshape(-1, all_feature).contiguous()       
        gcn_outs = []

        for i, edge_index in enumerate(edge_index_sets):
            edge_num = edge_index.shape[1]                         
            cache_edge_index = self.cache_edge_index_sets[i]       

            if cache_edge_index is None or cache_edge_index.shape[1] != edge_num*batch_num:
                self.cache_edge_index_sets[i] = get_batch_edge_index(edge_index, batch_num, node_num).to(device)  
                           
            all_embeddings = self.embedding(torch.arange(node_num).to(device))     
            weights_arr = all_embeddings.detach().clone()                          
            all_embeddings = all_embeddings.repeat(batch_num, 1)                   

            weights = weights_arr.view(node_num, -1)                               
            
            if self.adaptive_gcn_option == True:
                a = F.softmax(F.relu(torch.matmul(weights, weights.T)),dim=1)      
                a_set = [torch.eye(node_num).to(x.device), a]                      
                for k in range(2, self.cheb_k):                                    
                    a_set.append(torch.matmul(2 * a, a_set[-1]) - a_set[-2])       
                a = torch.stack(a_set, dim=0)                                      
                x = self.lin(x)                                                    
                x_temp =  x.view(batch_num, node_num, -1)                          
                ax = torch.einsum("knm,bmc->bknc", a, x_temp)                      
                weight = torch.einsum('nd,dkio->nkio', weights, self.weights_pool) 
                bias = torch.matmul(weights, self.bias_pool)                       
                ax = ax.permute(0, 2, 1, 3)                                        
                ax = torch.einsum('bnki,nkio->bno', ax, weight) + bias             
                gcn_out = torch.reshape(ax, (batch_num*node_num, -1))              
            else:
                cos_ji_mat = torch.matmul(weights, weights.T)
                normed_mat = torch.matmul(weights.norm(dim=-1).view(-1,1), weights.norm(dim=-1).view(1,-1))
                cos_ji_mat = cos_ji_mat / normed_mat
                dim = weights.shape[-1]
                topk_num = self.topk
                topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1]
                self.learned_graph = topk_indices_ji
                gated_i = torch.arange(0, node_num).T.unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0)
                gated_j = topk_indices_ji.flatten().unsqueeze(0)
                gated_edge_index = torch.cat((gated_j, gated_i), dim=0)
                batch_gated_edge_index = get_batch_edge_index(gated_edge_index, batch_num, node_num).to(device)
                gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num*batch_num, embedding=all_embeddings)

            gcn_outs.append(gcn_out)

        out = torch.cat(gcn_outs, dim=1)
        out = x.view(batch_num, node_num, -1)
        
        return out

    
############################################  Decomposition functions  ############################################

class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean
    

############################################  DLinear, GCNDLinear, ADA_DLinear  ############################################

class proactive_anomaly_detection(nn.Module):
    """
    Decomposition-Linear
    """
    def __init__(self, edge_index_sets, input_dim, decay, dim=64, topk=20, trans = None, pred_len = 1, graph=True, adaptive_gcn_option=False):
        super(proactive_anomaly_detection, self).__init__()
        self.seq_len = input_dim                            
        self.pred_len = pred_len                            

        ##### Decompsition Kernel Size
        kernel_size = 25
        self.decompsition = series_decomp(kernel_size)

        ##### Graph Part
        self.graph = graph
        self.edge_index_sets = edge_index_sets
        edge_set_num = len(edge_index_sets)                 
        embed_dim = dim                                     
        self.topk = topk                                    
        self.adaptive_gcn_option = adaptive_gcn_option 
        
        ##### conti / Cate Index Divide
        self.continuous_index, self.categorical_index, self.output_info = trans.col_index()
        self.feature_dim = len(self.output_info)
        self.new_continuous_index = []
        self.new_categorical_index = []
        self.max_size = max(self.output_info)  
        if self.graph:
            self.graph_total = graph_module(self.feature_dim, input_dim, self.max_size, dim, embed_dim, edge_set_num, topk, self.adaptive_gcn_option)
        i = 0
        
        for index, value in enumerate(self.output_info):
            if index in self.continuous_index:
                self.new_continuous_index.append(i)
                i += 1
            else:
                for j in range(self.max_size):
                    self.new_categorical_index.append(i)
                    i += 1
                    
        ##### Linear layer for separation training
        self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
        self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
        self.Linear_Cat = nn.Linear(self.seq_len,self.pred_len)

    
    def datainfo(self):
        return self.new_continuous_index, self.new_categorical_index, self.output_info

    def forward(self, x):
        # x: [Batch, Channel, Input length] 
        x_conti = x[:, self.new_continuous_index, :]            # Batch,   Conti Channel,  Seq len
        x_cate = x[:, self.new_categorical_index, :]            # Batch,   Cate Channel ,  Seq len
        con_info = self.output_info.count(1)                    # Conti type channel count
        cate_info = len(self.output_info)-con_info              # Cate type channel count

        x_conti_graph = x_conti.permute(0,2,1)                  # Batch,   Seq len,   Conti Channel
        x_conti_graph = x_conti_graph.unsqueeze(-1)             # Batch,   Seq len,   Conti Channel,   1
        x_conti_graph_zero = torch.zeros_like(x_conti_graph).repeat(1,1,1,self.max_size-1)
        x_conti_graph = torch.cat((x_conti_graph,x_conti_graph_zero),3)
        
        x_cate_graph = x_cate.permute(0,2,1)                    # Batch,   Seq len,   Cate Channel 
        x_cate_graph = x_cate_graph.reshape(x_cate_graph.size(0), x_cate_graph.size(1), cate_info, -1) 
        x_graph = torch.cat((x_conti_graph,x_cate_graph),2)     # Batch,   Seq len,   Channel     ,    max_size

        x_conti = x_conti.permute(0,2,1).contiguous()                                       # Batch,   Seq len,   Conti Channel
        seasonal_init, trend_init = self.decompsition(x_conti)
        seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) # Batch,   Conti Channel,  Seq len
        seasonal_output = self.Linear_Seasonal(trend_init)                                  # Batch,   Conti Channel,  Pred len
        trend_output = self.Linear_Trend(trend_init)                                        # Batch,   Conti Channel,  Pred len
        x_conti = seasonal_output + trend_output
        x_conti = x_conti.permute(0,2,1)                                                    # Batch,   Pred len,   Conti Channel
        
        x_cate = self.Linear_Cat(x_cate)                        # Batch,   Cate Channel,   Pred len
        x_cate = x_cate.reshape(x.size(0),  cate_info, -1)      # Batch,   cate_info   ,   max_size
        x_cate = x_cate.permute(0,2,1)                          # Batch,   Pred len,   Cate Channel

        if self.graph:
            # Input  : Batch,   Seq len,   Channel,   max_size
            x_graph = x_graph.permute(0,2,3,1)                                  # Batch,   Channel,   max_size,   Seq len
            x_graph = x_graph.reshape(x_graph.size(0), x_graph.size(1), -1)     # Batch,   Channel,   max_size*Seq len
            x_graph_out = self.graph_total(self.edge_index_sets, x_graph)       # Batch,   Channel,   max_size
            x_graph_out = x_graph_out.permute(0,2,1)                            # Batch,   max_size,  Channel
            x_graph_con = x_graph_out[:,0:1,self.continuous_index]              # Batch,   1      ,   Conti Channel
            x_graph_cat = x_graph_out[:,:,self.categorical_index]               # Batch,   max_size,  Cate Channel
            x_conti = x_conti + x_graph_con                                     # Batch,   1      ,   Conti Channel
            x_cate = x_cate + x_graph_cat                                       # Batch,   max_size,  Cate Channel

        x_conti = x_conti.repeat(1, x_cate.size(1), 1)                          
        out = torch.cat((x_conti, x_cate), 2)                                   # Batch,   max_size,  Channel
        return out                                            

    
    
    
    
