
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from torch_geometric.data.batch import Batch
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
from torch_geometric.data import Data
import math
import torch.nn.functional as F
#from .graph_layer import GraphLayer
#from .GDN import GDN
from models.mlp import MLP
from models.AutoregressiveGNN import AutoregressiveGNN



class MaskedModel(nn.Module):
    '''
    Wrapper Module for Models that implements masking. 
        base_model: Already created nn.Module of the model to be used with masking. Must handle multiple variable types itself. 
        n_mask: the number of masks to be used
        n_steps: window size for series
        n_feats: number of features/nodes/sensors
        batch_size: batch_size
        masking_grous: list of lists with feature indeces. Each list is a group of features to be predicted together from the other groups

    
    '''
    def __init__(self, base_model,n_masks, n_steps,n_feats,batch_size = 64,masking_groups = None,**kwargs):
        super().__init__()

        self.n_steps = n_steps # includes last state
        print(n_steps)
        self.n_feats = n_feats
        self.device = kwargs['device']
        self.model = base_model
        self.model.to(self.device)

        self.cat_nodes = self.model.cat_nodes
        self.cat_ranges = self.model.cat_ranges
        self.nume_nodes = self.model.nume_nodes
        self.bin_nodes = self.model.bin_nodes
        self.masking_groups = masking_groups

        if masking_groups is None:
            raise ValueError("Masked model needs masking groups")

        self.nume_node_mask_inds,self.cat_node_mask_inds,self.bin_node_mask_inds = self._create_mask_mappings()
        
        
        self.n_masks = n_masks
        
        self.prepare_masking(masking_groups,batch_size)
        
        
        print(f"Using masked model with {n_masks} masks on {n_feats} features.")

        

    def prepare_masking(self,masking_groups,batch_size):
        self.masking_groups = masking_groups
        self.feat_order = [n for g in masking_groups for n in g]
        self.masks = self._create_masks(self.n_feats,self.n_masks,batch_size).to(self.device)


    def get_additional_loss_terms(self):

        return self.model.get_additional_loss_terms()

    def to(self,device):
        super().to(device)
        
        self.device = device

    def _create_mask_mappings(self):
        nume_node_mask_inds = []
        cat_node_mask_inds = []
        bin_node_mask_inds = []
        
        for i in self.nume_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    nume_node_mask_inds.append(mask_id)
        
        for i in self.cat_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    cat_node_mask_inds.append(mask_id)
        
        for i in self.bin_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    bin_node_mask_inds.append(mask_id)
        

        return nume_node_mask_inds,cat_node_mask_inds,bin_node_mask_inds

    

    def _create_masks(self,n_feats, n_masks,bsz = 1):
        '''
        Generates masks that are used for prediction. We could have at most N masks for N nodes. We can also choose M<N masks such that each node is only masked once.

        '''
        feat_groups = self.masking_groups
        masks = torch.ones(bsz,n_masks,n_feats,1)
        for i,group in enumerate(feat_groups):
            masks[:,i,group]=0

        return masks

    #For now define the loss outside of here, but we probably need to bring it back in here
    #def loss_func(self,y_pred, y_true):
    #    loss = F.mse_loss(y_pred, y_true, reduction='mean')
    #    return loss

    def forward(self, data, **kwargs):

        x, edge_index, last_state = data.x, data.edge_index, data.last_state
        n_masks = self.n_masks
        n_features = self.n_feats
        n_steps = self.n_steps
        
        batch_size = x.shape[0]#//n_features        
        #create masked version of last state
    
        masks = self.masks.to(self.device)[:batch_size].detach()

        masked_last_state = last_state.reshape(-1,1,n_features,1) * masks  #(batch_size,num_masks,n_nodes,1)

        x = x.reshape([-1, 1,self.n_feats, n_steps-1])
        repeated_x = x.expand(x.shape[0],n_masks,x.shape[2],x.shape[3]) #shape: (bsz,n_masks,n_nodes,n_steps)

        x_with_masked_last_state = torch.cat((repeated_x,masked_last_state), axis =-1)
        x_with_masked_last_state = x_with_masked_last_state.view(-1,n_features,n_steps)

        x = x_with_masked_last_state # shape: (bsz*n_masks,n_nodes,n_feats)
       
        effective_batch_size = x.shape[0]
        

        data.x = x
        if len(edge_index.shape)>1:
            n_edges = edge_index.shape[1]
        
            new_edge_index = edge_index.repeat([1,n_masks])
            factor = torch.arange(0,n_masks).reshape(1,-1)
            mod = (n_features*factor).repeat_interleave(n_edges,dim=1).to(self.device) 
            #[2,n_edges*batch_size*n_masks]
            new_edge_index=new_edge_index+mod
            
            data.edge_index = new_edge_index
        #data.num_graphs = effective_batch_size
        out = self.model(data=data,n_graphs = effective_batch_size, **kwargs) #dict with categorical, binary and numerical predictions

        #NOTE: Need to collect results in correct order from the masked data.
        result = dict() 

        if self.nume_nodes.shape[0]!=0:

            predict_num = out['numerical'].reshape(batch_size,self.n_masks,self.nume_nodes.shape[0],-1)
            predict_num = predict_num[:,self.nume_node_mask_inds,torch.arange(self.nume_nodes.shape[0])] #collect each num feature vector from its corresponding mask 
            result['numerical'] = predict_num

        if self.cat_nodes.shape[0]!=0:

            predict_cat = out['categorical'].reshape(batch_size,self.n_masks,self.cat_ranges.shape[0],-1)
            predict_cat = predict_cat[:,self.cat_node_mask_inds,torch.arange(self.cat_ranges.shape[0])] #collect each cat feature vector from its corresponding mask 
            result['categorical'] = predict_cat

        if self.bin_nodes.shape[0]!=0:
        
            predict_bin = out['binary'].reshape(batch_size,self.n_masks,self.bin_nodes.shape[0],-1)
            predict_bin = predict_bin[:,self.bin_node_mask_inds,torch.arange(self.bin_nodes.shape[0])] #collect each bin feature vector from its corresponding mask 
            result['binary'] = predict_bin

        
        return result
    


    def test_prediction(self,data, org_edge_index, last_state):
        return self(data,org_edge_index,last_state)