import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from torch_geometric.data.batch import Batch
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
from torch_geometric.data import Data
import math
import torch.nn.functional as F
#from .graph_layer import GraphLayer
#from .GDN import GDN
from models.mlp import MLP
from models.AutoregressiveGNN import AutoregressiveGNN
from models.MixedVariableModel import MixedVariableModel
from models import mlp
import sys

class MaskedModel(nn.Module):
    '''
    Wrapper Module for Models that implements masking. 
        base_model: Already created nn.Module of the model to be used with masking. Must handle multiple variable types itself. 
        n_mask: the number of masks to be used
        n_steps: window size for series
        n_feats: number of features/nodes/sensors
        batch_size: batch_size
        masking_grous: list of lists with feature indeces. Each list is a group of features to be predicted together from the other groups

    
    '''
    def __init__(self, base_model,n_masks, n_steps,n_feats,batch_size = 64,masking_groups = None,loss_func = None,**kwargs):
        super().__init__()

        self.n_steps = n_steps # includes last state
        print(n_steps)
        self.n_feats = n_feats
        self.device = kwargs['device']
        self.model = base_model
        self.model.to(self.device)

        self.cat_nodes = torch.tensor(self.model.cat_nodes)
        self.cat_ranges = torch.tensor(self.model.cat_ranges)
        self.nume_nodes = torch.tensor(self.model.nume_nodes)
        self.bin_nodes = torch.tensor(self.model.bin_nodes)
        self.random_nodes = self.model.random_nodes
        self.loss_func = loss_func
        self.masking_groups = masking_groups
        self.testing = False
        if masking_groups is None:
            raise ValueError("Masked model needs masking groups")
        if loss_func is None:
            raise ValueError("Masked Model needs loss function")

        self.nume_node_mask_inds,self.cat_node_mask_inds,self.bin_node_mask_inds = self._create_mask_mappings()
        
        self.anomaly_filter=kwargs.get('anomaly_filter',0.95)
        
        self.n_masks = n_masks
        
        self.prepare_masking(masking_groups,batch_size)
        self.dropout = nn.Dropout(0.2)

        import sys
        print(kwargs,file=sys.stderr)
        self.current_step_model = mlp.MLP(n_feats,3,self.model.node_info,h_dim = 128,n_layers=2,dropout_rate=0, device=self.device)#MixedVariableModel(n_feats,2,self.model.node_info,underlying_model='mlp',**kwargs)
        self.current_step_model.to(self.device)
        
        print(f"Using masked model with {n_masks} masks on {n_feats} features.",file=sys.stderr)

        

    def prepare_masking(self,masking_groups,batch_size):
        self.masking_groups = masking_groups
        self.feat_order = [n for g in masking_groups for n in g]
        self.masks = self._create_masks(self.n_feats,self.n_masks,batch_size).to(self.device)


    def get_additional_loss_terms(self):

        return self.model.get_additional_loss_terms()

    def to(self,device):
        super().to(device)
        
        self.device = device

    def _create_mask_mappings(self):
        nume_node_mask_inds = []
        cat_node_mask_inds = []
        bin_node_mask_inds = []
        
        for i in self.nume_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    nume_node_mask_inds.append(mask_id)
        
        for i in self.cat_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    cat_node_mask_inds.append(mask_id)
        
        for i in self.bin_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    bin_node_mask_inds.append(mask_id)
        

        return nume_node_mask_inds,cat_node_mask_inds,bin_node_mask_inds

    

    def _create_masks(self,n_feats, n_masks,bsz = 1):
        '''
        Generates masks that are used for prediction. We could have at most N masks for N nodes. We can also choose M<N masks such that each node is only masked once.

        '''
        feat_groups = self.masking_groups
        masks = torch.ones(bsz,n_masks,n_feats,1)
        for i,group in enumerate(feat_groups):
            masks[:,i,group]=0

        return masks

    #For now define the loss outside of here, but we probably need to bring it back in here
    #def loss_func(self,y_pred, y_true):
    #    loss = F.mse_loss(y_pred, y_true, reduction='mean')
    #    return loss

    def forward(self, data, **kwargs):

        x, edge_index, last_state = data.x, data.edge_index, data.last_state
        n_masks = self.n_masks
        n_features = self.n_feats
        n_steps = self.n_steps
        

        batch_size = x.shape[0]    
        masks = self.masks.to(self.device)[:batch_size].detach()
        
        pred = self.model(data=data,n_graphs = batch_size, **kwargs)
        
        
        loss,pred_errors = self.loss_func(pred,data,return_separate_losses=True)
        nov_dist=dict()
        
        merged_pred_errors = torch.zeros((batch_size,self.n_feats))
        
            
        merged_pred_errors[:,self.nume_nodes]=pred_errors['numerical'].float()
        if 'categorical' in pred_errors:
            
            merged_pred_errors[:,self.cat_nodes]=pred_errors['categorical'].float()
        if 'binary' in pred_errors:
            merged_pred_errors[:,self.bin_nodes]=pred_errors['binary'].float()
        
        merged_pred_weights = merged_pred_errors.to(self.device)
        for k,v in self.random_nodes.items():
            merged_pred_weights[:,v]=0

        
        merged_pred_weights = -merged_pred_errors/-(merged_pred_errors.max(dim=-1, keepdim=True)[0])
        
        
        merged_pred_weights= (merged_pred_weights<self.anomaly_filter).float().to(self.device)
        
        
        old_vals = x.reshape(-1,n_features,self.n_steps-1)[:,:,-1]
        
        modded_last_state =old_vals *(1-merged_pred_weights)
        modded_last_state += last_state.reshape(-1,n_features)*merged_pred_weights
        
       
        masked_last_state = modded_last_state.reshape(-1,1,n_features,1) * masks
        
    
        merged_preds = torch.zeros((batch_size,self.n_feats,1)).to(self.device)
        merged_preds[:,self.nume_nodes]=pred['numerical']

        if 'categorical' in pred:
            
            idx=pred['categorical'].squeeze(-1).float().argmax(dim=-1)
            samples = F.one_hot(idx,num_classes=pred['categorical'].shape[-1])

            samples = samples.reshape(batch_size,self.cat_ranges.shape[0],self.cat_ranges.shape[1])
            tmp = self.cat_ranges.expand(batch_size,self.cat_ranges.shape[0],self.cat_ranges.shape[1])

            cats = tmp[samples.long()==1].reshape(batch_size,-1,1)
            
            merged_preds[:,self.cat_nodes]=cats.float().to(self.device)

        if 'binary' in pred:
            merged_preds[:,self.bin_nodes]=pred['binary']

        merged_pred_weights=merged_pred_weights.reshape([-1, 1,self.n_feats, 1])
        merged_preds = merged_preds.reshape([-1, 1,self.n_feats, 1]).detach()
        merged_preds = torch.cat([merged_pred_weights,merged_preds],dim=-1)
        repeated_x = merged_preds.expand(merged_preds.shape[0],n_masks,merged_preds.shape[2],merged_preds.shape[3]) #shape: (bsz,n_masks,n_nodes,n_steps)
        x_with_masked_last_state = torch.cat((repeated_x,masked_last_state), axis =-1)
        x_with_masked_last_state = x_with_masked_last_state.view(-1,n_features,3)

        x = x_with_masked_last_state # shape: (bsz*n_masks,n_nodes,n_feats)
        
        effective_batch_size = x.shape[0]
        

        data.x = x
        if len(edge_index.shape)>1:
            n_edges = edge_index.shape[1]
        
            new_edge_index = edge_index.repeat([1,n_masks])
            factor = torch.arange(0,n_masks).reshape(1,-1)
            mod = (n_features*factor).repeat_interleave(n_edges,dim=1).to(self.device) 
            #[2,n_edges*batch_size*n_masks]
            new_edge_index=new_edge_index+mod
            
            data.edge_index = new_edge_index
        #data.num_graphs = effective_batch_size
        
        pred = self.current_step_model(data=data,n_graphs = effective_batch_size, **kwargs) #dict with categorical, binary and numerical predictions

        #print(out['numerical'].shape,file=sys.stderr)
        #NOTE: Need to collect results in correct order from the masked data.
        result = dict() 

        if self.nume_nodes.shape[0]!=0:

            predict_num = pred['numerical'].reshape(batch_size,self.n_masks,self.nume_nodes.shape[0],-1)
            predict_num = predict_num[:,self.nume_node_mask_inds,torch.arange(self.nume_nodes.shape[0])] #collect each num feature vector from its corresponding mask 
            result['numerical'] = predict_num

        if self.cat_nodes.shape[0]!=0:

            predict_cat = pred['categorical'].reshape(batch_size,self.n_masks,self.cat_ranges.shape[0],-1)
            predict_cat = predict_cat[:,self.cat_node_mask_inds,torch.arange(self.cat_ranges.shape[0])] #collect each cat feature vector from its corresponding mask 
            result['categorical'] = predict_cat

        if self.bin_nodes.shape[0]!=0:
        
            predict_bin = pred['binary'].reshape(batch_size,self.n_masks,self.bin_nodes.shape[0],-1)
            predict_bin = predict_bin[:,self.bin_node_mask_inds,torch.arange(self.bin_nodes.shape[0])] #collect each bin feature vector from its corresponding mask 
            result['binary'] = predict_bin

        
        return result
    


    def test_prediction(self,data, org_edge_index, last_state):
        return self(data,org_edge_index,last_state)

    def tensor_difference(self,t1,t2):
        '''
        Computes intersection of two 1D tensors. Assumes non duplicate elements.
        '''
        combined = np.concatenate((t1, t2))
        uniques, counts = np.unique(combined,return_counts=True)
        difference = uniques[counts == 1]
        #intersection = uniques[counts > 1]
        return difference

