import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import dataset.json_graph as json_graph
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv

import utils

class AutoregressiveGNN(torch.nn.Module):

    def __init__(self, num_nodes, n_steps, node_info, emb_dim = 64,conv_type='GraphSAGE', **kwargs):

        super().__init__()

        self.dropout_rate = 0.1
        self.num_nodes = num_nodes 
        self.device=kwargs['device']
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"])
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]

        emb_dim = emb_dim
        out_dim = 64 
        if conv_type=="GraphSAGE":
            CONV = SAGEConv
        elif conv_type =="GCN":
            CONV = GCNConv
        elif conv_type=='GAT':
            CONV = GATConv
        

        # allocate an embedding vector for each node
        # shape is (num_nodes, emb_dim)
        self.emb = torch.nn.Parameter(torch.randn(num_nodes, emb_dim, dtype=torch.float32) * 0.1)

        if conv_type=='GAT':
            self.GNN_layers = torch.nn.ModuleList([CONV((emb_dim + n_steps), 64,heads=3),CONV(3*64, 64,heads=3),CONV(3*64, out_dim,heads=1)])
        else:
            self.GNN_layers = torch.nn.ModuleList([CONV((emb_dim + n_steps), 64),CONV(64, 64),CONV(64, out_dim)])
        


        self.padding_flag = torch.unsqueeze(torch.tensor(self.cat_ranges == -1), 0)

        self.dense0 = torch.nn.Linear(emb_dim+out_dim, 1)

        self.dense1 = torch.nn.Linear(out_dim+n_steps, 1)
        self.dense2 = torch.nn.Linear(out_dim+n_steps, 1)
        self.validation_95th = torch.nn.parameter.Parameter()

    def to(self, device):
        super().to(device)
        self.padding_flag = self.padding_flag.to(device) 

    def get_additional_loss_terms(self):

        return 0
    def forward(self, data, **kwargs):

        # node features, edge list
        x, edge_index = data.x, data.edge_index
        
        n_graphs = data.num_graphs if 'n_graphs' not in kwargs else kwargs['n_graphs']
        
        if len(x.shape)==3:
            x=x.reshape(x.shape[0]*x.shape[1],-1)
        # repeat embeddings for different graphs
        batch_emb = torch.cat([self.emb] * n_graphs)
        
        # concatenate embeddings with node features
        xin = torch.cat([batch_emb, x], dim=1)

        # run GNN
        hidden = xin
        
        for layer in self.GNN_layers[:-1]:
            hidden=layer(hidden,edge_index)
            hidden = F.relu(hidden)
        

       
        #hidden = F.dropout(hidden, training=self.training, p=self.dropout_rate)

        # recover output of shape (batch_size x nodes) x dim to multiple graphs
        xout = self.GNN_layers[-1](hidden, edge_index)
        xout = xout.reshape([-1, self.num_nodes, xout.shape[1]])
        x = x.reshape([xout.shape[0],self.num_nodes,-1])
        # predict for graph nodes
        pred = self._predict(xout,x)
        
        return pred

    def _predict(self, xout,x):

        cat_nodes = self.cat_nodes
        ranges = self.cat_ranges

        # setting padding entries to be 0
        range_ind = ranges
        range_ind[ranges < 0] = 0

        # get embedding vectors for category nodes
        cat_emb = self.emb[range_ind, :]
        # [1, cat_nodes, categories, emb_dim]
        cat_emb = cat_emb.unsqueeze(0).repeat([xout.shape[0], 1, 1, 1]) 

        cat_out = xout[:, cat_nodes, :]
        # [batch_size, cat_nodes, classes, out_dim]
        cat_out = cat_out.unsqueeze(2).repeat(1, 1, ranges.shape[1], 1)

        cat_input = torch.cat([cat_emb, cat_out], dim=3)
        

        logits = self.dense0(cat_input).squeeze(-1)

        # compute predictions for categorical nodes with multiple possible values 
        # shape is (batch_size, selected_nodes, max_num_class)
        #logits = xout[:, cat_nodes, :] 
        
        logits.masked_fill_(self.padding_flag, -1e3) 
        # prediction for numerical nodes
        # (batch_size, selected_nodes, 1)
        select_ = torch.from_numpy(self.nume_nodes).to(self.device).long()
       
        nume_in = torch.cat([xout[:, select_, :],x[:,select_, :]],dim=-1)
        numerical_predict = self.dense1(nume_in)#xout[:, self.nume_nodes, :]) 
        # predictions for binary nodes
        # (batch_size, selected_nodes, 1)
        select_ = torch.tensor(self.bin_nodes).long()
        
        bin_in = torch.cat([xout[:,select_, :],x[:, select_, :]],dim=-1)
        binary_predict =self.dense2(bin_in)#xout[:, self.bin_nodes, :]) 
        pred = dict(categorical=logits, numerical=numerical_predict, binary=binary_predict)
        return pred 


