import argparse, os, sys
import json

from sklearn import datasets

sys.path.append('../')
from dataset.sensordata import make_sensor_datasets 
from dataset.gamedata import GameDataset
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from Autoregressive_model import AutoregressiveModel
from train import train
from scipy.stats import iqr
from test_sensors import test_with_normalized_loss
from test_games import test
import types
from io import StringIO
import re
import pandas as pd
import scipy.stats as stats
import tqdm
import matplotlib.pyplot as plt
import ast
import matplotlib
from dataset.json_graph import JsonToGraph
import time    

def collect_result(output, prefix=''):
    print(f"{prefix}F1-Score.*",file=sys.stderr)
    f1 = float(re.findall(f"{prefix}F1-Score.*", output)[0].split(' ')[-1])
    prec = float(re.findall(f"{prefix}Precision.*", output)[0].split(' ')[-1])
    rec = float(re.findall(f"{prefix}Recall.*", output)[0].split(' ')[-1])
    valloss = float(re.findall("best.*", output)[0].split(' ')[-1])
    return {"F1":f1,'val_loss':valloss,'Precision':prec,'Recall':rec}

def collect_result_byline(output):
    lines=  output.split("\n")
    scores = []
    for l in lines:
        if "validation scores :" in l and 'hook' in l:
            scores.append(l[len("validation scores :"):].strip())
    return scores


def make_model_config(model_type = "Reconstructing",task='swat', model='transformer', train_path = "./", val_path = "./", masks=50,use_pretrained = False,anomaly_filter=0.95):

    if model_type =='Reconstructing':
        masks = 0
        recon = True
    elif model_type == "Predictive":
        masks = masks
        recon= False

    config = types.SimpleNamespace()
    config.emb_dim = 32
    config.train_dataset_dir=train_path
    config.test_dataset_dir = "./" # ignored
    config.val_dataset_dir = val_path
    config.model_type = model

    config.bsz=400
    config.model_save_dir = f'./K_experiment_models'
    config.task = task
    config.n_masks = masks  #set to 0 for reconstruction
    config.reconstructing=recon # makes the model a reconstruction type model
    config.subsample = 15
    config.gpu=0
    config.validation_step = 1
    config.winsize=10
    config.model_config_root="../"
    config.use_pretrained=use_pretrained
    config.anomaly_filter=anomaly_filter
    config.patience_limit = 15
    return config

def train_hook(test_dataset):
    def f1_calc(model):
        f1= test_with_normalized_loss(test_dataset,model)
        return f1
    return f1_calc


def experiment_measure_during_training(hyper_param_sets,n_train_epochs, trainset,config,trainloader,valloader,test_dataset, hook=None,lr=0.001):
    val_losses =[]
    result_sets = []
    f1_scores = []
    for hyper_params in hyper_param_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead

        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params,lr=lr)
        #train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)#train_hook(test_dataset))
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=hook(test_dataset))
        
        
        test_with_normalized_loss(test_dataset,model)
        output = output_buffer.getvalue()
        results = collect_result(output)
        lines=collect_result_byline(output)
        lines = [ast.literal_eval(line) for line in lines]
        

        results['params'] = str(hyper_params)
        result_sets.extend(lines)


    resultdf=pd.DataFrame(result_sets)
    resultdf=resultdf.rename(columns={"hook_result":"F1"})

    print(resultdf, file=sys.stderr)
    return resultdf

def load_model_configs(model_type,config):
    import json
    n_masks=config.n_masks
    dataset = config.task
    fname=  f"../models/model_configs/{model_type}_{n_masks}_{dataset}.json"
    with open(fname,'r') as f:
        model_hparams = json.load(f)

    return model_hparams

