import argparse, os
import sys
sys.path.append('../')
from PIL.Image import new

from dataset.data_gan import MyDataset
from dataset.gamedata import GameDataset
from dataset.json_graph import JsonToGraph
from torch.utils.data import DataLoader,Dataset
import pickle 
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
import pickle as pkl
from sklearn.metrics import roc_auc_score
import numpy as np
import types
from natsort import natsorted
from Autoregressive_model import AutoregressiveModel
from gan_model import GANModel

import matplotlib.pyplot as plt
import main
import json


def eval_model(test_dataset_dir,n_masks, idx):
    config = types.SimpleNamespace()
    config.emb_dim = 32
    
    config.test_dataset_dir = test_dataset_dir
    
    config.val_dataset_dir = "/home/plymper/data/gridworldsData/novel_nonov/"
    config.model_type = 'mlp'
    config.n_masks =n_masks
    config.model_save_dir = f'../results/mlp_masks_{config.n_masks}_{idx}'
    config.task = 'gridworld'

    dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=5, mode='test') #use same jgraph obj as validation set
    model = AutoregressiveModel(dataset.num_nodes, dataset.node_feature_dim, dataset.node_info,model= config.model_type, config = config)
    pth=os.path.join(config.model_save_dir,config.model_type+".pth")
    print(pth)
    model.load_model(pth)
    model.to('cpu')
    loss_perc,test_losses,predictions = main.run_on_jsons(config,dataset,model)


    with open(os.path.join(config.test_dataset_dir,'labels.pkl'),'rb') as f:
        labels = pkl.load(f)


    graph_perc = []
    for j in loss_perc:
        for d in j:
            graph_perc.append(d['graph'][0].item())

    test_losses = torch.cat([i for sublist in test_losses for i in sublist ])
    test_labels = [item for subl in labels for item in subl[4:]]
    test_labels= [0]*(len(graph_perc) - len(test_labels)) + test_labels

    graph_scores = roc_auc_score(test_labels,test_losses),roc_auc_score(test_labels,graph_perc)

    with open(os.path.join(config.test_dataset_dir,'node_labels.pkl'),'rb') as f:
        node_labels = pkl.load(f)
        
    test_labels = [item for subl in node_labels for item in subl[4:]]
    test_labels= [np.zeros(dataset.num_nodes)]*(len(graph_perc) - len(test_labels)) + test_labels
    per_node_test_labels = np.concatenate(test_labels)

    all_node_predictions = []
    node_info = dataset.node_info

    for k,preds in predictions.items():
        node_preds = np.zeros(dataset.num_nodes)
        numerical = preds['percentile']['losses']['numerical']
        binary = preds['percentile']['losses']['binary']
        categorical = preds['percentile']['losses']['categorical']

        node_preds[node_info['cat_nodes']]=categorical
        node_preds[node_info['nume_nodes']]=numerical
        node_preds[node_info['bin_nodes']]=binary
        all_node_predictions.append(node_preds)
        
    per_node_loss_predictions=np.concatenate(all_node_predictions)

    all_node_predictions = []
    node_info = dataset.node_info

    for k,preds in predictions.items():
        node_preds = np.zeros(dataset.num_nodes)
        numerical = preds['percentile']['numerical']
        binary = preds['percentile']['binary']
        categorical = preds['percentile']['categorical']

        node_preds[node_info['cat_nodes']]=categorical
        node_preds[node_info['nume_nodes']]=numerical
        node_preds[node_info['bin_nodes']]=binary
        all_node_predictions.append(node_preds)
        
    per_node_perc_predictions=np.concatenate(all_node_predictions)


    node_scores = roc_auc_score(per_node_test_labels,per_node_loss_predictions),roc_auc_score(per_node_test_labels,per_node_perc_predictions)

    print("-------------------")
    return node_scores,graph_scores


if __name__=="__main__":

    nov1 = "/home/plymper/data/gridworldsData/mixed_inventoryreset"
    nov2 = "/home/plymper/data/gridworldsData/mixed_breakincrease"
    nov3 = "/home/plymper/data/gridworldsData/mixed_woodgift"

    paths = [nov1,nov2,nov3]
    masks = [1,29,116]

    results={}

    for mask in masks:
        results[mask] = {}
        for path in paths:
            results[mask][path.split('_')[-1]]={"node":{"loss":0,"perc":0},"graph":{"loss":0,"perc":0}}
            node_avg =np.array([0.0,0.0])
            graph_avg =np.array([0.0,0.0])
            for i in range(4):
                node_score,graph_score = eval_model(path,mask,i)
                print(node_score)
                node_avg+=np.array(node_score)/4
                graph_avg+=np.array(graph_score)/4
                
            results[mask][path.split('_')[-1]]['node']['loss'] = node_avg[0]
            results[mask][path.split('_')[-1]]['node']['perc'] = node_avg[1]
            results[mask][path.split('_')[-1]]['graph']['loss'] = graph_avg[0]
            results[mask][path.split('_')[-1]]['graph']['perc'] = graph_avg[1]

    
    with open('results.json','w') as f:
        json.dump(results,f)

    print(results)


    