import argparse, os, sys

from sklearn import datasets

sys.path.append('../')
from dataset.sensordata import make_sensor_datasets 
from dataset.gamedata import GameDataset
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from Autoregressive_model import AutoregressiveModel
from train import train
from scipy.stats import iqr
from test_sensors import test_with_normalized_loss
from test_games import test
import types
from io import StringIO
import re
import pandas as pd
import scipy.stats as stats
import tqdm
import matplotlib.pyplot as plt
import ast
import matplotlib
from dataset.json_graph import JsonToGraph
import time    
import test_json
from dataset import gridworld_jsondata
from dataset import monopoly_jsondata
from pathlib import Path


from json_experiment import make_model_config,load_model_configs

if __name__=='__main__':

    model_type = 'transformer'
    task='monopoly'
    m = 28
    if task == 'gridworld':
        dataset_root = "/home/plymper/data/polycraftv2"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "/home/plymper/data/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)

    config = make_model_config("Predictive",model=model_type, task=task,masks=m,train_path="./", val_path="./")
    hprms = load_model_configs(config.model_type, config)
    trainset = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(trainset, batch_size=50)
    
    model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hprms)
    model.load_model(f"./saved_models/{model_type}_{m}.pth")
    model.to(f"cuda:{config.gpu}")

    model.model.testing=True

    model.model.eval()
    results = test_json.test(model,task=config.task)

    print(results)