import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from torch_geometric.data.batch import Batch
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
from torch_geometric.data import Data
import math
import torch.nn.functional as F
#from .graph_layer import GraphLayer
#from .GDN import GDN
from models.mlp import MLP
from models.AutoregressiveGNN import AutoregressiveGNN

import sys

class MaskedModel(nn.Module):
    '''
    Wrapper Module for Models that implements masking. 
        base_model: Already created nn.Module of the model to be used with masking. Must handle multiple variable types itself. 
        n_mask: the number of masks to be used
        n_steps: window size for series
        n_feats: number of features/nodes/sensors
        batch_size: batch_size
        masking_grous: list of lists with feature indeces. Each list is a group of features to be predicted together from the other groups

    
    '''
    def __init__(self, base_model,n_masks, n_steps,n_feats,batch_size = 64,masking_groups = None,**kwargs):
        super().__init__()

        self.n_steps = n_steps # includes last state
        print(n_steps)
        self.n_feats = n_feats
        self.device = kwargs['device']
        self.model = base_model
        self.model.to(self.device)

        self.cat_nodes = torch.tensor(self.model.cat_nodes)
        self.cat_ranges = torch.tensor(self.model.cat_ranges)
        self.nume_nodes = torch.tensor(self.model.nume_nodes)
        self.bin_nodes = torch.tensor(self.model.bin_nodes)
        self.masking_groups = torch.tensor(masking_groups)
        self.testing = False
        if masking_groups is None:
            raise ValueError("Masked model needs masking groups")

        self.nume_node_mask_inds,self.cat_node_mask_inds,self.bin_node_mask_inds = self._create_mask_mappings()
        
        
        self.n_masks = n_masks
        
        self.prepare_masking(masking_groups,batch_size)
        self.dropout = nn.Dropout(0.4)


        
        print(f"Using masked model with {n_masks} masks on {n_feats} features.")

        

    def prepare_masking(self,masking_groups,batch_size):
        self.masking_groups = masking_groups
        self.feat_order = [n for g in masking_groups for n in g]
        self.masks = self._create_masks(self.n_feats,self.n_masks,batch_size).to(self.device)


    def get_additional_loss_terms(self):

        return self.model.get_additional_loss_terms()

    def to(self,device):
        super().to(device)
        
        self.device = device

    def _create_mask_mappings(self):
        nume_node_mask_inds = []
        cat_node_mask_inds = []
        bin_node_mask_inds = []
        
        for i in self.nume_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    nume_node_mask_inds.append(mask_id)
        
        for i in self.cat_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    cat_node_mask_inds.append(mask_id)
        
        for i in self.bin_nodes:
            for mask_id,g in enumerate(self.masking_groups):
                if i in g:
                    bin_node_mask_inds.append(mask_id)
        

        return nume_node_mask_inds,cat_node_mask_inds,bin_node_mask_inds

    

    def _create_masks(self,n_feats, n_masks,bsz = 1):
        '''
        Generates masks that are used for prediction. We could have at most N masks for N nodes. We can also choose M<N masks such that each node is only masked once.

        '''
        feat_groups = self.masking_groups
        masks = torch.ones(bsz,n_masks,n_feats,1)
        for i,group in enumerate(feat_groups):
            masks[:,i,group]=0

        return masks

    #For now define the loss outside of here, but we probably need to bring it back in here
    #def loss_func(self,y_pred, y_true):
    #    loss = F.mse_loss(y_pred, y_true, reduction='mean')
    #    return loss

    def forward(self, data, **kwargs):

        x, edge_index, last_state = data.x, data.edge_index, data.last_state
        n_masks = self.n_masks
        n_features = self.n_feats
        n_steps = self.n_steps
        

        batch_size = x.shape[0]    
        masks = self.masks.to(self.device)[:batch_size].detach()
        
        if self.testing or True:
            x_with_0_last_state = torch.zeros(batch_size,self.n_feats,n_steps).to(self.device)
            x_with_0_last_state[:,:,:-1]=x#.reshape([-1,self.n_feats, n_steps-1])
            x_with_0_last_state[:,:,-1]=0*x[:,:,-1]
            data.x = x_with_0_last_state
            out_no_last = self.model(data=data,n_graphs = batch_size, **kwargs)
            
            #print(out,file=sys.stdout)
            loss,pred_errors = self.loss_func(out_no_last,data,return_separate_losses=True)
            nov_dist=dict()
            
            merged_pred_errors = torch.zeros((batch_size,self.n_feats))
            for k,v in pred_errors.items():
                pred_errors[k]= -v/-(torch.max(v)+1e-6)
                pred_errors[k] = pred_errors[k].detach()
            
            preds = torch.zeros((batch_size,self.n_feats)).to(self.device)
           
            #print(pred_errors, file=sys.stderr)
            val_losses = kwargs['per_node_val_losses']
            
            merged_val_losses = torch.zeros(self.n_feats)
            
            merged_val_losses[self.nume_nodes]=val_losses['numerical'].max(dim=-1).values
            if len(val_losses["categorical"])>0:
                merged_val_losses[self.cat_nodes]=val_losses['categorical'].max(dim=-1).values
            if len(val_losses["binary"])>0:
                merged_val_losses[self.bin_nodes]=val_losses['binary'].max(dim=-1).values
            

            preds[:,self.nume_nodes]=out_no_last['numerical'].squeeze(-1).float().detach()
            if 'categorical' in out_no_last:
                idx=out_no_last['categorical'].squeeze(-1).float().argmax(dim=-1)
                samples = F.one_hot(idx,num_classes=out_no_last['categorical'].shape[-1])

                samples = samples.reshape(batch_size,self.cat_ranges.shape[0],self.cat_ranges.shape[1])
                tmp = self.cat_ranges.expand(batch_size,self.cat_ranges.shape[0],self.cat_ranges.shape[1])

                cats = tmp[samples.long()==1].reshape(batch_size,-1,1)
                
                preds[:,self.cat_nodes]=cats.float().to(self.device).squeeze(-1)
               
            if 'binary' in out_no_last:
                preds[:,self.bin_nodes]=out_no_last['binary'].squeeze(-1).float().detach()
            

            merged_pred_weights = merged_pred_errors.to(self.device)

            merged_pred_errors[:,self.nume_nodes]=pred_errors['numerical'].float()
            if 'categorical' in pred_errors:
                merged_pred_errors[:,self.cat_nodes]=pred_errors['categorical'].float()
            if 'binary' in pred_errors:
                merged_pred_errors[:,self.bin_nodes]=pred_errors['binary'].float()
            
            merged_pred_weights = merged_pred_errors.to(self.device)
            #merged_val_losses = merged_val_losses.to(self.device)
            #merged_val_losses = merged_val_losses.to(self.device)
            #merged_pred_weights = (merged_pred_weights<merged_val_losses*1.5).float().to(self.device).reshape(-1,1,n_features,1)
            #merged_pred_weights = torch.softmax(-merged_pred_errors.to(self.device),dim=-1)
            #merged_pred_weights = torch.softmax(-merged_pred_errors.to(self.device),dim=-1)
            #print(merged_pred_errors.shape,(merged_pred_errors.max(dim=-1, keepdim=True)[0]).shape, file=sys.stderr)
            merged_pred_weights = -merged_pred_errors/-(merged_pred_errors.max(dim=-1, keepdim=True)[0])
            #try:
            #    merged_pred_weights= torch.bernoulli(merged_pred_weights)#.reshape(-1,1,n_features,1)
            #except RuntimeError:
            #    print(merged_pred_weights,file=sys.stderr)
            #    exit()
            if self.testing:
                #merged_pred_weights= (merged_pred_weights<merged_val_losses.to(self.device)).float().to(self.device)#.reshape(-1,1,n_features,1)
                merged_pred_weights= (merged_pred_weights<0.95).float().to(self.device)#.reshape(-1,1,n_features,1)
                

                modded_last_state =preds *(1-merged_pred_weights)
                modded_last_state += last_state.reshape(-1,n_features)*merged_pred_weights
                
            else:
                modded_last_state = last_state
            
            masked_last_state = modded_last_state.reshape(-1,1,n_features,1) * masks
            
        else:
            last_state = self.dropout(last_state)
            #print(last_state.shape,masks.shape,file=sys.stderr)
            masked_last_state = last_state.reshape(-1,1,n_features,1) * masks  #(batch_size,num_masks,n_nodes,1)
            
        #create masked version of last state
    
       

       #masked_last_state = last_state.reshape(-1,1,n_features,1) * masks  #(batch_size,num_masks,n_nodes,1)

        x = x.reshape([-1, 1,self.n_feats, n_steps-1])
        repeated_x = x.expand(x.shape[0],n_masks,x.shape[2],x.shape[3]) #shape: (bsz,n_masks,n_nodes,n_steps)

        x_with_masked_last_state = torch.cat((repeated_x,masked_last_state), axis =-1)
        x_with_masked_last_state = x_with_masked_last_state.view(-1,n_features,n_steps)

        x = x_with_masked_last_state # shape: (bsz*n_masks,n_nodes,n_feats)
       
        effective_batch_size = x.shape[0]
        

        data.x = x
        if len(edge_index.shape)>1:
            n_edges = edge_index.shape[1]
        
            new_edge_index = edge_index.repeat([1,n_masks])
            factor = torch.arange(0,n_masks).reshape(1,-1)
            mod = (n_features*factor).repeat_interleave(n_edges,dim=1).to(self.device) 
            #[2,n_edges*batch_size*n_masks]
            new_edge_index=new_edge_index+mod
            
            data.edge_index = new_edge_index
        #data.num_graphs = effective_batch_size
        out = self.model(data=data,n_graphs = effective_batch_size, **kwargs) #dict with categorical, binary and numerical predictions

        #print(out['numerical'].shape,file=sys.stderr)
        #NOTE: Need to collect results in correct order from the masked data.
        result = dict() 

        if self.nume_nodes.shape[0]!=0:

            predict_num = out['numerical'].reshape(batch_size,self.n_masks,self.nume_nodes.shape[0],-1)
            predict_num = predict_num[:,self.nume_node_mask_inds,torch.arange(self.nume_nodes.shape[0])] #collect each num feature vector from its corresponding mask 
            result['numerical'] = predict_num

        if self.cat_nodes.shape[0]!=0:

            predict_cat = out['categorical'].reshape(batch_size,self.n_masks,self.cat_ranges.shape[0],-1)
            predict_cat = predict_cat[:,self.cat_node_mask_inds,torch.arange(self.cat_ranges.shape[0])] #collect each cat feature vector from its corresponding mask 
            result['categorical'] = predict_cat

        if self.bin_nodes.shape[0]!=0:
        
            predict_bin = out['binary'].reshape(batch_size,self.n_masks,self.bin_nodes.shape[0],-1)
            predict_bin = predict_bin[:,self.bin_node_mask_inds,torch.arange(self.bin_nodes.shape[0])] #collect each bin feature vector from its corresponding mask 
            result['binary'] = predict_bin

        if not self.training:
            return result
        
        return result,out_no_last
    


    def test_prediction(self,data, org_edge_index, last_state):
        return self(data,org_edge_index,last_state)

    def loss_func(self, predictions, graph_batch, return_separate_losses = False):
        """

        """
        reduction = 'none' #if return_separate_losses else 'mean'
        # (batch_size,num_nodes)
        labels = graph_batch.last_state.reshape([-1, self.n_feats])
        n_graphs = labels.shape[0]
        
        loss = 0
        # loss on numerical nodes
        numerical_labels = labels[:, self.nume_nodes]
        
        nume_loss = F.mse_loss(predictions["numerical"].squeeze(2), numerical_labels,reduction=reduction).reshape(n_graphs,-1)
        #print('nume_loss',nume_loss.shape)
        loss += torch.mean(nume_loss, dim=1)
        #print('numerical',torch.mean(nume_loss))
        if self.cat_nodes.shape[0]!=0:
            # loss on cat nodes
            # select out categorical labels
            cat_labels = labels[:, self.cat_nodes].int().unsqueeze(2)
            # possible categories of these nodes. shape is (num_cat_nodes, max_cat_num)
            one_hot = (cat_labels == self.cat_ranges)
            _, targets = one_hot.max(dim=-1)
            cat_loss = F.cross_entropy(predictions["categorical"].reshape([-1, self.cat_ranges.shape[1]]), targets.reshape([-1]),reduction=reduction).reshape(n_graphs,-1)
            loss += torch.mean(cat_loss, dim=1) 
            #print('cat',torch.mean(cat_loss) )
        else:
            cat_loss =None

        if self.bin_nodes.shape[0]!=0:
        # loss on binary nodes
            binary_labels = labels[:, self.bin_nodes]
            #fix binary labels to be 0 and 1 instead of -1 and 1
            binary_labels[binary_labels==-1]=0
            bin_loss = F.binary_cross_entropy_with_logits(predictions["binary"].squeeze(2), binary_labels,reduction=reduction).reshape(n_graphs,-1)
            loss += torch.mean(bin_loss, dim=1)
            #print('bin',torch.mean(bin_loss))
        else:
            bin_loss = None
        
        if return_separate_losses:
            
            loss_dict = {}
            if nume_loss is not None:
                loss_dict['numerical'] = nume_loss.detach().cpu()
            if cat_loss is not None:
                loss_dict['categorical'] = cat_loss.detach().cpu()
            if bin_loss is not None:
                loss_dict['binary'] = bin_loss.detach().cpu()

            return loss, loss_dict
        return torch.mean(loss)