import argparse, os
import sys
sys.path.append('../')
from PIL.Image import new

from dataset.gamedata import GameDataset
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
import pickle as pkl
from sklearn.metrics import roc_auc_score
import numpy as np
import types
from natsort import natsorted
from Autoregressive_model import AutoregressiveModel
import copy
import matplotlib.pyplot as plt
import utils
import json
from sklearn.metrics import roc_curve,roc_auc_score
from sklearn.metrics import precision_recall_curve
import sklearn.metrics as metrics
from dataset import monopoly_json2vec,gridworld_json2vec

def test_model(config):
    return test(config)


def test(config,model = None,hyperparams=None):

    graph_pred_perc=[]
    graph_pred_loss=[]
    node_pred_perc=[]
    node_pred_loss=[]
    graph_label = []
    node_label = []
     
    if config.task =='gridworld':
        #test_configs = ["/home/plymper/data/gridworldsData/mixed_inventoryreset","/home/plymper/data/gridworldsData/mixed_breakincrease","/home/plymper/data/gridworldsData/mixed_woodgift"]
        test_configs = ["/home/plymper/data/gridworldsData/mixed_inventoryreset","/home/plymper/data/gridworldsData/mixed_breakincrease",]
    else:

        test_configs= ["move_player","cash_change"]
    
    for p in test_configs:
        print(p)
        if config.task=='gridworld':
            config.test_dataset_dir=p
            config.inject_novelty_type = None
        elif config.task =='monopoly':
            config.test_dataset_dir = "/home/yli52/dataset/monopoly/val_new/ori_json"
            config.inject_novelty_type= p
        graph_perc,test_losses,test_labels,per_node_perc_predictions,per_node_loss_predictions,per_node_test_labels, predictions= get_labels_predictions(config,model=model, hyper_params=hyperparams)
        #graph_perc,test_losses,test_labels,per_node_perc_predictions,per_node_loss_predictions,per_node_test_labels 

        graph_pred_perc.append(graph_perc)
        graph_pred_loss.append(test_losses.cpu())
        node_pred_perc.append(per_node_perc_predictions)
        node_pred_loss.append(per_node_loss_predictions)
        graph_label.append(test_labels)
        node_label.append(per_node_test_labels)
        #print(p)
        
        #score(test_losses,test_labels,'Per graph loss')
        #if config.task == "monopoly":
        #    score(per_node_perc_predictions,per_node_test_labels.flatten(),'Per node percentile')
        #    score(per_node_loss_predictions,per_node_test_labels.flatten(),'Per node loss')
        #else:
        #    score(per_node_perc_predictions,per_node_test_labels,'Per node percentile')
        #    score(per_node_loss_predictions,per_node_test_labels,'Per node loss')
#         score(per_node_loss_predictions,per_node_test_labels,'Per node loss')

    graph_pred_perc=np.concatenate(graph_pred_perc)
    graph_pred_loss=np.concatenate(graph_pred_loss)
    node_pred_perc=np.concatenate(node_pred_perc)
    node_pred_loss=np.concatenate(node_pred_loss)
    graph_label = np.concatenate(graph_label)
    node_label = np.concatenate(node_label)

    

    node_pred_loss = node_pred_loss.reshape(len(graph_label),-1)#score_normalization(node_pred_loss.reshape(len(graph_label),-1))
    #node_pred_loss = score_normalization(node_pred_loss.reshape(len(graph_label),-1))
    
    node_label = node_label.reshape(len(graph_label),-1)
    #print(node_pred_loss,node_label,file=sys.stderr)
    #exit()
    from sklearn.metrics import average_precision_score
    node_average_precision = []
    for ps,ls in zip(node_pred_loss,node_label):
        if np.sum(ls)>0:
            ap = average_precision_score(ls,ps)
            node_average_precision.append(ap)
        #print(ap,np.sum(ls))
        
    
    map = np.mean(node_average_precision)
        

    max_sensor_loss = node_pred_loss.max(axis=1)
    sum_sensor_loss = node_pred_loss.sum(axis=1)
    if config.task == "monopoly":
        graph_label = graph_label.flatten()
        node_label = node_label.flatten()
        
    print("------------DATASET STATS--------------")
    print(f"# Test instances",len(graph_label))
    
    print(f"State Anomaly fraction:", np.sum(graph_label)/len(graph_label))
    print(f"Node Anomaly fraction:", np.sum(node_label)/len(node_label.flatten()))

    print("----------Aggregated-----------")
    
    score(sum_sensor_loss,graph_label,'Per graph sum normalized loss')
    print("MAP:", map)
    score(max_sensor_loss,graph_label,'Per graph max normalized loss')
    print("MAP:", map)

    node_pred_loss=node_pred_loss.reshape(-1)
    node_label=node_label.reshape(-1)

    score(node_pred_loss,node_label,'Per node normalized loss')
    print("MAP:", map)

    return

