from typing import OrderedDict
import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn.functional as F
import dataset.json_graph as json_graph
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv

import utils

class MLP(torch.nn.Module):

    def __init__(self, num_nodes, n_steps, node_info, out_dim=128, h_dim=256,n_layers = 1, dropout_rate = 0.2,**kwargs):

        super().__init__()

        self.dropout_rate = dropout_rate
        self.num_nodes = num_nodes 
        self.n_steps = n_steps
        self.device = kwargs['device']
        print(dropout_rate)
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"]).to(self.device)
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]
        self.random_nodes=node_info.get("random_nodes",{"numerical":[],"binary":[],"categorical":[]})
        self.node_info=node_info

        
        # shape [cat_nodes, feat_dim, max_cat]
        temp = np.expand_dims(self.cat_ranges.cpu().numpy(), [1])
        self.input_selector = torch.tensor(np.tile(temp, [1, 1, self.n_steps, 1])).to(self.device)
        self.input_flag = (self.input_selector.flatten() >= 0)


        out_dim = out_dim 
        h_dim= h_dim
        print(out_dim,h_dim)
        # allocate an embedding vector for each node

        input_dim = (np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)) + self.nume_nodes.shape[0] + self.bin_nodes.shape[0]) * n_steps 
        print("Input dimension is ", input_dim)
        #print(np.sum((self.cat_ranges >= 0).astype(np.int32)) , self.nume_nodes.shape[0] , self.bin_nodes.shape[0])
        
        layer_list = [torch.nn.Linear(input_dim, h_dim),torch.nn.ReLU()] + [torch.nn.Linear(h_dim, h_dim),torch.nn.ReLU()]*n_layers
        self.dense0 = torch.nn.Sequential(*layer_list)#torch.nn.Linear(input_dim, h_dim),torch.nn.ReLU(),torch.nn.Linear(h_dim, h_dim),torch.nn.ReLU())
        self.dense1 = torch.nn.Linear(h_dim, out_dim)

        self.dense_cat  = torch.nn.Linear(out_dim, np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)))
        self.dense_nume = torch.nn.Linear(out_dim, self.nume_nodes.shape[0])
        self.dense_bin  = torch.nn.Linear(out_dim, self.bin_nodes.shape[0])
        self.output_size = np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32))+self.nume_nodes.shape[0]+self.bin_nodes.shape[0]

        
        self.device = 0

    def to(self, device):
        super().to(device)
        self.input_selector = self.input_selector.to(device) 
        self.device = device
    
    def get_additional_loss_terms(self):

        return 0

    def forward(self, data, **kwargs):

        # node features, edge list
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x,p =self.dropout_rate)
        x = x.reshape([-1, self.num_nodes, self.n_steps]) 
        batch_size = x.shape[0]
        
           
        
        #if there exist each feature type
        feats = []
        if self.nume_nodes.shape[0]!=0:
            nume_feat = x[:, self.nume_nodes, :].reshape([batch_size, -1])
            feats.append(nume_feat)

        if self.cat_nodes.shape[0]!=0:
            # [batch, cat_nodes, concat_steps, 1]
            cat_x = x[:, self.cat_nodes, :].unsqueeze(3)
            onehot_inputs = (cat_x == self.input_selector)       
            cat_feat = onehot_inputs.reshape([batch_size, -1])[:, self.input_flag]
            feats.append(cat_feat)
            
        if self.bin_nodes.shape[0]!=0:
            # [batch, bin_nodes*concat_steps]
            bin_feat = x[:, self.bin_nodes, :].reshape([batch_size, -1])
            feats.append(bin_feat)

        feat_input = torch.cat(feats, axis=1)
        hidden = self.dense0(feat_input)
        hidden = F.dropout(hidden,p =self.dropout_rate)
        xout = self.dense1(hidden)
        pred = dict()

        # prediction for numerical nodes
        if self.nume_nodes.shape[0]!=0:
            numerical_predict =  self.dense_nume(xout).unsqueeze(2)
            pred['numerical'] = numerical_predict       

        if self.cat_nodes.shape[0]!=0:
            # predict categorical labels
            sparse_logits = self.dense_cat(xout) 
            
            logits = -1000 * torch.ones([batch_size, self.cat_ranges.cpu().numpy().size], dtype=torch.float32, device=self.device)
            
            logits[:, self.cat_ranges.flatten() >= 0] = sparse_logits
            logits = logits.reshape([batch_size, self.cat_ranges.shape[0], self.cat_ranges.shape[1]])
            
            pred['categorical']=logits

        if self.bin_nodes.shape[0]!=0:
             # predictions for binary nodes
            binary_predict =self.dense_bin(xout).unsqueeze(2)
            
            pred['binary']=binary_predict




        return pred


