import pickle
from unicodedata import numeric
from models import AutoregressiveGNN,MaskedGNN,MaskedMLP,MaskedModel,TabularMaskedMLP,mlp, MixedVariableModel, ReconstructingModel, TwoStepModel,TwoStepV2
import dataset.json_graph as json_graph
import torch
import torch.nn.functional as F
from tqdm import tqdm
import os
import numpy as np
from models.AnomalyTransformer import AnomalyTransformer
import utils
import copy 
import json

import sys
from models.FeatureGrouper import get_feature_grouper
class AutoregressiveModel:
    """
    this class mains an autoregressive model. it can train, test and save the model.   
    """

    def __init__(self, n_feats, n_steps, node_info, model_type, config, device, hyper_params = None,lr=0.0001,**kwargs):

        #self.lr = 0.00001
        self.lr = lr
        self.n_feats = n_feats
        self.n_steps=n_steps
        self.config = config
        self.n_masks = config.n_masks
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"])
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]
        self.random_nodes=node_info.get("random_nodes",{"numerical":[],"binary":[],"categorical":[]})
        self.anomaly_filter = config.__dict__.get("anomaly_filter",0.95) 
        self.use_reconstruction = config.reconstructing
        self.device = device
        self.feature_groups = None
        
        self.per_node_val_losses = {"numerical":torch.zeros(len(self.nume_nodes)) + torch.inf,"binary":torch.zeros(len(self.bin_nodes)) + torch.inf,"categorical":torch.zeros(len(self.cat_nodes)) + torch.inf}
        if hyper_params is None:
            raise ValueError("Need hyperparam dictionary")

        model_hparams = hyper_params

        if config.n_masks>0:
            if self.use_reconstruction :
                raise TypeError("Cannot use both making and reconstruction")

            if config.use_pretrained:
                fname=  f"../models/model_configs/{model_type}_{0}_{config.task}.json"
                with open(fname,'r') as f:
                    base_hparams = json.load(f)
                base_model = self._create_base_model(model_type,n_feats,n_steps,node_info,self.device,config=config,**base_hparams,**kwargs) # +1 to include last state
                base_model.load_state_dict(torch.load(os.path.join(config.model_save_dir,config.model_type+f"_0.pth")))
            else:
                base_model = self._create_base_model(model_type,n_feats,n_steps,node_info,self.device,config=config,**hyper_params,**kwargs) # +1 to include last state
            

            self.feature_groups = [list(range(n_feats))]
            #feature_grouper = get_feature_grouper(name = "ClusteringGrouper", n_masks =config.n_masks, train_dataset = kwargs['train_dataset'], cat_nodes=self.cat_nodes)
            feature_grouper = get_feature_grouper(name = "RandomGrouper", n_masks =config.n_masks, n_feats = n_feats, train_dataset = kwargs['train_dataset'], cat_nodes=self.cat_nodes)
            self.feature_groups=feature_grouper.get_groups()
            print(self.feature_groups)
            #self.model = MaskedModel.MaskedModel(base_model,n_masks = len(self.feature_groups),n_steps=n_steps+1,n_feats=n_feats,batch_size=config.bsz, device = self.device, masking_groups=self.feature_groups)
            print('--------',config.use_pretrained,file=sys.stderr)
            #if config.use_pretrained:
            self.model = TwoStepV2.MaskedModel(base_model,n_masks = len(self.feature_groups),n_steps=n_steps+1,n_feats=n_feats,batch_size=config.bsz, device = self.device, masking_groups=self.feature_groups,**model_hparams,loss_func = self.loss_func,anomaly_filter=self.anomaly_filter)
            #else:
            #    self.model = TwoStepModel.MaskedModel(base_model,n_masks = len(self.feature_groups),n_steps=n_steps+1,n_feats=n_feats,batch_size=config.bsz, device = self.device, masking_groups=self.feature_groups)
            
        elif self.use_reconstruction:
            base_model = self._create_base_model(model_type,n_feats,n_steps+1,node_info,self.device,config=config,**model_hparams,**kwargs) # +1 to include last state
            self.model = ReconstructingModel.ReconstructingModel(base_model,n_steps=n_steps+1,n_feats=n_feats, device = self.device)

        else:
            self.model = self._create_base_model(model_type,n_feats,n_steps,node_info,self.device,config=config,**model_hparams,**kwargs)

    def _create_base_model(self,model_type, num_nodes,n_steps,node_info,device, **kwargs):
        
        
        if model_type == 'mlp':
           
            return mlp.MLP(num_nodes = num_nodes,n_steps=n_steps,node_info=node_info,device = device, **kwargs)

        if model_type in ['GCN','GAT',"GraphSAGE"]:
            if not kwargs['config'].use_json_graph and kwargs['config'].task in ['gridworld','monopoly']:
                raise ValueError("Cannot train GNN without json graph on game data")
            
            return AutoregressiveGNN.AutoregressiveGNN(num_nodes,n_steps,node_info,conv_type=model_type,device = device,**kwargs)

        else:
            return MixedVariableModel.MixedVariableModel(num_nodes,n_steps,node_info,underlying_model=model_type,device=device,**kwargs)

        
        
    def load_model(self,path):
        if self.use_reconstruction:
            path+="_reconstruction"
        self.model.load_state_dict(torch.load(path))#,map_location=self.device))
        print("loaded")
        with open(path+"_statistics",'rb') as f:
            tmp = pickle.load(f)
            self.per_node_val_losses =tmp['per_node']
            self.graph_val_losses =tmp['graph']
            self.feature_groups=tmp['feature_groups']
            if isinstance(self.model,MaskedModel.MaskedModel) or isinstance(self.model,TwoStepModel.MaskedModel):
                self.model.prepare_masking(tmp['feature_groups'],self.config.bsz)

    def save_model(self,path):
        if self.use_reconstruction:
            path+="_reconstruction"
        torch.save(self.model.state_dict(), path)
        with open(path+"_statistics",'wb') as f:
            pickle.dump({"graph":self.graph_val_losses.cpu(),"per_node": self.per_node_val_losses,'feature_groups':self.feature_groups},f)

    def to(self, device):
        self.device =device
        self.cat_ranges = self.cat_ranges.to(device)
        self.model.to(device)

    def getgraphloss(self):
        if self.config.__dict__.get("use_val_thresh",None) is not None:
            return self.graph_val_losses.max()
        return None

    def compute_loss_statistics(self,losses):
        '''
        Calculates 95th percentile for each node.
        
        '''
        num_losses = losses['numerical']
        cat_losses = losses['categorical']
        bin_losses = losses['num']
        
        return

    def loss_func(self, predictions, graph_batch, return_separate_losses = False):
        """

        """
        reduction = 'none' #if return_separate_losses else 'mean'
        # (batch_size,num_nodes)
        labels = graph_batch.last_state.reshape([-1, self.n_feats])
        n_graphs = labels.shape[0]
        
        loss = 0
        # loss on numerical nodes

        ## FIX TO NOT PENALIZE RANDOM VARIABLES ##
        #   offsets mean by a bit  
        #         


        numerical_labels = labels[:, self.nume_nodes]
         
        nume_preds = predictions["numerical"].squeeze(2)
        random = self.random_nodes['numerical']
        non_random = self.invert_index(np.arange(len(self.nume_nodes)),random)



        nume_loss = F.mse_loss(nume_preds, numerical_labels,reduction=reduction).reshape(n_graphs,-1)
        nume_loss[:,random]=0
        loss += torch.sum(nume_loss[:,non_random], dim=1)
        #print(loss,file=sys.stderr)
        #print('numerical',torch.mean(nume_loss))
        if self.cat_nodes.shape[0]!=0:
            # loss on cat nodes
            # select out categorical labels
            cat_labels = labels[:, self.cat_nodes].int().unsqueeze(2)
            random = self.random_nodes['categorical']
            non_random =  non_random=self.invert_index(np.arange(len(self.cat_nodes)),random)

            # possible categories of these nodes. shape is (num_cat_nodes, max_cat_num)
            one_hot = (cat_labels == self.cat_ranges)
            _, targets = one_hot.max(dim=-1)
            cat_loss = F.cross_entropy(predictions["categorical"].reshape([-1, self.cat_ranges.shape[1]]), targets.reshape([-1]),reduction=reduction).reshape(n_graphs,-1)
            cat_loss[:,random]=0
            loss += torch.sum(cat_loss[:,non_random], dim=1) 
            #print(loss,file=sys.stderr)
            #print('cat',torch.mean(cat_loss) )
        else:
            cat_loss =None

        if self.bin_nodes.shape[0]!=0:
        # loss on binary nodes
        
            binary_labels = labels[:, self.bin_nodes]
            random = self.random_nodes['binary']
            non_random=self.invert_index(np.arange(len(self.bin_nodes)),random)
            #fix binary labels to be 0 and 1 instead of -1 and 1
            binary_labels[binary_labels==-1]=0
            bin_loss = F.binary_cross_entropy_with_logits(predictions["binary"].squeeze(2), binary_labels,reduction=reduction).reshape(n_graphs,-1)
            bin_loss[:,random]=0
            loss += torch.sum(bin_loss[:,non_random], dim=1)
            #print(loss,file=sys.stderr)
            #print('bin',torch.mean(bin_loss))
        else:
            bin_loss = None
        
        loss = loss/labels.shape[-1]
        
        if return_separate_losses:
            
            loss_dict = {}
            if nume_loss is not None:
                loss_dict['numerical'] = nume_loss.detach().cpu()
            if cat_loss is not None:
                loss_dict['categorical'] = cat_loss.detach().cpu()
            if bin_loss is not None:
                loss_dict['binary'] = bin_loss.detach().cpu()

            return loss, loss_dict
        return torch.mean(loss)

    def validation(self, valloader):
        val_losses = {}
        for mode in ['real']:
            loss = self.validate(valloader,mode)
            val_losses[mode+"_val_loss"] = loss

        return val_losses

    def invert_index(self,a1,a2):
        c = np.concatenate((a1,a2))
        unique,counts = np.unique(c,return_counts=True)
        idx = unique[counts==1]
        return idx
    def validate(self,valloader, mode = "real", return_all_losses = False):
        '''
        Runs the model on some data without gradients and computes mean score.
            valloader: pytorch geometric dataloader with validation data
            mode: [real,noise]  one of these strings indicating what data to run on.
        
        Returns:
            Mean discriminator score for data. 
        '''
        dataloader = valloader
        episode_length = len(dataloader)
        pbar = tqdm(dataloader, total=episode_length)
        self.model.eval()
        
        all_losses = []
        #losses_per_category
        with torch.no_grad(): #no need for gradients here
            for bidx, graph_batch in enumerate(pbar):
                
                graph_batch = graph_batch.to(self.device)
                if mode == 'real': 
                    output= self.model(graph_batch)
                    loss = self.loss_func(output,graph_batch)
                    
                if mode =='noise':
                    noise_graph_batch = graph_batch
                    noise_graph_batch.x = torch.rand_like(graph_batch.x)
                    output= self.model(noise_graph_batch)
                    loss = self.loss_func(output,graph_batch)

                all_losses.append(loss)
        
        all_losses = torch.tensor(all_losses).detach().cpu().numpy()

        self.model.train()

        if return_all_losses:
            return all_losses
        else:
            return float(np.mean(all_losses))

    def get_additional_loss_terms(self):

        return self.model.get_additional_loss_terms()

    def get_graph_and_per_node_losses(self, dataloader):

        '''
        Calculates per graph and per node losses on dataloader
        '''
        
        episode_length = len(dataloader)
        
        self.model.eval()
        
        all_losses = {"categorical":[],"numerical":[], "binary":[]}
        graph_losses = []
        #losses_per_category
        with torch.no_grad(): #no need for gradients here
            for graph_batch in dataloader:#enumerate(pbar):
                
                graph_batch = graph_batch.to(self.device)
            
                output= self.predict(graph_batch)
                loss, losses = self.loss_func(output,graph_batch, return_separate_losses=True)
                for k in losses:
                    all_losses[k].extend(losses[k])
                graph_losses.append(loss.cpu())
        
        #calculate loss stats
        
        numerical = torch.cat([i.unsqueeze(-1) for i in all_losses["numerical"]],dim=-1)
        if all_losses['categorical']:
            categorical =  torch.cat([i.unsqueeze(-1) for i in all_losses["categorical"]],dim=-1)
        else:
            categorical=np.array([])
        
        if all_losses['binary']:
            binary =  torch.cat([i.unsqueeze(-1) for i in all_losses["binary"]],dim=-1)
        else:
            binary=np.array([])
    
        graph_losses=torch.cat(graph_losses).to(self.device)
        
        result= dict(numerical=numerical,categorical=categorical,binary=binary)
        self.model.train()

        return result, graph_losses

    def _calc_perc(self, v,values):
        
        count = sum(v.item()>values)
        return count/len(values)

    def calculate_loss_percentiles(self,losses):

        numerical = losses['numerical'][0]
        val_numerical = self.per_node_val_losses['numerical']
        numerical_percentiles = []
        
        for i in range(val_numerical.shape[1]):
            numerical_percentiles.append(self._calc_perc(numerical[i],val_numerical[:,i]))
        

        return

    


    def prediction_to_nodearray(self, predictions, node_info):
        nume_nodes = node_info["nume_nodes"]
        cat_nodes = node_info["cat_nodes"]
        cat_ranges = node_info["cat_ranges"]
        bin_nodes = node_info['bin_nodes']
        num_nodes = self.n_feats

        node_array = np.zeros((num_nodes,1))
        node_array[nume_nodes]=predictions['numerical'][0].detach().cpu().numpy()

        if cat_nodes.shape[0]!=0:
            cat_inds =predictions['categorical'][0].detach().cpu().numpy().argmax(axis=1).reshape(-1)
            cat_strs = []
            for i in range(len(cat_ranges)):
                cat_strs.append(cat_ranges[i][cat_inds[i]])
            node_array[cat_nodes]=np.array(cat_strs).reshape(-1,1)
        
        if bin_nodes.shape[0]!=0:
            node_array[bin_nodes]=predictions['binary'][0].detach().cpu().numpy()
        return node_array

    def compute_percentiles(self, queries, values, loop_values = True):
        '''
        queries: nx1
        values: nXm

        Computes percentile of ith query value among the ith value vector
        '''

        repeated_vals = values.unsqueeze(0).repeat((len(queries),1,1))
        
        queries = queries.unsqueeze(-1)
        
        percentiles = torch.sum(queries>repeated_vals,dim=-1)/values.shape[-1]
        
        '''
        percentiles = []
        for k in range(len(queries)):
            percentiles.append([])
            for i in range(len(queries[k])):
                j = i if loop_values else 0
                percentiles[-1].append(utils.percentile_of_val(torch.tensor(values[j]),queries[k][i]))

       '''

        return percentiles

    
    def compute_novelty_scores(self, graph, return_prediction = False):

        
        # vector with shape (num_nodes,) 
        # loss = loss_func(predict(batch), graph)

        # quantiles = compare(self.statistics, loss)
        prediction = self.predict(graph)
        
        loss,losses = self.loss_func(prediction,graph, return_separate_losses=True)
        #self.per_node_val_losses
        numerical_percentiles = self.compute_percentiles(losses['numerical'],self.per_node_val_losses['numerical'])
        percentiles = {"numerical":numerical_percentiles,"losses":losses}

        losses['numerical'] = losses['numerical']#/self.per_node_val_losses['numerical'].max(dim=-1).values
        losses['numerical'][:,self.random_nodes['numerical']]=0
        if 'categorical' in losses: 
            
            losses['categorical'] = losses['categorical']#/self.per_node_val_losses['categorical'].max(dim=-1).values
            losses['categorical'][:,self.random_nodes['categorical']]=0

        if 'binary' in losses:
            losses['binary'] = losses['binary']#/self.per_node_val_losses['binary'].max(dim=-1).values
            losses['binary'][:,self.random_nodes['binary']]=0

        graph_percentile = self.compute_percentiles(loss.to(self.device).reshape(1,-1), self.graph_val_losses.reshape(1,-1).to(self.device),loop_values=False)
       
        percentiles['graph'] = graph_percentile

        if return_prediction:
            return percentiles,loss, prediction
        return loss,losses

    def compute_novelty_scores_old(self, graph, return_prediction = False):

        
        # vector with shape (num_nodes,) 
        # loss = loss_func(predict(batch), graph)

        # quantiles = compare(self.statistics, loss)
        prediction = self.predict(graph)
        
        loss,losses = self.loss_func(prediction,graph, return_separate_losses=True)
        
        
        #print(factor.shape,file=sys.stderr)
       
        numerical_percentiles = self.compute_percentiles(losses['numerical'],self.per_node_val_losses['numerical'])
        percentiles = {"numerical":numerical_percentiles,"losses":losses}

        losses['numerical'] = losses['numerical']
        if 'categorical' in losses: 
            categorical_percentiles = self.compute_percentiles(losses['categorical'],self.per_node_val_losses['categorical'])
            percentiles['categorical'] = categorical_percentiles
            losses['categorical'] = losses['categorical']

        if 'binary' in losses:
            binary_percentiles = self.compute_percentiles(losses['binary'],self.per_node_val_losses['binary'])
            percentiles['binary'] = binary_percentiles
            losses['binary'] = losses['binary']

        graph_percentile = self.compute_percentiles(loss.to(self.device).reshape(1,-1), self.graph_val_losses.reshape(1,-1).to(self.device),loop_values=False)
       
        percentiles['graph'] = graph_percentile

        if return_prediction:
            return percentiles,loss, prediction
        return percentiles

    def get_diff_dictionary(self,t,p):
        
        def make_diffs(a,b):
            del_keys=0
            nondict=0
            diff = {}
            for k in a.keys():   
                if type(a[k]) == dict:
                    
                    diff[k]=make_diffs(a[k],b[k])
                else:
                    nondict+=1
                    if str(a[k]) == str(b[k]):
                        
                        del_keys+=1
                    else:
                        diff[k]=copy.deepcopy(a[k])
            if del_keys!=nondict and nondict==len(a):
                diff=copy.deepcopy(a)
                #print(diff)
            return diff
        diff = make_diffs(t,p)
        return diff

    def predict(self,graph):
        graph.to(self.device)
        output= self.model(graph, per_node_val_losses=self.per_node_val_losses)
        return output