def K_experiment_measure_at_end(hyper_params,n_train_epochs, trainset,config_sets,trainloader,valloader,test_dataset,lr=0.001,hook=None):

    result_sets = []
    for config in config_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead
        if hyper_params is None:
            hprms = load_model_configs(config.model_type, config)
            print(hprms,config,file=sys.stderr)
        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hprms,anomaly_filter=config.anomaly_filter,lr=lr)
        start = time.time()
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=hook)
        end = time.time()
        duration = end-start
        model.model.testing=True
        if config.task in ['wadi','swat']:
            test_with_normalized_loss(test_dataset,model)
            output = output_buffer.getvalue()
            results = collect_result(output)
        
        else:
            
            test(config,model,hyperparams=hyper_params)
            output = output_buffer.getvalue()
            results = collect_result(output, prefix = "Per graph max normalized loss ")
            results1 = collect_result(output, prefix = "Per node normalized loss ")
            results['Node F1']= results1['F1']
        print()
        

        results['params'] = str(hyper_params)
        results['time']=duration
        result_sets.append(results)
        
    resultdf = pd.DataFrame(result_sets)
    
    return resultdf

def hyperparam_run(hyper_params_set ,n_train_epochs, trainset,config,trainloader,valloader,anomaly_filter=0.95,save=False):

    result_sets = []
    for hprms in hyper_params_set:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead
        
        print(hprms,config,file=sys.stderr)
        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, lr=1e-4, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hprms,anomaly_filter=anomaly_filter)
        start = time.time()
        loss =train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)
        end = time.time()
        duration = end-start
        results= {}
        results['val_loss']=loss
        results['params'] = json.dumps(hprms)
        results['time']=duration
        result_sets.append(results)
        
        resultdf = pd.DataFrame(result_sets)
        if save:
            resultdf.to_csv(f"./Finetune_{config.model_type}_{config.task}.csv")
    
    return resultdf


def experiment_measure_at_end(hyper_param_sets,n_train_epochs, trainset,config,trainloader,valloader,test_dataset):
    
    result_sets = []
    for hyper_params in hyper_param_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead

        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params)
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)
        
        if config.task in ['wadi','swat']:
            test_with_normalized_loss(test_dataset,model)
            output = output_buffer.getvalue()
            results = collect_result(output)
        
        else:
            
            test(config)
            output = output_buffer.getvalue()
            results = collect_result(output, prefix = "Per graph max normalized loss ")
            results1 = collect_result(output, prefix = "Per node normalized loss ")
            results['Node F1']= results1['F1']
        
        

        results['params'] = str(hyper_params)
        result_sets.append(results)
    resultdf = pd.DataFrame(result_sets)
    
    return resultdf


def valloss_f1_experiment(model,task,masks,epochs=200):
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False)
    trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
    trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    
    #### --- Reconstructing --- ####

    hyper_params = [{"output_dim":64,"n_layers":1, 'emb_dim':64}]
    #hyper_params = [{"out_dim":256,"n_layers":1, 'h_dim':256}]
    lr=1e-4
    #reconst_resultdf = experiment_measure_during_training(hyper_params,epochs, trainset,config,trainloader,valloader,test_dataset, hook=train_hook, lr=lr)
    #print(reconst_resultdf,file=sys.stderr)
    #
    #reconst_resultdf.to_csv(f"Reconst_Val_F1_during_training_{model}_{task}.csv")

    #### --- Predictive --- ####

    model_family = 'Predictive'
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False)

    pred_resultdf= experiment_measure_during_training(hyper_params,epochs, trainset,config,trainloader,valloader,test_dataset, hook=train_hook, lr=lr)
    
    pred_resultdf.to_csv(f"Pred_Val_F1_during_training_{model}_{task}.csv")


