# Add a learnable graph neural network as in MTGNN that mixes batches of timeseries

# Input shape (number of timeseries, sequence length, feature dimension)
# Apply a map in the number of timeseries and feature dimensions. 


from __future__ import division
import torch
import torch.nn as nn
from torch.nn import init
import numbers
import torch.nn.functional as F

class LayerNorm(nn.Module):
    __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
            self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()


    def reset_parameters(self):
        if self.elementwise_affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def forward(self, input, idx):
        if self.elementwise_affine:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps)
        else:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps)

    def extra_repr(self):
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
    

class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self,x, A):
        x = torch.einsum('ncwl,vw->ncvl',(x,A))
        return x.contiguous()

class linear(nn.Module):
    def __init__(self,c_in,c_out,bias=True):
        super(linear,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)

    def forward(self,x):
        return self.mlp(x)

class mixprop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(mixprop, self).__init__()
        self.nconv = nconv()
        self.mlp = linear((gdep+1)*c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha


    def forward(self,x,adj):
        
        adj = adj + torch.eye(adj.size(0)).to(x.device)
        d = adj.sum(1)
        h = x
        out = [h]
        a = adj / d.view(-1, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
            out.append(h)
        ho = torch.cat(out,dim=1)
        ho = self.mlp(ho)
        return ho

class graph_constructor(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(graph_constructor, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1]
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.emb2 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim,dim)
            self.lin2 = nn.Linear(dim,dim)

        self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))
        
        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
        mask.fill_(float('0'))
        s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj

class gtnet(nn.Module):
    def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32,seq_length=12, in_dim=2,propalpha=0.05, tanhalpha=3, layer_norm_affline=True, use_layer_norm=True, use_sequence_layer_norm=False):
        super(gtnet, self).__init__()
        self.use_sequence_layer_norm = use_sequence_layer_norm
        self.gcn_true = gcn_true
        self.buildA_true = buildA_true
        self.num_nodes = num_nodes
        self.dropout = dropout
        self.predefined_A = predefined_A
        if self.predefined_A==None:
            # identity matrix
            self.predefined_A = torch.eye(self.num_nodes)
            self.predefined_A = self.predefined_A.to("cuda")
        
        self.residual_convs = nn.Conv2d(in_channels=conv_channels,
                                                    out_channels=residual_channels,
                                                 kernel_size=(1, 1))
        self.gconv1 = mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)
        self.gconv2 = mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)
        self.seq_length = seq_length
        self.use_layer_norm = use_layer_norm
        
        if self.use_layer_norm and self.use_sequence_layer_norm:
            self.norm = LayerNorm((residual_channels, num_nodes, self.seq_length),elementwise_affine=layer_norm_affline)
        elif self.use_layer_norm and not self.use_sequence_layer_norm:
            self.norm = LayerNorm((residual_channels, num_nodes,1),elementwise_affine=layer_norm_affline)
        
        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=residual_channels,
                                    kernel_size=(1, 1))
        self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat)

        self.idx = torch.arange(self.num_nodes).to(device)

    def forward(self, input, idx=None):
        
        # Input will have shape (nr timeseries, sequence length, feature dim)

        # We want to reshape to shape
        # (Batch size, feature dim ?, nr timeseries, sequence length)
        # Add a dimension of size 1 at the beginning
        input = input.unsqueeze(0)
        input = input.permute(0,3,1,2)
        ####
        seq_len = input.size(3)
        
        assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length'

        
        if self.gcn_true:
            if self.buildA_true:
                if idx is None:
                    adp = self.gc(self.idx)
                else:
                    adp = self.gc(idx)
            else:
                adp = self.predefined_A

        x = self.start_conv(input)
        residual = x
        x = F.dropout(x, self.dropout, training=self.training)
        if self.gcn_true:
            x = self.gconv1(x, adp)+self.gconv2(x, adp.transpose(1,0))
        else:
            x = self.residual_convs(x)

        x = x + residual[:, :, :, -x.size(3):]
        
        if self.use_layer_norm and self.use_sequence_layer_norm:
            if idx is None:
                x = self.norm(x,self.idx)
            else:
                x = self.norm(x,idx)
        elif self.use_layer_norm and not self.use_sequence_layer_norm:
            
            x = x.permute(0,3,1,2)
            x = x.unsqueeze(-1)
            #merge the first two dimensions
            batch_size = x.size(0)
            seq_len = x.size(1)
            x = x.reshape(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4))
            x = self.norm(x,self.idx)
            #reshape back to original shape
            x = x.reshape(batch_size,seq_len,x.size(1),x.size(2),x.size(3))
            x = x.squeeze(-1)
            x = x.permute(0,2,3,1)

        
        # x currently have shape (Batch size, feature dim ?, nr timeseries, sequence length)
        # Reshape to (batch size, nr timeseries, sequence length, feature dim)
        x = x.permute(0,2,3,1)
        x = x.squeeze(0)
        return x



def test_gnn_not_mixing_sequence_dimension(nn, datapoint):
    from torch.autograd import grad
    # set seed
    torch.random.manual_seed(0)
    x = datapoint
    y = nn(x)
    L = x.size(1)
    # gradients from the future must be approx. zero
    for i in range(L-1):
        g = grad(y[0, i, 0], x, retain_graph=True, allow_unused=True)[0]
        #future = g[0, i + 1 :, 0] + g[0, 0 :i, 0]
        #concatenate the gradients from the future and the past
        gradients = torch.cat((g[0, i + 1 :, 0],g[0, 0 :i, 0]),dim=0)

        assert torch.max(torch.abs(gradients)) < 1e-4, "function is not causal"
        # assert torch.allclose(
        #     g[0, 0, 0, 0, i + 1 :], torch.zeros_like(g[0, 0, 0, 0, i + 1 :]), atol=1e-5
        # ), "function is not causal"
    print("Test passed, function is causal")

def main():
    # Test the gtnet model
    # Create a random input tensor
    # (nr timeseries, sequence length, feature dim)
    nr_timeseries = 10
    sequence_length = 196
    feature_dim = 32
    input = torch.rand(nr_timeseries, sequence_length, feature_dim, requires_grad=True)
    # Create a gtnet model
    gtnet_model = gtnet(
        gcn_true=True, 
        buildA_true=True, 
        gcn_depth=1, 
        num_nodes=nr_timeseries, 
        device='cpu', 
        predefined_A=None, 
        static_feat=None, 
        dropout=0.3, 
        subgraph_size=nr_timeseries, #smaller or equal to num_nodes 
        node_dim=40, 
        dilation_exponential=1, 
        conv_channels=32, 
        residual_channels=32, 
        seq_length=sequence_length, 
        in_dim=feature_dim, 
        propalpha=0.05, 
        tanhalpha=3, 
        layer_norm_affline=True,
        use_sequence_layer_norm=False)
    # Test the forward pass
    output = gtnet_model(input)
    test_gnn_not_mixing_sequence_dimension(gtnet_model, input)
    print(output.shape)
    print("Input and output same shape: {0}".format(input.shape == output.shape))


if __name__ == '__main__':
    main()