def score(x,y,name = 'Per graph Loss'):
    precision, recall, thresholds = precision_recall_curve(y, x)

    f1_scores = 2*recall*precision/(recall+precision+1e-5)
    print("--------------------")
    print(name)
    print('Best threshold: ', thresholds[np.argmax(f1_scores)])
    print(f'{name} F1-Score: ', np.max(f1_scores))
    print('Precision:',precision[np.argmax(f1_scores)])
    print('Recall:',recall[np.argmax(f1_scores)])
    

    #calculate MRR : mean of 1/rank 
    
def score_normalization(scores):
    from scipy.stats import iqr
    iq= iqr(scores,axis=0)
    iq[iq==0]=1
    normed = (scores-np.median(scores,axis=0))/iq
    #normed = np.zeros_like(scores)
    #
    #for i in range(len(scores)):
    #    iq = iqr(scores[:i+1],axis=0)
    #    iq[iq==0]=1
    #    normed[i]=np.median(scores[:i+1],axis=0)/iq
    return normed

def load_model_configs(model_type,config):
    import json
    n_masks=config.n_masks
    dataset = config.task
    if 'model_config_root' in config.__dict__:
        fname=  f"{config.model_config_root}/models/model_configs/{model_type}_{n_masks}_{dataset}.json"
    else:
        fname=  f"./models/model_configs/{model_type}_{n_masks}_{dataset}.json"
    with open(fname,'r') as f:
        model_hparams = json.load(f)

    return model_hparams

def get_labels_predictions(config,model=None, hyper_params=None):
    ignore_intermediate_nodes = not config.use_json_graph
    #if config.task == 'monopoly':
    #    ignore_intermediate_nodes = False
    dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=config.winsize, mode='test',ignore_intermediate_nodes=ignore_intermediate_nodes, task=config.task, inject_novelty_type=config.inject_novelty_type)
    
    if model is None:
        if hyper_params is None:
            hyper_params = load_model_configs(config.model_type,config)
        model = AutoregressiveModel(dataset.num_nodes, dataset.node_feature_dim, dataset.node_info,model_type= config.model_type, config = config,device = 'cpu', train_dataset=np.ones((dataset.num_nodes,dataset.num_nodes)), hyper_params=hyper_params)
        pth=os.path.join(config.model_save_dir,config.model_type+f"_{config.n_masks}.pth")
    
        model.load(pth)
        model.to('cpu')
        print(config)
    model.model.testing=True
    model.model.eval()
    loss_perc,test_losses,predictions = run_on_jsons(config,dataset,model)
    
    if config.task == 'monopoly':
        labels = dataset.labels
    else:
        with open(os.path.join(config.test_dataset_dir,'labels.pkl'),'rb') as f:
            labels = pkl.load(f)   

    graph_perc = []
    for j in loss_perc:
        for d in j:
            graph_perc.append(d['graph'][0].item())

    test_labels = labels

    if config.task=='gridworld':
        #label transformation
        test_labels= labels[5:]
        test_labels = [item for subl in labels for item in subl[config.winsize-1:]]
        test_labels= [0]*(len(graph_perc) - len(test_labels)) + test_labels
        pass
    
    
    print(test_losses)
    test_losses = torch.cat([i for sublist in test_losses for i in sublist ])
    #graph_scores = roc_auc_score(test_labels,test_losses),roc_auc_score(test_labels,graph_perc)
    if config.task == 'monopoly':
        node_labels = dataset.node_labels
    else:
        with open(os.path.join(config.test_dataset_dir,'node_labels.pkl'),'rb') as f:
            node_labels = pkl.load(f) 
            
    if ignore_intermediate_nodes and config.task =='gridworld':
        old_node_info = dataset.old_node_info
        node_labels_=[]
        for ep_nl in node_labels:
            if(len(ep_nl)<1): 
                node_labels_.append([])
                continue
            ep_nl = np.array(ep_nl)
            nume_labels = ep_nl[:,old_node_info['nume_nodes']]
            cat_labels = ep_nl[:,old_node_info['cat_nodes']]
            bin_labels = ep_nl[:,old_node_info['bin_nodes']]
            node_labels_.append(np.concatenate([nume_labels,cat_labels,bin_labels],axis=1))
        node_labels=node_labels_

    if config.task=='gridworld':  
        node_test_labels = [item for subl in node_labels for item in subl[config.winsize-1:]]
        node_test_labels= [np.zeros(dataset.num_nodes)]*(len(graph_perc) - len(node_test_labels)) + node_test_labels
        per_node_test_labels = np.concatenate(node_test_labels)
    else:
        per_node_test_labels = np.concatenate(node_labels)


    all_node_predictions = []
    node_info = dataset.node_info

    for k,preds in predictions.items():
        node_preds = np.zeros(dataset.num_nodes)
        numerical = preds['percentile']['losses']['numerical']
        binary = preds['percentile']['losses']['binary']
        categorical = preds['percentile']['losses']['categorical']
        node_preds[node_info['cat_nodes']]=categorical
        node_preds[node_info['nume_nodes']]=numerical
        node_preds[node_info['bin_nodes']]=binary
        all_node_predictions.append(node_preds)
        
    per_node_loss_predictions=np.concatenate(all_node_predictions)

    all_node_predictions = []
    node_info = dataset.node_info
    '''
    for k,preds in predictions.items():
        node_preds = np.zeros(dataset.num_nodes)
        numerical = preds['percentile']['numerical']
        binary = preds['percentile'].get('binary',[])
        categorical = preds['percentile'].get('categorical',[])

        node_preds[node_info['cat_nodes']]=categorical 
        node_preds[node_info['nume_nodes']]=numerical
        node_preds[node_info['bin_nodes']]=binary 
        all_node_predictions.append(node_preds)
    '''
    per_node_perc_predictions=per_node_loss_predictions

    return graph_perc,test_losses,test_labels,per_node_perc_predictions,per_node_loss_predictions,per_node_test_labels, predictions