def plot_F1_and_vallosses_during_training_individual():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.title("F1 score on SWaT test set for reconstructing model during training.")
    
    

    plt.savefig("Reconstructing_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for reconstructing model during training.")
    plt.plot(reconst_resultdf['prediction_val_loss'],'b')
    plt.savefig("Reconstructing_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(pred_resultdf['F1'],'g')
    plt.title("F1 score on SWaT test set for predictive model during training.")
    
    plt.savefig("Predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    #############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for predictive model during training.")
    plt.plot(pred_resultdf['prediction_val_loss'],'g')
    plt.savefig("Predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    
def plot_F1_and_vallosses_during_training():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.plot(pred_resultdf['F1'],'g')
    plt.xlabel("Iterations")
    plt.ylabel("F1 score")
    plt.legend(["F1 Reconstructing","Predictive"])
    plt.title("F1 score on SWaT test set during training.")
    
    
    
    plt.savefig("Reconstructing_predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['prediction_val_loss'],'b', fmt='--')
    plt.plot(pred_resultdf['prediction_val_loss'],'g', fmt='--')
    plt.xlabel("Iterations")
    plt.ylabel("Validation Loss")
    plt.legend(["Reconstructing","Predictive"])
    plt.title("Validation loss on SWaT during training.")
    
    plt.savefig("Reconstructing_predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')


def plot_F1_val_during_training_oneplot(model,task):
    

    reconst_resultdf = pd.read_csv(f"Reconst_Val_F1_during_training_{model}_{task}.csv",index_col=0)
    pred_resultdf = pd.read_csv(f"Pred_Val_F1_during_training_{model}_{task}.csv",index_col=0)
    fig, ax1 = plt.subplots(figsize=(20,10))
    
    ax2 = ax1.twinx() 
    
    ax1.plot(reconst_resultdf['F1'],'b', linewidth=3)
    ax1.plot(pred_resultdf['F1'],'g', linewidth=3)
    ax1.set_xlabel("Iterations", fontname="Arial", fontsize=fontsize)
    ax1.set_ylabel("F1 score", fontname="Arial", fontsize=fontsize)
    
    
    ax1.set_title(f"F1 score and validation loss on {task} test set during training.", fontname="Arial", fontsize=fontsize)
    ##############
    ax2.plot(reconst_resultdf['prediction_val_loss'],'b--', linewidth=3)
    ax2.plot(pred_resultdf['prediction_val_loss'],'g--', linewidth=3)
    ax2.set_ylabel("Error", fontname="Arial", fontsize=fontsize)
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax1.legend(["Reconstruction F1","Masked Predictive F1"], loc='upper left', bbox_to_anchor=(1.1, 1))
    ax2.legend(["Reconstruction Val. Loss","Masked Predictive Val. Loss"], loc='upper left', bbox_to_anchor=(1.1, 0.8))
    #ax2.legend(["Reconstructing","Predictive"])
    #plt.title("Validation loss on SWaT during training.")
    
    fig.savefig(f"Reconstructing_predictive_model_valloss_during_training_merged_{model}_{task}.jpg",dpi=250,bbox_inches="tight",pad_inches=0.5)


def sparsity_experiment(masks, repeats, epochs=200, model='transformer', task='swat',sparsity=30,use_pretrained = False,lr=None):

    model_family = 'Predictive'
    repeated_masks=masks*repeats
    hyper_params =None 
    all_configs = []
    for m in repeated_masks:
        config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val",use_pretrained=use_pretrained)
        config.subsample=sparsity
        config.use_json_graph=False
        all_configs.append(config)
    #####################################################
    ####--- Predicive ---####
    if config.task in ['gridworld','monopoly']:
        raise ValueError("This experiment is not applicable to this task.")

    elif config.task in ['swat','wadi']:
        dataset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
        trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)
        valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    else:
        raise ValueError(f"{config.task} Not a valid task")

    
    pred_resultsdf = K_experiment_measure_at_end(hyper_params,epochs,dataset,all_configs,trainloader,valloader,test_dataset=test_dataset,lr=lr)

    pred_resultsdf['K']=repeated_masks
    
    pred_resultsdf.to_csv(f"K_experiment_prediction_{model}_{task}_{sparsity}.csv")




def widths_experiment(model,task,masks,widths, repeats, epochs=200):
    '''
    In this experiment we measure reconstructing and predictive performance for different widths. Our purpose is to show that predictive models are more reliable because they are not given the ground truth in their input. 
    Our hypothesis is that reconstruction model performance is going to suffer as the information bottleneck gets smaller, but this is difficult to control and is not reflected in the validation loss.
    '''
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False)
    trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
    trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)

    repeated_widths = widths*repeats

    if model == 'mlp':
        hyper_param_sets = [{"out_dim":o,"n_layers":1, 'h_dim':o, 'dropout_rate':0.2}  for o in repeated_widths for l in [2]]
    elif model == 'transformer':
        hyper_param_sets = [{"output_dim":o,"n_layers":l, 'emb_dim':o} for o in repeated_widths for l in [2]]
    
    ####--- Reconstruction ---####

    reconst_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset)
    reconst_resultsdf['widths']=repeated_widths
    reconst_resultsdf.to_csv(f"Widths_experiment_reconstruction_{model}_{task}.csv")


    #####################################################
    ####--- Predicive ---####

    model_family = 'Predictive'
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False)
    pred_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset)

    pred_resultsdf['widths']=repeated_widths
    

    pred_resultsdf.to_csv(f"Widths_experiment_prediction_{model}_{task}.csv")

    return 

