import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn.functional as F
import dataset.json_graph as json_graph
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from models.AnomalyTransformer.AnomalyTransformer import AnomalyTransformer
from models.GDN.GDN import GDN
import utils
from models.GDN import net_struct as gdn_net_struct
from models.GDN import preprocess as gdn_preprocess
from torch import nn
import sys
class MixedVariableModel(torch.nn.Module):
    '''
    Wrapper for Implemeting models with mixed variable types. Handles logistic of different node types, onehot encoding and separating predictions.
    Underlying model needs to simply accept (b,input_dim,n_steps) and return (b,input_dim,out_dim) (input dim is the final feature vector size after onehot-encoding)
    '''
    def __init__(self,num_nodes, n_steps, node_info, underlying_model = 'transformer',**kwargs):

        super().__init__()

        self.dropout_rate = 0.2
        self.num_nodes = num_nodes 
        self.n_steps = n_steps
        self.device = kwargs['device']
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"]).to(self.device)
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]
        self.random_nodes=node_info.get("random_nodes",{"numerical":[],"binary":[],"categorical":[]})
        self.node_info = node_info
        
        # shape [cat_nodes, feat_dim, max_cat]
        temp = np.expand_dims(self.cat_ranges.cpu().numpy(), [1])
        self.input_selector = torch.tensor(np.tile(temp, [1, 1, self.n_steps, 1])).to(self.device)
        self.input_flag = (self.input_selector.flatten() >= 0)

        out_dim = kwargs.get('output_dim',None)
        if out_dim is None:
            out_dim = kwargs.get('out_dim',None)

        input_dim = (np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)) + self.nume_nodes.shape[0] + self.bin_nodes.shape[0])
        print("Input dimension is ", input_dim)
        
        if underlying_model =='transformer':
            self.underlying_model = AnomalyTransformer(input_dim,n_steps,task='prediction', **kwargs)

        elif underlying_model=='GDN':
            #raise NotImplementedError()
            dataset = kwargs['config'].task
            fc_struc = gdn_net_struct.get_fc_graph_struc(dataset)
            feaure_list= [f'feature_{i}' for i in range(input_dim)]
            fc_edge_index = gdn_preprocess.build_loc_net(fc_struc, feaure_list, feature_map=feaure_list)
            fc_edge_index = torch.tensor(fc_edge_index, dtype = torch.long)
            edge_index_sets = [fc_edge_index]
            self.underlying_model = GDN(edge_index_sets, input_dim, dim=out_dim, input_dim=n_steps,out_layer_inter_dim=out_dim,  **kwargs)
        
        elif underlying_model=='mlp':
            
            print('----',input_dim,file=sys.stderr)
            self.underlying_model = nn.Sequential(nn.Linear(input_dim,64),nn.ReLU(),nn.Linear(64,64),nn.ReLU(),nn.Linear(64,out_dim))

        self.underlying_model.to(self.device)

        self.dense_cat  = torch.nn.Linear(out_dim*input_dim, np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)))
        self.dense_nume = torch.nn.Linear(out_dim*input_dim, self.nume_nodes.shape[0])
        self.dense_bin  = torch.nn.Linear(out_dim*input_dim, self.bin_nodes.shape[0])
        self.output_size = np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32))+self.nume_nodes.shape[0]+self.bin_nodes.shape[0]



    def to(self, device):
        super().to(device)
        self.input_selector = self.input_selector.to(device) 
        self.device = device
        self.underlying_model.to(device)


    def get_additional_loss_terms(self):

        return self.underlying_model.get_additional_loss_terms()


    def forward(self, data, **kwargs):

        # node features, edge list
        x, edge_index = data.x, data.edge_index
        x = x.reshape([-1, self.num_nodes, self.n_steps]) 
        batch_size =  x.shape[0] if 'n_graphs' not in kwargs else kwargs['n_graphs']
        

        #if there exist each feature type
        feats = []
        if self.nume_nodes.shape[0]!=0:
            nume_feat = x[:, self.nume_nodes, :].reshape([batch_size, -1,self.n_steps])
            feats.append(nume_feat)

        if self.cat_nodes.shape[0]!=0:
            # [batch, cat_nodes, concat_steps, 1]
            cat_x = x[:, self.cat_nodes, :].unsqueeze(3)
            
            onehot_inputs = (cat_x == self.input_selector)  

            cat_feat = onehot_inputs.reshape([batch_size, -1])[:, self.input_flag]
            cat_feat = cat_feat.reshape(batch_size,-1,self.n_steps)
            
            #cat_feat = onehot_inputs.reshape([batch_size, -1])[:, self.input_flag]
            feats.append(cat_feat)
            
        if self.bin_nodes.shape[0]!=0:
            # [batch, bin_nodes*concat_steps]
            bin_feat = x[:, self.bin_nodes, :].reshape([batch_size, -1, self.n_steps])
            feats.append(bin_feat)

        #(b,input_dim,n_steps)
        feat_input = torch.cat(feats, axis=1).float()
        
        xout = self.underlying_model(feat_input).reshape(batch_size,-1) #apply underlying model to data 
        
        pred = dict()

        
        if self.nume_nodes.shape[0]!=0:
            # prediction for numerical nodes [b,len(nume_nodes),1]
            numerical_predict =  self.dense_nume(xout).unsqueeze(2)
            pred['numerical'] = numerical_predict       

        if self.cat_nodes.shape[0]!=0:
            # predict categorical labels [b,len(cat_nodes),1]
            
            sparse_logits = self.dense_cat(xout) 
            
            logits = -1000 * torch.ones([batch_size, self.cat_ranges.cpu().numpy().size], dtype=torch.float32, device=self.device)
            logits[:, self.cat_ranges.flatten() >= 0] = sparse_logits
            logits = logits.reshape([batch_size, self.cat_ranges.shape[0], self.cat_ranges.shape[1]])
            
            pred['categorical']=logits

        if self.bin_nodes.shape[0]!=0:
             # predictions for binary nodes
            binary_predict =self.dense_bin(xout).unsqueeze(2)
            
            pred['binary']=binary_predict




        return pred


