import argparse, os, sys
sys.path.append('../')
from dataset.sensordata import make_sensor_datasets 
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from Autoregressive_model import AutoregressiveModel
from train import train
from scipy.stats import iqr
from test_sensors import test_with_normalized_loss
import types
from io import StringIO
import re
import pandas as pd
import scipy.stats as stats
import tqdm


def collect_result(output):
    f1 = float(re.findall("F1.*", output)[0].split(' ')[-1])
    valloss = float(re.findall("best.*", output)[0].split(' ')[-1])
    return {"F1":f1,'val_loss':valloss}



config = types.SimpleNamespace()
config.emb_dim = 32
config.test_dataset_dir = "./" # ignored
config.val_dataset_dir = "./"
config.model_type = 'transformer'

config.bsz=50
config.model_save_dir = f'./hyperparam_plots_results'
config.task = 'swat'
config.n_masks = 50  #set to 0 for reconstruction
config.reconstructing=False # makes the model a reconstruction type model
config.subsample = 15
config.gpu=0
config.validation_step = 1000

trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=5,subsample=config.subsample)
trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)

def experiment(hyper_param_sets,n_train_epochs):
    
    result_sets = []
    for hyper_params in hyper_param_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead

        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params)
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs)
        test_with_normalized_loss(test_dataset,model)
        output = output_buffer.getvalue()
        results = collect_result(output)
        results['params'] = str(hyper_params)
        result_sets.append(results)
    
    return pd.DataFrame(result_sets)

#Some hyper-parameter configurations

hyper_param_sets = [{"output_dim":o,"n_layers":l, "lambda_":lamb} for o in range(4,17,2) for l in [1,2] for lamb in [1,2,3]]


resultdf = experiment(hyper_param_sets,50)


result = stats.spearmanr(a=resultdf['F1'],b = resultdf['val_loss'])

sys.stdout = sys.__stdout__
print(result)

resultdf.to_csv("f1_valloss.csv")