def K_experiment(masks, repeats, epochs=200, model='transformer', task='swat',use_pretrained = False,save=True,lr=None,hook=None):
    '''
    In this experiment we measure the predictive and anomaly detection performance with varying dropout'''
    

    model_family = 'Predictive'
    repeated_masks = masks*repeats
    params = {
        'mlp':{'out_dim': 128, 'n_layers': 3, 'h_dim': 32, 'dropout_rate': 0.2},
        'transformer':{"output_dim":64,"n_layers":3,"emb_dim":64, "dropout_rate":0.2},
        'GDN':{"output_dim":32,"dropout_rate":0.2,"topk":5}

    }
    hyper_params = params[model]
   
    hyper_params =None # json.load(f)


    #hyper_params = {"out_dim":64,"n_layers":2, "h_dim":64, "dropout_rate":0.2}
    #hyper_params ={"output_dim":32,"n_layers":2,"emb_dim":64}
    #hyper_params ={"output_dim":32,"dropout_rate":0.2}
    
    all_configs = []
    for m in repeated_masks:
        config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val",use_pretrained=use_pretrained)
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/plymper/data/polycraftv2/normal/jsons", val_path="/home/plymper/data/polycraftv2/normal/val")
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/plymper/data/gridworldsData/train_data", val_path="/home/plymper/data/gridworldsData/novel_nonov/")
        
        config.use_json_graph=False
        all_configs.append(config)
    #####################################################
    ####--- Predicive ---####
    if config.task in ['gridworld','monopoly']:
        dataset = GameDataset(data_path=config.train_dataset_dir, concat_steps=config.winsize, mode='training',ignore_intermediate_nodes= not config.use_json_graph, task = config.task)
        trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)
        val_dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=config.winsize, mode='validation',ignore_intermediate_nodes=not config.use_json_graph, task = config.task)
        valloader=GraphDataLoader(val_dataset, batch_size=config.bsz, shuffle=False)
        test_dataset=None
    elif config.task in ['swat','wadi']:
        dataset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
        trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)
        valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    else:
        raise ValueError(f"{config.task} Not a valid task")

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    
    pred_resultsdf = K_experiment_measure_at_end(hyper_params,epochs,dataset,all_configs,trainloader,valloader,test_dataset=test_dataset,lr=lr,hook=hook)

    pred_resultsdf['K']=repeated_masks
    
    if task=='monopoly':
        t ='monopoly_old'
    else:
        t=task
    if save:
        pred_resultsdf.to_csv(f"K_experiment_prediction_{model}_{t}.csv")

    return 

def anomaly_filter_experiment(masks,filters, repeats, epochs=200, model='transformer', task='swat',use_pretrained = False):
    '''
    In this experiment we measure the predictive and anomaly detection performance with varying dropout'''
    

    model_family = 'Predictive'
    repeated_filters = filters*repeats
    


    #hyper_params = {"out_dim":64,"n_layers":2, "h_dim":64, "dropout_rate":0.2}
    #hyper_params ={"output_dim":32,"n_layers":2,"emb_dim":64}
    #hyper_params ={"output_dim":32,"dropout_rate":0.2}
    hyper_params=None
    
    all_configs = []
    m=masks
    for f in repeated_filters:
        config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val",use_pretrained=use_pretrained,anomaly_filter=f)
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/plymper/data/polycraftv2/normal/jsons", val_path="/home/plymper/data/polycraftv2/normal/val")
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/plymper/data/gridworldsData/train_data", val_path="/home/plymper/data/gridworldsData/novel_nonov/")
        
        config.use_json_graph=False
        all_configs.append(config)
    #####################################################
    ####--- Predicive ---####
    if config.task in ['gridworld','monopoly']:
        dataset = GameDataset(data_path=config.train_dataset_dir, concat_steps=config.winsize, mode='training',ignore_intermediate_nodes= not config.use_json_graph, task = config.task)
        trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)
        val_dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=config.winsize, mode='validation',ignore_intermediate_nodes=not config.use_json_graph, task = config.task)
        valloader=GraphDataLoader(val_dataset, batch_size=config.bsz, shuffle=False)
        test_dataset=None
    elif config.task in ['swat','wadi']:
        dataset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
        trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)
        valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    else:
        raise ValueError(f"{config.task} Not a valid task")

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    
    pred_resultsdf = K_experiment_measure_at_end(hyper_params,epochs,dataset,all_configs,trainloader,valloader,test_dataset=test_dataset)

    pred_resultsdf['anomaly_filters']=repeated_filters
    
    if task=='monopoly':
        t ='monopoly_old'
    else:
        t=task
    pred_resultsdf.to_csv(f"anoomaly_filter_experiment_prediction_{model}_{t}.csv")

    return 




