import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn.functional as F
import dataset.json_graph as json_graph
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv

import utils

class TabularMaskedMLP(torch.nn.Module):

    def __init__(self, num_nodes, n_masks , feat_dim, node_info):

        super().__init__()
        
        self.dropout_rate = 0.3
        self.num_nodes = num_nodes 
        self.feat_dim = feat_dim
        
        self.n_masks = n_masks

        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = node_info["cat_ranges"]
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]



        self.masking_indeces = torch.arange(num_nodes)

        
        # shape [cat_nodes, feat_dim, max_cat]
        temp = np.expand_dims(self.cat_ranges, [1])
        self.input_selector = torch.tensor(np.tile(temp, [1, 1, self.feat_dim, 1]))
        self.input_flag = (self.input_selector.flatten() >= 0)

        out_dim = 64 
        h_dim = 128
        # allocate an embedding vector for each node

        input_dim = (np.sum((self.cat_ranges >= 0).astype(np.int32)) + self.nume_nodes.shape[0] + self.bin_nodes.shape[0]) *(feat_dim )
        print("Input dimension is ", input_dim)
        print(np.sum((self.cat_ranges >= 0).astype(np.int32)) , self.nume_nodes.shape[0] , self.bin_nodes.shape[0])

        self.dense0 = torch.nn.Sequential(torch.nn.Linear(input_dim, h_dim),torch.nn.ReLU(),torch.nn.Linear(h_dim, h_dim),torch.nn.ReLU())
        self.dense1 = torch.nn.Linear(h_dim, out_dim)

        self.dense_cat  = torch.nn.Linear(out_dim, np.sum((self.cat_ranges >= 0).astype(np.int32)))
        self.dense_nume = torch.nn.Linear(out_dim, self.nume_nodes.shape[0])
        self.dense_bin  = torch.nn.Linear(out_dim, self.bin_nodes.shape[0])

        
        self.device = 0

    def to(self, device):
        super().to(device)
        self.input_selector = self.input_selector.to(device) 
        self.device = device


    def _create_masks(self,n_nodes, n_masks,bsz = 1):
        '''
        Generates masks that are used for prediction. Ideally we would have one mask for each node but thats impossible to do. Instead lets choose M masks such that each node is only masked once.

        '''

        node_inds = self.masking_indeces
        node_groups = torch.stack(torch.split(node_inds,n_masks))    
        mask_inds = torch.arange(n_masks)
        masks = torch.ones(bsz,n_masks,n_nodes,1)
        masks[:,mask_inds,node_groups]=0

        return masks

    def forward(self, data):

        # node features, edge list
        x = data.x
        batch_size = data.num_graphs


        x = x.reshape([batch_size, self.num_nodes, self.feat_dim]) 
        

        n_graphs = data.num_graphs

        #In how many groups to split nodes for prediction. Must divide n_nodes exactly. In the limit, predict each node given the entire rest of the graph
        #In the other limit we mask the entire last state with 1 mask and recover the previous model. 
        #But that takes alot of memory, so mask groups at a time. 
        n_masks = self.n_masks      
        #mask shape: (bsz,n_masks,n_nodes,n_feats)
        masks = self._create_masks(n_nodes = self.num_nodes, n_masks=n_masks, bsz= n_graphs)        
        
        #Create n_masks copies of the last state, with the appropriate nodes masked
        masked_x = x.view([-1,1, self.num_nodes, 1])*masks.to(self.device)
        x = masked_x
        
        #n_graphs*n_masks graphs, where each graph has the last state partially masked for prediction.
        batch_size*=n_masks
        
        x = x.view(batch_size,self.num_nodes,self.feat_dim)

        # [batch, cat_nodes, concat_steps, 1]
        
        nume_feat = x[:, self.nume_nodes, :].reshape([batch_size, -1])
        bin_feat = x[:, self.bin_nodes, :].reshape([batch_size, -1])

        cat_x = x[:, self.cat_nodes, :].unsqueeze(3)
        
        #if there exist categorical nodes
        if self.cat_nodes.shape[0]!=0:
            onehot_inputs = (cat_x == self.input_selector)       

            cat_feat = onehot_inputs.reshape([batch_size, -1])[:, self.input_flag]
            feat_input = torch.cat([cat_feat, nume_feat, bin_feat], axis=1)
        else:
            feat_input = torch.cat([nume_feat, bin_feat], axis=1)


        hidden = self.dense0(feat_input)
        hidden = F.dropout(hidden,p =self.dropout_rate)
        xout = self.dense1(hidden)

        # prediction for numerical nodes
        numerical_predict =  self.dense_nume(xout).unsqueeze(2)
        numerical_predict = numerical_predict.view(-1,n_masks,self.nume_nodes.shape[0],1)  #split by mask
        numerical_predict = numerical_predict[masks[:,:,self.nume_nodes]==0].view(-1,self.nume_nodes.shape[0],1) #collect the predictions for masked values
        

        pred = dict(numerical=numerical_predict)

        if self.cat_nodes.shape[0]!=0:
            # predict categorical labels
            sparse_logits = self.dense_cat(xout) 

            logits = -1000 * torch.ones([batch_size, self.cat_ranges.size], dtype=torch.float32, device=self.device)
            logits[:, self.cat_ranges.flatten() >= 0] = sparse_logits
            logits = logits.reshape([batch_size, self.cat_ranges.shape[0], self.cat_ranges.shape[1]])
            
            logits=  logits.view(-1,n_masks,logits.shape[1], logits.shape[2]) #split by mask
            logit_mask = masks.view(-1,n_masks,self.num_nodes)
            logits = logits[logit_mask[:,:,self.cat_nodes]==0].view(-1,self.cat_ranges.shape[0],self.cat_ranges.shape[1]) #collect the predictions for masked values
            
            pred['categorical']=logits

        if self.bin_nodes.shape[0]!=0:
             # predictions for binary nodes
            binary_predict =self.dense_bin(xout).unsqueeze(2)
            binary_predict = binary_predict.view(-1,n_masks,self.bin_nodes.shape[0],1)
            binary_predict = binary_predict[masks[:,:,self.bin_nodes]==0].view(-1,self.bin_nodes.shape[0],1)

            pred['binary']=binary_predict



        return pred