def evaluate_localization(errors,labels, task='gridworld',**kwargs):
    if task == 'monopoly':
        types2vec = monopoly_json2vec.types2vec
    elif task=='gridworld':
        types2vec = gridworld_json2vec.types2vec
    errors = copy.deepcopy(errors)
    labels = copy.deepcopy(labels)
    vloss=None
    valloss = kwargs.get('valloss',None)
    if valloss is not None:
        valloss = valloss['feature']
        for k,v in valloss.items():
            if len(v)>0:
                valloss[k]=v.max(-1).values
                
        vloss = torch.tensor(types2vec(valloss)) 

    for k,v in errors.items():
        for i in range(len(v)):
            
            v[i] = torch.tensor(types2vec(v[i]))
            if vloss is not None:
                v[i]=1*(v[i]>vloss)
            labels[k][i] = labels[k][i].cpu()

    all_errors = []
    all_labels = []
    for k,v in errors.items():
        all_errors.append(torch.cat(v))
        all_labels.append(torch.cat(labels[k]))


    x = torch.cat(all_errors).flatten()
    y = torch.cat(all_labels, dim=-1).flatten()
        

    

    precision, recall, thresholds = metrics.precision_recall_curve(y, x)

    f1_scores = 2*recall*precision/(recall+precision+1e-6)

    precision, recall, thresholds = metrics.precision_recall_curve(y, x)
    auc_precision_recall = metrics.auc(recall, precision)
    auc_roc = metrics.roc_auc_score(y,x)
    
    f1 = np.max(f1_scores)
    prec=precision[np.argmax(f1_scores)]
    rec=recall[np.argmax(f1_scores)]
    if valloss is not None:
        y_hat = x
        f1 = metrics.f1_score(y,y_hat)
        prec=metrics.precision_score(y,y_hat)
        rec=metrics.recall_score(y,y_hat)
    else:
        f1 = np.max(f1_scores)
        prec=precision[np.argmax(f1_scores)]
        rec=recall[np.argmax(f1_scores)]
        

    print("AUC-PR",auc_precision_recall)
    print("AUC-ROC",auc_roc)

    print(f'F1-Score: ', f1)
    print('Precision:',prec)
    print('Recall:',rec)
    
    

    return {"Localization F1":f1,"Localization Precision":prec,"Localization Recall":rec, "Localization PR-AUC":auc_precision_recall,"Localization ROC-AUC":auc_roc}

def pred_vec(pred):
    vals = []
    for k,v in pred.items():
        vals.append(v.cpu())

    
    return torch.cat(vals,dim=-1)

