import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from torch_geometric.data.batch import Batch
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
from torch_geometric.data import Data
import math
import torch.nn.functional as F
#from .graph_layer import GraphLayer
#from .GDN import GDN
from models.mlp import MLP
from models.AutoregressiveGNN import AutoregressiveGNN



class ReconstructingModel(nn.Module):
    '''
    Wrapper Module for Models that implements reconstructing model. 
        base_model: Already created nn.Module of the model to be used with masking. Must handle multiple variable types itself. 
       
        n_steps: window size for series
        n_feats: number of features/nodes/sensors
        batch_size: batch_size
    
    '''
    def __init__(self, base_model, n_steps,n_feats,**kwargs):
        super().__init__()

        self.n_steps = n_steps # includes last state
        print(n_steps)
        self.n_feats = n_feats
        self.device = kwargs['device']
        self.model = base_model
        self.model.to(self.device)

        self.cat_nodes = self.model.cat_nodes
        self.cat_ranges = self.model.cat_ranges
        self.nume_nodes = self.model.nume_nodes
        self.bin_nodes = self.model.bin_nodes
        
        

    def get_additional_loss_terms(self):

        return self.model.get_additional_loss_terms()

    def to(self,device):
        super().to(device)
        
        self.device = device

   
    

    def forward(self, data, **kwargs):

        x, edge_index, last_state = data.x, data.edge_index, data.last_state
        
        n_features = self.n_feats
        n_steps = self.n_steps
        
        batch_size = x.shape[0]    
        
        
        last_state = last_state.reshape(-1,n_features,1)

        x = x.reshape([-1,self.n_feats, n_steps-1])

        x_with_last_state = torch.cat([x,last_state], dim=-1)
        data.x = x_with_last_state.float()

        out = self.model(data=data, **kwargs) #dict with categorical, binary and numerical predictions
        

        #NOTE: Need to collect results in correct order from the masked data.
        result = dict() 

        if self.nume_nodes.shape[0]!=0:

            predict_num = out['numerical'].reshape(batch_size,self.nume_nodes.shape[0],-1)
            result['numerical'] = predict_num

        if self.cat_nodes.shape[0]!=0:

            predict_cat = out['categorical'].reshape(batch_size,self.cat_ranges.shape[0],-1)
            result['categorical'] = predict_cat

        if self.bin_nodes.shape[0]!=0:
        
            predict_bin = out['binary'].reshape(batch_size,self.bin_nodes.shape[0],-1)
            result['binary'] = predict_bin


        return result
    


    def test_prediction(self,data, org_edge_index, last_state):
        return self(data,org_edge_index,last_state)