
import pandas as pd
import json
import torch
import importlib.resources
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import os
import pickle
from sklearn import metrics
import argparse
import copy
from dataset.gridworld_json2vec import data_info
from dataset import monopoly_json2vec,gridworld_json2vec
import dataset.gridworld_jsondata as gridworld_jsondata
import dataset.monopoly_jsondata as monopoly_jsondata
import sys
from Autoregressive_model import AutoregressiveModel
from test_games import evaluate_localization,evaluate_detection

def run_on_episode(model, path, dataset):
    total_error=0
    fnames = list(filter(lambda x: '.json' in x,os.listdir(path)))
    fnames = sorted(fnames,key = lambda x: int(x.split(".")[0])) # numerical sorting of filenames
    errors = []
    vec_errors = []
    labels = []
    for i,f in enumerate(fnames):
        fpath = path+f"/{f}"
        with open(fpath,'r') as f:
            json_obj = json.load(f)
        
        dataset.receive_json_obj(json_obj,new_episode= i==0)

        sample = dataset[-1]
        if sample is None:
            continue
        
        #pred = model.model(sample.to(model.device),per_node_val_losses=model.per_node_val_losses)
        #error, error_dict = model.loss_func(pred,sample, return_separate_losses=True)
        
        error,error_dict = model.compute_novelty_scores(sample.to(model.device))
        vec_errors.append(error)
        errors.append(error_dict)
        labels.append(sample.labels)
        #print(fpath,error)
        total_error+=error.item()
    mean_error = total_error/len(fnames)
    return mean_error, errors, labels, vec_errors


def run_on_dir(model,path, dataset):
    tmp =list(filter(lambda x: str(x).isnumeric(),os.listdir(path)))
    episodes = list(sorted(tmp,key=lambda x: int(x)))
    ep_names = []
    ep_scores = []
    ep_scores_vec = []
    ep_scores_all = []
    ep_labels_all = []
    prefix = path.split("/")[-1]

    for ep in episodes:
        episode_error, errors,labels, vec_errors= run_on_episode(model,f"{path}/{ep}/",dataset)
        ep_name = f"{prefix}/{ep}"
        ep_names.append(ep_name)
        ep_scores.append(episode_error)
        ep_scores_all.append(errors)
        ep_labels_all.append(labels)
        ep_scores_vec.append(vec_errors)

    return ep_names,ep_scores,ep_scores_all, ep_labels_all, ep_scores_vec
        

def test(model, task = 'gridworld'):
    if task=='gridworld':
        base_path = "../dataset/polycraft/"
        GameDataset = gridworld_jsondata.GameDataset
        
    elif task=='monopoly':
        base_path = "../dataset/monopoly/"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(base_path)

    ep_names,ep_scores = [],[]
    test_paths = ["normal_test"]
    
    all_errors = {}
    all_labels = {}
    all_vec_errors = {}
    for p in test_paths:
        #loads in test mode so validation info only loaded
        test_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=model.n_steps+1, mode="test")
        a,b,c,d,e = run_on_dir(model,base_path+p, test_set)
        ep_names.extend(a)
        ep_scores.extend(b)
        for x,y in zip(a,c):
            all_errors[x]=y
        
        for x,y in zip(a,d):
            all_labels[x]=y
        for x,y in zip(a,e):
            all_vec_errors[x]=y

       
    #results = {"names":ep_names,"scores":ep_scores}
    n_masks = model.n_masks
    with open(f'./polycraft_scores/{task}_{n_masks}_test_errors.pkl','wb') as f:
        pickle.dump(all_errors,f)
    with open(f'./polycraft_scores/{task}_{n_masks}_test_labels.pkl','wb') as f:
        pickle.dump(all_labels,f)

    if model.config.__dict__.get("use_val_thresh",False):
        val_loss = {"graph":model.graph_val_losses, "feature":model.per_node_val_losses}
    else:
        val_loss = None
    
    results = evaluate_localization(all_errors,all_labels,task=task,val_loss=val_loss)
    results1=evaluate_detection(all_errors,all_labels, val_loss=val_loss, vec_errors=all_vec_errors)
    results.update(results1)
    return results

def pred_vec(pred):
    vals = []
    for k,v in pred.items():
        vals.append(v.cpu())

    
    return torch.cat(vals,dim=-1)

if __name__=='__main__':

    parser = argparse.ArgumentParser(description='graph-anomaly-detection')
    # model parameters

    parser.add_argument("--model_type", type=str, default="mlp", help="mlp or GCN or GAT or GraphSAGE")
    parser.add_argument("--model_save_dir", type=str, default="./polycraft_results", help="save dir")
    parser.add_argument("--winsize", type=int, default=10, help="window size")


    args = parser.parse_args()

    with importlib.resources.path("polycraft_nov_data", "dataset") as dataset_root:
        DATASET_ROOT = Path(dataset_root)


    test_set = gridworld_jsondata.GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=10, mode="test")

    hyperparam = {"h_dim":64}
    model = AutoregressiveModel(test_set.num_nodes,test_set.concat_steps,test_set.node_info,'mlp',args,device ='cuda', hyper_params=hyperparam)
    model.load_model(f"{args.model_save_dir}/mlp_1.pth")
    model.to('cuda')

    results = test(model)

    pd.DataFrame(results).to_csv("results.csv")