def finetune(task,model,epochs=200):
    
    model_family = 'Predictive'
    
    with open(f"/home/plymper/graph-anomaly-detection-clean/models/model_configs/finetuning/{model}_0.json") as f:
        hyper_params_set = json.load(f)
    m=0
    
    config = make_model_config(model_family,model=model, task=task,masks=m,train_path="./", val_path="./",use_pretrained=False)
    
    config.use_json_graph=False

    #####################################################
    ####--- Predicive ---####
    if config.task in ['swat','wadi']:
        trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
        trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
        valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    else:
        raise ValueError(f"{config.task} Not a valid task")

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    
    pred_resultsdf = hyperparam_run(hyper_params_set ,epochs, trainset,config,trainloader,valloader)
    

    return 



def K_plot(masks,model,repeats,task):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    if task=='monopoly':
        t='monopoly_old'
    else:
        t=task
    dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{t}.csv",index_col=0)
    

    pred_avgd = dropout_resultsdf.groupby(['K']).mean()
    pred_std = dropout_resultsdf.groupby(['K']).std()
    
    from scipy.stats import ttest_rel
    tmp = dropout_resultsdf.groupby(['K'])
    tmp1 = tmp.get_group(0)["F1"]
    tmp2 = tmp.get_group(masks[-1])["F1"]
    print("Pval:",ttest_rel(tmp1,tmp2))

    fig,ax = plt.subplots(figsize=(20,10))
    #plt.figure()
    
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    print("Step")
    print(y)
    print(err)
    #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    try:
        y = pred_avgd['Node F1'].values
        err = pred_std['Node F1'].values/np.sqrt(len(y))
        print("Node")
        print(y)
        print(err)
        
    except KeyError as e:
        print("No node scores")
        print(e)

    ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    #print(y,err)
    #ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
    
    y = pred_avgd['val_loss'].values
    err = pred_std['val_loss'].values/np.sqrt(len(y))
    #y= y[:idx+1]
    #err= err[:idx+1]
    
    

    ax2 = ax.twinx() 
    ax2.errorbar(masks,y,yerr =err,marker= 'o', color='g', linewidth=3)
    ax.set_xticks(masks)
    ax.set_xlabel("K")
    ax.set_ylabel("F1 score")
    ax2.set_ylabel("Prediction Error")
    
    ax.set_title(f"Monopoly F1 scores and Validation Loss of MLP with varying K")

    ax.legend(["Feature F1 Score"], loc = 'upper left',bbox_to_anchor=(0, 0.88))
    ax2.legend(["Prediction Error"], loc = 'upper left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"F1_loss_masking_predictive.jpg", dpi = 500, bbox_inches='tight')
    
    
    print(pred_avgd['time'],pred_std['time'])

    return

def width_plot(model,task,widths):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    reconst_resultsdf = pd.read_csv(f"Widths_experiment_reconstruction_{model}_{task}.csv",index_col=0)
    pred_resultsdf = pd.read_csv(f"Widths_experiment_prediction_{model}_{task}.csv",index_col=0)

    pred_avgd = pred_resultsdf.groupby(['widths']).mean()
    pred_std = pred_resultsdf.groupby(['widths']).std()

    reconst_avgd = reconst_resultsdf.groupby(['widths']).mean()
    recons_std = reconst_resultsdf.groupby(['widths']).std()
    
    x_axis = pred_avgd.index.values
    
    fig,ax = plt.subplots(figsize=(20,10))
    #ax.set_xscale('log',basex=2)
    #plt.figure()
    y = reconst_avgd['F1'].values
    err = recons_std['F1'].values/np.sqrt(len(y))
    print(reconst_resultsdf)
    ax.errorbar(x_axis,y,yerr =err, marker= 'o', color='b', linewidth=3)

    #ax2 = ax.twinx() 
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(len(y))
    print(pred_std)
    print(err)
    ax.errorbar(x_axis,y,yerr =err,marker= 'o', color='g', linewidth=3)
    
    xticks = x_axis#[f"$2^{w}$" for w in range(1,8)]
    print("-------------",x_axis)
    print(widths)
    print(xticks)
    #ax.set_xticklabels(xticks)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.set_xlabel("Embedding Dimension")
    ax.set_ylabel("F1 score")
    #ax2.set_ylabel("Validation Loss")
    
    ax.legend(["Reconstruction Model","Masked Predictive Model"],loc='lower left')
    ax.set_title(f"F1 scores of Models with varying embedding dimension")

    
    fig.savefig(f"F1_width_reconst_predictive_{model}_{task}.jpg", dpi = 500, bbox_inches='tight')
    


    return

def main():
    
    import time
    
    torch.manual_seed(0)
    np.random.seed(0)

    from slack_message import send_message

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 35}

    matplotlib.rc('font', **font)
    
    task_masks = {"monopoly":[0,19],
                  "swat":[0,50],
                  "wadi":[0,127]
                }

    
    
    repeats=10
    try:
        for model in ['GDN','transformer','mlp']:
            for task in ['swat','wadi']:
                
                masks =task_masks[task]
                K_experiment(masks,repeats,model=model,epochs=500, task=task,use_pretrained=True,lr=0.0001)
    except Exception as e :
        send_message("Err K")
        raise e

    
    send_message("Done K")