def evaluate_detection(errors,labels,valloss=None,**kwargs):
    errors = copy.deepcopy(errors)
    labels = copy.deepcopy(labels)
    
   
    for k,v in errors.items():
        for i in range(len(v)):
            v[i] = torch.mean(pred_vec(v[i])).item()
            labels[k][i] = torch.max(labels[k][i].cpu())

    vec_errors = kwargs.get('vec_errors',None)
    if vec_errors is not None:
        errors = vec_errors
        
    all_errors = []
    all_labels = []
    for k,v in errors.items():
        
        all_errors.append(torch.tensor(v))
        all_labels.append(torch.tensor(labels[k]))

    x = torch.cat(all_errors).flatten()
    y = torch.cat(all_labels, dim=-1).flatten()

    

    precision, recall, thresholds = metrics.precision_recall_curve(y, x)

    f1_scores = 2*recall*precision/(recall+precision+1e-6)

    precision, recall, thresholds = metrics.precision_recall_curve(y, x)
    auc_precision_recall = metrics.auc(recall, precision)
    auc_roc = metrics.roc_auc_score(y,x)
    if valloss is None:

        f1 = np.max(f1_scores)
        prec=precision[np.argmax(f1_scores)]
        rec=recall[np.argmax(f1_scores)]
    else:
        thresh = valloss['graph'].max().cpu()
        y_hat = x>thresh
        f1 = metrics.f1_score(y,y_hat)
        prec=metrics.precision_score(y,y_hat)
        rec=metrics.recall_score(y,y_hat)

    print("AUC-PR",auc_precision_recall)
    print("AUC-ROC",auc_roc)

    print(f'F1-Score: ', f1)
    print('Precision:',prec)
    print('Recall:',rec)

    return {"F1":f1,"Precision":prec,"Recall":rec, "PR-AUC":auc_precision_recall,"ROC-AUC":auc_roc}

def run_on_jsons(args,dataset, model):
    '''
    Function to load and feed jsons to the model one by one. 
        args: SimpleNameSpace populated with information about the model and data (see main function)
    
    Returns:
        loss percentiles for graph and nodes, loss values for graph and nodes, predictions object
    '''
    
    model.model.eval()
    loss_percentile = []
    losses = []
    predictions = {}
        #args.test_dataset_dir = args.val_dataset_dir
    test_files = natsorted(list(filter(lambda x: '.json' in x,os.listdir(args.test_dataset_dir))))
        
    fcount = 0
    new_episode = False
    for fname in test_files:
        split = fname.split('_')
        if args.task == 'monopoly':
            ep_num = int(split[-2])
            step_num = int(split[-1].split('.')[0])
        else:
            ep_num = int(split[0])
            step_num = int(split[1][:-5])
            
        if step_num ==1:
            new_episode = True
            loss_percentile.append([])
            losses.append([])
            
        with open(os.path.join(args.test_dataset_dir,fname)) as f:
            json_obj = json.load(f)
            

        dataset.receive_json_obj(json_obj, new_episode=new_episode)
        new_episode = False
        
        graph = dataset[-1]
        if graph is None:
            pass
        else:
            with torch.no_grad():
                x = graph.to('cpu')
                if args.task == 'monopoly':
                    perc, loss, pred_json, true_json, diff_json = predict_json(dataset, model, json_obj, x,reconstruct=False)
                else:
                    perc, loss, pred_json, true_json, diff_json = predict_json(dataset, model, dataset.jgraph.prune_json(json_obj,task = args.task), x,reconstruct=False)
#                 perc, loss, pred_json, true_json, diff_json = predict_json(dataset, model, dataset.jgraph.prune_json(json_obj,task=args.task), x,reconstruct=False)

                predictions[fname] = {"true":true_json,"prediction":pred_json, "loss":loss, "percentile":perc,"diff_json":diff_json}
                
                #perc,loss,pred = model.compute_novelty_score(x,return_prediction=True)
                #predictions[fname] = {"loss":loss,"pred":pred,"percentile":perc}
                
                loss_percentile[-1].append(perc)
                losses[-1].append(loss)

        
    return loss_percentile,losses,predictions



def predict_json(dataset, model, json_obj, x, reconstruct = False):
    perc,loss,pred = model.compute_novelty_scores(x,return_prediction=True)
    if not reconstruct:
        return perc,loss,None,None,None
    pred_node_vals = dataset.unnormalize(model.prediction_to_nodearray(pred,dataset.node_info))
    pred_json = utils.reconstruct_json(copy.deepcopy(json_obj),pred_node_vals,dataset.jgraph,dataset.node_info)

    true_node_vals = dataset.unnormalize(np.array(x.last_state.numpy().reshape(-1,1),float))
    true_json = utils.reconstruct_json(copy.deepcopy(json_obj),true_node_vals,dataset.jgraph,dataset.node_info)

    diff_json = model.get_diff_dictionary(true_json,pred_json)
    return perc,loss,pred_json,true_json, diff_json
