import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn.functional as F
import dataset.json_graph as json_graph
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, DenseGCNConv

import utils

class MaskedGNN(torch.nn.Module):

    def __init__(self, num_nodes, n_masks, emb_dim, feat_dim, node_info, conv_type='GraphSAGE'):

        super().__init__()

        self.dropout_rate = 0.1
        self.num_nodes = num_nodes 
        self.n_masks = n_masks
        self.masking_indeces = torch.arange(num_nodes)
        
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"])
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]
        self.device='cpu'
        emb_dim = 64 
        out_dim = 64 
        feat_dim +=1 #takes last state as input as well

        if conv_type=="GraphSAGE":
            CONV = SAGEConv
        elif conv_type =="GCN":
            CONV = GCNConv
        elif conv_type=='GAT':
            CONV = GATConv
        

        # allocate an embedding vector for each node
        # shape is (num_nodes, emb_dim)
        self.emb = torch.nn.Parameter(torch.randn(num_nodes, emb_dim, dtype=torch.float32) * 0.1)

        if conv_type=='GAT':
            self.GNN_layers = torch.nn.ModuleList([CONV((emb_dim + feat_dim), 64,heads=3),CONV(3*64, 64,heads=3),CONV(3*64, out_dim,heads=1)])
        else:
            self.GNN_layers = torch.nn.ModuleList([CONV((emb_dim + feat_dim), 64),CONV(64, 64),CONV(64, out_dim)])
        


        self.padding_flag = torch.unsqueeze(torch.tensor(self.cat_ranges == -1), 0)

        self.dense0 = torch.nn.Linear(emb_dim + out_dim, 1)

        self.dense1 = torch.nn.Linear(out_dim, 1)
        self.dense2 = torch.nn.Linear(out_dim, 1)
        self.validation_95th = torch.nn.parameter.Parameter()

    def to(self, device):
        super().to(device)
        self.device = device
        self.padding_flag = self.padding_flag.to(device) 


    def _create_masks(self,n_nodes, n_masks,bsz = 1):
        '''
        Generates masks that are used for prediction. Ideally we would have one mask for each node but thats impossible to do. Instead lets choose M masks such that each node is only masked once.

        '''

        node_inds = self.masking_indeces
        node_groups = torch.stack(torch.split(node_inds,n_masks))    
        mask_inds = torch.arange(n_masks)
        masks = torch.ones(bsz,n_masks,n_nodes,1)
        masks[:,mask_inds,node_groups]=0

        return masks



    def forward(self, data):
        #Use data.last_state to mask and predict masked values.
        #ideally we want that for each node.

        # node features, edge list
        x, edge_index, last_state = data.x, data.edge_index, data.last_state

        n_graphs = data.num_graphs

        #In how many groups to split nodes for prediction. Must divide n_nodes exactly. In the limit, predict each node given the entire rest of the graph
        #But that takes alot of memory, so mask groups at a time. 
        n_masks = self.n_masks      
        masks = self._create_masks(n_nodes = self.num_nodes, n_masks=n_masks, bsz= n_graphs)        

        #Create n_masks copies of the last state, with the appropriate nodes masked
        masked_last_state = last_state.view([-1,1, self.num_nodes, 1])*masks.to(self.device)
        x = x.reshape([-1, 1,self.num_nodes, x.shape[-1]])
        repeated_x = x.repeat(1,n_masks,1,1)
        
        #n_graphs*n_masks graphs, where each graph has the last state partially masked for prediction.
        x_with_masked_last_state = torch.cat((repeated_x,masked_last_state), axis =-1)
        x_with_masked_last_state = x_with_masked_last_state.view(-1,x.shape[-1]+1)

        #so batch size is n_masks times the original.
        n_graphs = data.num_graphs * n_masks
        # repeat embeddings for different graphs
        
        batch_emb = torch.cat([self.emb] * n_graphs)


        # concatenate embeddings with node features
        xin = torch.cat([batch_emb, x_with_masked_last_state], dim=1)

        # run GNN
        hidden = xin
        for layer in self.GNN_layers[:-1]:
            hidden=layer(hidden,edge_index)
            hidden = F.relu(hidden)
        

        hidden = F.relu(hidden)
        hidden = F.dropout(hidden, training=self.training, p=self.dropout_rate)

        # recover output of shape (batch_size x nodes) x dim to multiple graphs
        xout = self.GNN_layers[-1](hidden, edge_index)
        xout = xout.reshape([-1, self.num_nodes, xout.shape[1]])


        


        # predict for graph nodes
        pred = self._predict(xout)
        result = {}
        
        if self.nume_nodes.shape[0]!=0:
            numerical_predict = pred['numerical']
            numerical_predict = numerical_predict.view(-1,n_masks,self.nume_nodes.shape[0],1)  #split by mask
            numerical_predict = numerical_predict[masks[:,:,self.nume_nodes]==0].view(-1,self.nume_nodes.shape[0],1) #collect the predictions for masked values
            result['numerical'] = numerical_predict
       

        if self.cat_nodes.shape[0]!=0:
            logits = pred['categorical']
            logits=  logits.view(-1,n_masks,logits.shape[1], logits.shape[2]) #split by mask
            logit_mask = masks.view(-1,n_masks,self.num_nodes)
            logits = logits[logit_mask[:,:,self.cat_nodes]==0].view(-1,self.cat_ranges.shape[0],self.cat_ranges.shape[1]) #collect the predictions for masked values
            result['categorical'] = logits


        if self.bin_nodes.shape[0]!=0:
            # predictions for binary nodes
            binary_predict= pred['binary']
            binary_predict = binary_predict.view(-1,n_masks,self.bin_nodes.shape[0],1)
            binary_predict = binary_predict[masks[:,:,self.bin_nodes]==0].view(-1,self.bin_nodes.shape[0],1)
            result['binary'] = binary_predict

        
        
        return result

    def _predict(self, xout):

        pred = {}
        if self.cat_nodes.shape[0]!=0:
            cat_nodes = self.cat_nodes
            ranges = self.cat_ranges
            # setting padding entries to be 0
            range_ind = ranges
            range_ind[ranges < 0] = 0

            # get embedding vectors for category nodes
            cat_emb = self.emb[range_ind, :]
            # [1, cat_nodes, categories, emb_dim]
            cat_emb = cat_emb.unsqueeze(0).repeat([xout.shape[0], 1, 1, 1]) 

            cat_out = xout[:, cat_nodes, :]
            # [batch_size, cat_nodes, classes, out_dim]
            cat_out = cat_out.unsqueeze(2).repeat(1, 1, ranges.shape[1], 1)

            cat_input = torch.cat([cat_emb, cat_out], dim=3)

            logits = self.dense0(cat_input).squeeze(-1)

            # compute predictions for categorical nodes with multiple possible values 
            # shape is (batch_siz, selected_nodes, max_num_class)
            #logits = xout[:, cat_nodes, :] 
            
            logits.masked_fill_(self.padding_flag, -1e3) 
            pred['categorical']=logits

        # prediction for numerical nodes
        # (batch_size, selected_nodes, 1)
        if self.nume_nodes.shape[0]!=0:
            numerical_predict = self.dense1(xout[:, self.nume_nodes, :]) 
            pred['numerical'] = numerical_predict
        # predictions for binary nodes
        # (batch_size, selected_nodes, 1)
        if self.bin_nodes.shape[0]!=0:
            binary_predict =self.dense2(xout[:, self.bin_nodes, :]) 
            pred['binary'] = binary_predict

        
        return pred 