def epoch_time_measure():
    epochs=30
    repeats= 1 #avg over epochs so no need to rerun
    task_masks = {"gridworld":[0,26],
                  "monopoly":[0,9],
                  "swat":[0,50],
                  "wadi":[0,127]
                }

    for task in ['swat','wadi']:
        masks = task_masks[task]

        for model in ['GDN','transformer','mlp']:

            K_experiment(masks,repeats,model=model,epochs=epochs, task=task,use_pretrained=True,save=False,lr=0.001,hook = lambda x: "save_times")
            
    


    return


def run_sparse_experiment():
    
    import time
    
    torch.manual_seed(0)
    np.random.seed(0)

    from slack_message import send_message

    font = {'family' : 'Arial',
        'size'   : 35}

    matplotlib.rc('font', **font)
    
    task_masks = {"monopoly":[0,19],
                  "swat":[0,50],
                  "wadi":[0,127]
                }

    
    
    repeats=10
    try:
        for sparsity in [10,20,25]: #15 is main experiment anyway
            for task in ['swat','wadi']:
                for model in ['transformer','mlp','GDN']:
                    
                    masks =task_masks[task]
                    sparsity_experiment(masks,repeats,model=model,epochs=500, task=task,sparsity=sparsity,use_pretrained=True,lr=0.001)
                    #K_plot(masks,model,repeats,task)
    except Exception as e :
        send_message("Err K")
        raise e

    
    send_message("Done sparsity")

def pred_reconst_experiment():
    task_masks = {"monopoly":[0,19],
                  "swat":[0,50],
                  "wadi":[0,127]
                }
    valloss_f1_experiment('transformer','swat',50,500)
    plot_F1_val_during_training_oneplot('transformer','swat')   

    widths = [8,16,32,64,128]
    model='transformer'
    task='swat'
    repeats = 10
    masks = task_masks[task][-1]
    widths_experiment(model,task,masks,widths,repeats,500)
    width_plot(model,task,widths)

fontsize=40
if __name__=='__main__':

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 35}

    matplotlib.rc('font', **font)
    task_masks = {"monopoly":[0,19],
                  "swat":[0,50],
                  "wadi":[0,127]
                }


    main() #main experiment with this dataset 
    #epoch_time_measure()
    #run_sparse_experiment() #time subsample experiment
    #pred_reconst_experiment() 

    

   