import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot, zeros
from .graph_generator import Graph_Generator
from .util import normalize_adjacency,get_laplacian_matrix
from torch_geometric.nn import  DenseGraphConv,DenseSAGEConv
from torch_geometric.utils import dense_to_sparse

from torch.nn.utils import weight_norm
import datetime

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.gconv1 = DenseGraphConv(in_channels=n_outputs,out_channels=n_outputs,aggr='mean')

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net1 = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1)
        self.net2 = nn.Sequential(self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        # self.gconv1.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x_adj):
        x,adj = x_adj[0],x_adj[1]
        out = self.net1(x)
         
        out = torch.permute(out,(0,2,1)) 
        out = self.gconv1(out,adj)
         
        out = torch.permute(out,(0,2,1))
        out = self.net2(out)
        #print('net2',out.shape)
        res = x if self.downsample is None else self.downsample(x)
        #print('res',res.shape)
        return [self.relu(out + res),adj]
    

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x, adj):
        return self.network([x, adj])
    

class SKI_CL(nn.Module):
    def __init__(self,num_nodes,input_dim,rnn_units,output_dim,lag,horizon,num_layers,cheb_k):
        super(SKI_CL, self).__init__()
        
        self.num_nodes = num_nodes
        self.input_dim = input_dim
        self.hidden_dim = rnn_units
        self.output_dim = output_dim
        self.horizon = horizon
        self.num_layers = num_layers
        self.graph_generator = Graph_Generator(num_nodes)    
 
        self.mask = torch.eye(num_nodes, num_nodes).bool().cuda()

        self.encoder = TemporalConvNet(lag, num_channels =[128]*3, kernel_size=3, dropout=0.2)
        self.out_proj = nn.Linear(128, horizon)
       
    def forward(self,source,prior_form,teacher_forcing_ratio=0.5):
       

        x = torch.squeeze(source)


        adj = self.graph_generator.sample(x,prior_form,hard = True)
        
        
        adj = normalize_adjacency(adj)
        
        adj = adj.to(source.device)

        batch_size,node_num,_ = adj.shape
       
        
        x = torch.permute(source,(0,2,1))
        x_adj = self.encoder(x,adj)
        x = x_adj[0]
        feature = torch.permute(x,(0,2,1))
        

        output = self.out_proj(feature)

      

       
        return output,feature















    
