import argparse, os, sys

from sklearn import datasets

sys.path.append('../')
from dataset.sensordata import make_sensor_datasets 
from dataset.gamedata import GameDataset
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from Autoregressive_model import AutoregressiveModel
from train import train
from scipy.stats import iqr
from test_sensors import test_with_normalized_loss
from test_games import test
import types
from io import StringIO
import re
import pandas as pd
import scipy.stats as stats
import tqdm
import matplotlib.pyplot as plt
import ast
import matplotlib
from dataset.json_graph import JsonToGraph
import time    
import test_json
from dataset import gridworld_jsondata
from dataset import monopoly_jsondata
from pathlib import Path
import json

from slack_message import send_message

model_name={'mlp':"MLP","GDN":"GDN","transformer":"OAT"}
task_name = {"gridworld":"Polycraftv2","swat":"SWAT","wadi":"WADI","monopoly":"Monopoly"}

def collect_result(output, prefix=''):
    #print(f"{prefix}F1-Score.*",file=sys.stderr)
    f1 = float(re.findall(f"{prefix}F1.*", output)[0].split(' ')[-1].strip("}"))
    valloss = float(re.findall("best.*", output)[0].split(' ')[-1].strip("}"))
    return {"F1":f1,'val_loss':valloss}

def collect_result_byline(output):
    lines=  output.split("\n")
    scores = []
    for l in lines:
        if "validation scores :" in l and 'hook' in l:
            scores.append(l[len("validation scores :"):].strip())
    return scores


def make_model_config(model_type = "Reconstructing",task='swat', model='transformer', train_path = "./", val_path = "./", masks=50, use_pretrained = False,anomaly_filter=0.95,lr=1e-3, use_val_thresh=False):

    if model_type =='Reconstructing':
        masks = 0
        recon = True
    elif model_type == "Predictive":
        masks = masks
        recon= False

    config = types.SimpleNamespace()
    config.emb_dim = 32
    config.train_dataset_dir=train_path
    config.test_dataset_dir = "./" # ignored
    config.val_dataset_dir = val_path
    config.model_type = model

    config.bsz=256
    config.model_save_dir = f'./saved_models'
    config.task = task
    config.n_masks = masks  #set to 0 for reconstruction
    config.reconstructing=recon # makes the model a reconstruction type model
    config.subsample = 20
    config.gpu=0
    config.validation_step = 1
    config.winsize=10
    config.model_config_root="../"
    config.use_pretrained = use_pretrained
    config.anomaly_filter = anomaly_filter
    config.lr = lr
    config.use_val_thresh=use_val_thresh
    config.patience_limit = 5

    return config

def train_hook(model,config):
    def f1_calc(model):
        model.model.eval()
        results = test_json.test(model,task=config.task)
        model.model.train()
        return results
    return f1_calc


def experiment_measure_during_training(hyper_params,n_train_epochs, trainset,config,trainloader,valloader,test_dataset, hook=None):
    val_losses =[]
    result_sets = []
    f1_scores = []
    
    sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead
    if hyper_params is None:
        hyper_params = load_model_configs(config.model_type, config)

    model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params,lr=0.0001)
    #train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)#train_hook(test_dataset))
    train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=hook(model,config))
    
    model.model.eval()
    results = test_json.test(model,task=config.task)
    output = output_buffer.getvalue()
    
    results = collect_result(output,prefix="Localization ")
    print(results,file=sys.stderr)
    results1 = collect_result(output,prefix="")
    results.update(results1)
    lines=collect_result_byline(output)
    lines = [ast.literal_eval(line) for line in lines]
    

    results['params'] = str(hyper_params)
    result_sets.extend(lines)


    resultdf=pd.DataFrame(result_sets)
    

    print(resultdf, file=sys.stderr)
    return resultdf

def load_model_configs(model_type,config):
    import json
    n_masks=config.n_masks
    dataset = config.task
    fname=  f"../models/model_configs/{model_type}_{n_masks}_{dataset}.json"
    with open(fname,'r') as f:
        model_hparams = json.load(f)

    return model_hparams

def K_experiment_measure_at_end(hyper_params,n_train_epochs, trainset,config_sets,trainloader,valloader,test_dataset,lr=None,hook=None):

    result_sets = []
    for config in config_sets:
   
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead
        if hyper_params is None:
            hprms = load_model_configs(config.model_type, config)
            print(hprms,config,file=sys.stderr)
        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hprms,lr=lr)
        start = time.time()
        val_loss = train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=hook)
        end = time.time()
        duration = end-start
        model.model.testing=True
        if config.task in ['wadi','swat']:
            test_with_normalized_loss(test_dataset,model)
            output = output_buffer.getvalue()
            results = collect_result(output)
        
        else:
            model.model.eval()
            results = test_json.test(model,task=config.task)
            output = output_buffer.getvalue()         
        
        
        results['val_loss']=val_loss
        results['params'] = str(hyper_params)
        results['time']=duration
        result_sets.append(results)
        
    resultdf = pd.DataFrame(result_sets)
    
    return resultdf

def hyperparams_run(hyper_params_set,n_train_epochs, trainset,config,trainloader,valloader,save=False,test=False):

    result_sets = []
    for hyper_params in hyper_params_set:
   
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead
        print(hyper_params,file=sys.stderr)
        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params,lr=config.lr)
        start = time.time()
        val_loss = train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)
        end = time.time()
        duration = end-start
        model.model.testing=True
        results = {}
        results['val_loss']=val_loss
        results['params'] = json.dumps(hyper_params)
        results['time']=duration

        if test:
            model.model.eval()
            res = test_json.test(model,task=config.task)
            results.update(res)

        result_sets.append(results)
        
        resultdf = pd.DataFrame(result_sets)
        if save:
            resultdf.to_csv(f"Finetune_{config.model_type}_{config.task}.csv")


    
    return resultdf



def valloss_f1_experiment(model,masks,task,epochs=200):
    model_family = 'Reconstructing'
    config = make_model_config(model_family,model=model, task=task,masks=masks,train_path="./", val_path="./",use_pretrained=False)
        
    if task == 'gridworld':
        dataset_root = "../dataset/polycraft"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "../dataset/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)
        
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    
    #### --- Reconstructing --- ####

    hyper_params = {"output_dim":256,"emb_dim":256,"n_layers":2,"dropout_rate":0.2}
    #hyper_params={"out_dim":256,"h_dim":256, "n_layers":1,"dropout_rate":0.2}

    reconst_resultdf = experiment_measure_during_training(hyper_params,epochs, train_set,config,train_loader,valid_loader,None, hook=train_hook)
    print(reconst_resultdf,file=sys.stderr)
    
    reconst_resultdf.to_csv(f"Reconst_Val_F1_during_training_{model}_{task}.csv")

    #### --- Predictive --- ####

    model_family = 'Predictive'

    config = make_model_config(model_family,model=model, task=task,masks=masks,train_path="./", val_path="./",use_pretrained=False)

    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    pred_resultdf= experiment_measure_during_training(hyper_params,epochs, train_set,config,train_loader,valid_loader,None, hook=train_hook)
    
    pred_resultdf.to_csv(f"Pred_Val_F1_during_training_{model}_{task}.csv")


def plot_F1_and_vallosses_during_training_individual():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.title("F1 score on SWaT test set for reconstructing model during training.")
    
    

    plt.savefig("Reconstructing_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for reconstructing model during training.")
    plt.plot(reconst_resultdf['prediction_val_loss'],'b')
    plt.savefig("Reconstructing_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(pred_resultdf['F1'],'g')
    plt.title("F1 score on SWaT test set for predictive model during training.")
    
    plt.savefig("Predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    #############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for predictive model during training.")
    plt.plot(pred_resultdf['prediction_val_loss'],'g')
    plt.savefig("Predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    
def plot_F1_and_vallosses_during_training():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.plot(pred_resultdf['F1'],'g')
    plt.xlabel("Iterations")
    plt.ylabel("F1 score")
    plt.legend(["F1 Reconstructing","Predictive"])
    plt.title("F1 score on SWaT test set during training.")
    
    
    
    plt.savefig("Reconstructing_predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['prediction_val_loss'],'b', fmt='--')
    plt.plot(pred_resultdf['prediction_val_loss'],'g', fmt='--')
    plt.xlabel("Iterations")
    plt.ylabel("Validation Loss")
    plt.legend(["Reconstructing","Predictive"])
    plt.title("Validation loss on SWaT during training.")
    
    plt.savefig("Reconstructing_predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')


def plot_F1_val_during_training_oneplot():
    

    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)
    fig, ax1 = plt.subplots(figsize=(20,10))
    
    ax2 = ax1.twinx() 
    
    ax1.plot(reconst_resultdf['F1'],'b', linewidth=3)
    ax1.plot(pred_resultdf['F1'],'g', linewidth=3)
    ax1.set_xlabel("Iterations", fontname="Arial", fontsize=fontsize)
    ax1.set_ylabel("F1 score", fontname="Arial", fontsize=fontsize)
    
    
    ax1.set_title("F1 score and validation loss on SWaT test set during training.", fontname="Arial", fontsize=fontsize)
    ##############
    ax2.plot(reconst_resultdf['prediction_val_loss'],'b--', linewidth=3)
    ax2.plot(pred_resultdf['prediction_val_loss'],'g--', linewidth=3)
    ax2.set_ylabel("Error", fontname="Arial", fontsize=fontsize)
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax1.legend(["Reconstruction F1","Masked Predictive F1"], loc='upper left', bbox_to_anchor=(1.1, 1))
    ax2.legend(["Reconstruction Val. Loss","Masked Predictive Val. Loss"], loc='upper left', bbox_to_anchor=(1.1, 0.8))
    #ax2.legend(["Reconstructing","Predictive"])
    #plt.title("Validation loss on SWaT during training.")
    
    fig.savefig("Reconstructing_predictive_model_valloss_during_training_merged.jpg",dpi=500,bbox_inches="tight",pad_inches=0.5)


def K_experiment(masks, repeats, epochs=200, model='transformer', task='swat', use_pretrained = False, save=True,lr=None,hook=None):
    '''
    In this experiment we measure the predictive and anomaly detection performance with varying dropout'''
    

    model_family = 'Predictive'
    repeated_masks = masks*repeats
    
    hyper_params =None 
    
    all_configs = []
    for m in repeated_masks:
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val")
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="../dataset/polycraft/normal/jsons", val_path="../dataset/polycraft/normal/val")
        config = make_model_config(model_family,model=model, task=task,masks=m,train_path="./", val_path="./",use_pretrained=use_pretrained,anomaly_filter=0.95, use_val_thresh=True)
        
        config.use_json_graph=False
        all_configs.append(config)
    #####################################################
    ####--- Predicive ---####
    if task == 'gridworld':
        dataset_root = "../dataset/polycraft"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "../dataset/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)
        
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    
    pred_resultsdf=K_experiment_measure_at_end(hyper_params,epochs,train_set,all_configs,train_loader,valid_loader,test_dataset=valid_loader,lr=lr,hook=hook)

    pred_resultsdf['K']=repeated_masks
    tmp = pred_resultsdf.groupby("K").mean()
    
    print(tmp['Localization F1'],tmp['F1'],file=sys.stderr)
    if save:
        pred_resultsdf.to_csv(f"K_experiment_prediction_{model}_{task}.csv")

    return 


def anomaly_filter_experiment(masks,anomaly_filters, repeats, epochs=200, model='transformer', task='swat', use_pretrained = False, save=True,lr=0.001):
    '''
    In this experiment we measure the predictive and anomaly detection performance with varying dropout'''
    

    model_family = 'Predictive'
    repeated_filters = anomaly_filters*repeats
    
   
    hyper_params =None # json.load(f)


    #hyper_params = {"out_dim":64,"n_layers":2, "h_dim":64, "dropout_rate":0.2}
    #hyper_params ={"output_dim":32,"n_layers":2,"emb_dim":64}
    #hyper_params ={"output_dim":32,"dropout_rate":0.2}
    
    all_configs = []
    m=masks
    for f in repeated_filters:
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val")
        #config = make_model_config(model_family,model=model, task=task,masks=m,train_path="../dataset/polycraft/normal/jsons", val_path="../dataset/polycraft/normal/val")
        config = make_model_config(model_family,model=model, task=task,masks=m,train_path="./", val_path="./",use_pretrained=use_pretrained,anomaly_filter=f,lr=lr)
        
        config.use_json_graph=False
        all_configs.append(config)
    #####################################################
    ####--- Predicive ---####
    if task == 'gridworld':
        dataset_root = "../dataset/polycraft"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "../dataset/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)
        
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    
    pred_resultsdf=K_experiment_measure_at_end(hyper_params,epochs,train_set,all_configs,train_loader,valid_loader,test_dataset=valid_loader,lr=lr)

    pred_resultsdf['anomaly_filter']=repeated_filters
    
    if save:
        pred_resultsdf.to_csv(f"anomaly_filter_experiment_prediction_{model}_{task}.csv")

    send_message("Finished anomaly filter experiment")

    return 
def K_plot(masks,model,repeats,task):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{task}.csv",index_col=0)
    

    pred_avgd = dropout_resultsdf.groupby(['K']).mean()
    pred_std = dropout_resultsdf.groupby(['K']).std()
    
    from scipy.stats import ttest_rel
    tmp = dropout_resultsdf.groupby(['K'])
    tmp1 = tmp.get_group(0)["F1"]
    tmp2 = tmp.get_group(masks[-1])["F1"]
    print("Pval:",ttest_rel(tmp1,tmp2))

    tmp1 = tmp.get_group(0)["Localization F1"]
    tmp2 = tmp.get_group(masks[-1])["Localization F1"]
    print("Pval:",ttest_rel(tmp1,tmp2))

    fig,ax = plt.subplots(figsize=(20,10))
    #plt.figure()
    
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    print("Step")
    print(y)
    print(err)
    #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    try:
        y = pred_avgd['Localization F1'].values
        err = pred_std['Localization F1'].values/np.sqrt(len(y))
        print("Node")
        print(y)
        print(err)
        
    except KeyError as e:
        print("No node scores")
        print(e)

    ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    #print(y,err)
    #ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
    
    y = pred_avgd['val_loss'].values
    err = pred_std['val_loss'].values/np.sqrt(len(y))
    #y= y[:idx+1]
    #err= err[:idx+1]
    
    

    ax2 = ax.twinx() 
    ax2.errorbar(masks,y,yerr =err,marker= 'o', color='g', linewidth=3)
    ax.set_xticks(masks)
    ax.set_xlabel("K")
    ax.set_ylabel("F1 score")
    ax2.set_ylabel("Prediction Error")
    
    ax.set_title(f"Monopoly F1 scores and Validation Loss of MLP with varying K")

    ax.legend(["Feature F1 Score"], loc = 'upper left',bbox_to_anchor=(0, 0.88))
    ax2.legend(["Prediction Error"], loc = 'upper left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"F1_loss_masking_predictive.jpg", dpi = 500, bbox_inches='tight')
    
    
    print(pred_avgd['time'],pred_std['time'])

    return
def anomaly_filter_plot(filters,model,repeats,task, step=True):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    dropout_resultsdf = pd.read_csv(f"anomaly_filter_experiment_prediction_{model}_{task}.csv",index_col=0)
    dropout_resultsdf = dropout_resultsdf.drop(index=dropout_resultsdf[dropout_resultsdf['anomaly_filter']==0.95].index)
    
    standard_filter=pd.read_csv(f"K_experiment_prediction_{model}_{task}.csv")
    standard_filter = standard_filter.drop(index=standard_filter[standard_filter['K']==0].index)
    standard_filter['anomaly_filter']=0.95
    dropout_resultsdf = dropout_resultsdf.append(standard_filter)

    pred_avgd = dropout_resultsdf.groupby(['anomaly_filter']).mean()
    pred_std = dropout_resultsdf.groupby(['anomaly_filter']).std()
    
    

    filters = pred_avgd.index.values
    fig,ax = plt.subplots(figsize=(20,10))
    #plt.figure()
    
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    if not step:
        prefix = 'Localization '
    else:
        prefix = ''
    print("Step")
    print(y)
    print(err)
    #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    if not step:
        try:
            y = pred_avgd[f'{prefix}F1'].values
            err = pred_std[f'{prefix}F1'].values/np.sqrt(len(y))
            print("Node")
            print(y)
            print(err)
            
        except KeyError as e:
            print("No node scores")     
            print(e)

    ax.set_ylim(0,1)

    ax.errorbar(filters,y,yerr =err, marker= 'o', color='b', linewidth=3)
    y = pred_avgd[f'{prefix}Precision'].values
    err = pred_std[f'{prefix}Precision'].values/np.sqrt(repeats)
    #ax.errorbar(filters,y,yerr =err, marker= 'o', color='r', linewidth=3)

    y = pred_avgd[f'{prefix}Recall'].values
    err = pred_std[f'{prefix}Recall'].values/np.sqrt(repeats)
    #ax.errorbar(filters,y,yerr =err, marker= 'o', color='orange', linewidth=3)
    
    y = pred_avgd['val_loss'].values
    count=10
    err = pred_std['val_loss'].values/np.sqrt(len(y))
    #y= y[:idx+1]
    #err= err[:idx+1]
    
    

    ax2 = ax.twinx() 
    ax2.errorbar(filters,y,yerr =err,marker= 'o', color='g', linewidth=3)
    xlabels = filters.tolist()
    xlabels[-1]=">1"#"No Filter"
    
    ax.set_xticks(filters,labels=xlabels)
    ax.set_xlabel("$r$")
    ax.set_ylabel("Score")
    ax2.set_ylabel("Validation Loss")

    ax.set_title(f"{task_name[task]} F1 scores and Validation Loss of Masked {model_name[model]} with varying $r$")
    
    #ax.legend([f"{prefix}F1 Score", f'{prefix}Precision',f'{prefix}Recall'], loc = 'lower left',bbox_to_anchor=(0, 0.08))
    ax.legend([f"{prefix}F1 Score"], loc = 'lower left',bbox_to_anchor=(0, 0.08))
    
    ax2.legend(["Validation Loss"], loc = 'lower left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"filter_F1_loss_masking_predictive_{model}_{task}.jpg", dpi = 500, bbox_inches='tight')
    
    
    print(pred_avgd['time'],pred_std['time'])

    return

def anomaly_filter_plot_unified(filters,model,repeats,tasks, step=True):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    fig,ax = plt.subplots(figsize=(20,10))
    task_colors = {"monopoly":'b',"gridworld":'g'}
    legend = []
    for task in tasks:
        dropout_resultsdf = pd.read_csv(f"anomaly_filter_experiment_prediction_{model}_{task}.csv",index_col=0)
        dropout_resultsdf = dropout_resultsdf.drop(index=dropout_resultsdf[dropout_resultsdf['anomaly_filter']==0.95].index)
        
        standard_filter=pd.read_csv(f"K_experiment_prediction_{model}_{task}.csv")
        standard_filter = standard_filter.drop(index=standard_filter[standard_filter['K']==0].index)
        standard_filter['anomaly_filter']=0.95
        dropout_resultsdf = dropout_resultsdf.append(standard_filter)

        pred_avgd = dropout_resultsdf.groupby(['anomaly_filter']).mean()
        pred_std = dropout_resultsdf.groupby(['anomaly_filter']).std()
        
        

        filters = pred_avgd.index.values
        #plt.figure()
        
        y = pred_avgd['F1'].values
        err = pred_std['F1'].values/np.sqrt(repeats)
        if not step:
            prefix = 'Localization '
        else:
            prefix = ''
        print("Step")
        print(y)
        print(err)
        #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
        if not step:
            try:
                y = pred_avgd[f'{prefix}F1'].values
                err = pred_std[f'{prefix}F1'].values/np.sqrt(len(y))
                print("Node")
                print(y)
                print(err)
                
            except KeyError as e:
                print("No node scores")     
                print(e)

        #ax.set_ylim(0,1)

        ax.errorbar(filters,y,yerr =err, marker= 'o', color=task_colors[task], linewidth=3)
        y = pred_avgd[f'{prefix}Precision'].values
        err = pred_std[f'{prefix}Precision'].values/np.sqrt(repeats)
        #ax.errorbar(filters,y,yerr =err, marker= 'o', color='r', linewidth=3)

        y = pred_avgd[f'{prefix}Recall'].values
        err = pred_std[f'{prefix}Recall'].values/np.sqrt(repeats)
        #ax.errorbar(filters,y,yerr =err, marker= 'o', color='orange', linewidth=3)
        
        y = pred_avgd['val_loss'].values
        count=10
        err = pred_std['val_loss'].values/np.sqrt(len(y))
        #y= y[:idx+1]
        #err= err[:idx+1]
        
        

        #ax2 = ax.twinx() 
        #ax2.errorbar(filters,y,yerr =err,marker= 'o', color='g', linewidth=3)
        xlabels = filters.tolist()
        xlabels[-1]=">1"#"No Filter"
        legend.append(f"{task_name[task]} {prefix}F1 Score")
        
        
    ax.set_xticks(filters,labels=xlabels)
    ax.set_xlabel("$r$")
    ax.set_ylabel("Score")
        #ax2.set_ylabel("Validation Loss")

    ax.set_title(f"Polycraftv2 and Monopoly F1 scores of Masked {model_name[model]} with varying $r$")
    
    #ax.legend([f"{prefix}F1 Score", f'{prefix}Precision',f'{prefix}Recall'], loc = 'lower left',bbox_to_anchor=(0, 0.08))
    ax.legend(legend, loc = 'upper left')
    
    #ax2.legend(["Validation Loss"], loc = 'lower left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"filter_F1_loss_masking_predictive_{model}_monopoly_gridworld.jpg", dpi = 250, bbox_inches='tight')
    
    

    return

def widths_experiment(model,task,masks,widths, repeats, epochs=200,save=True):
    '''
    In this experiment we measure reconstructing and predictive performance for different widths. Our purpose is to show that predictive models are more reliable because they are not given the ground truth in their input. 
    Our hypothesis is that reconstruction model performance is going to suffer as the information bottleneck gets smaller, but this is difficult to control and is not reflected in the validation loss.
    '''
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False,lr=1e-3)
    #trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=config.winsize,subsample=config.subsample)
    #trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    #valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)

    if task == 'gridworld':
        dataset_root = "../dataset/polycraft"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "../dataset/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)
        
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    repeated_widths = widths*repeats

    if model == 'mlp':
        hyper_param_sets = [{"out_dim":o,"n_layers":1, 'h_dim':o, 'dropout_rate':0.2}  for o in repeated_widths for l in [2]]
    elif model == 'transformer':
        hyper_param_sets = [{"output_dim":o,"n_layers":l, 'emb_dim':o} for o in repeated_widths for l in [2]]
    
    ####--- Reconstruction ---####

    reconst_resultsdf = hyperparams_run(hyper_param_sets,epochs,train_set,config,train_loader,valid_loader,save=False,test=True)

    reconst_resultsdf['widths']=repeated_widths
    if save:
        reconst_resultsdf.to_csv(f"Widths_experiment_reconstruction_{model}_{task}.csv")


    #####################################################
    ####--- Predicive ---####

    model_family = 'Predictive'
    config = make_model_config(model_family,task,model,masks=masks,use_pretrained=False)
    pred_resultsdf = hyperparams_run(hyper_param_sets,epochs,train_set,config,train_loader,valid_loader,save=False,test=True)

    pred_resultsdf['widths']=repeated_widths

    send_message(f"Widths experiment finished for {model}_{task}")
    if save:
        
        pred_resultsdf.to_csv(f"Widths_experiment_prediction_{model}_{task}.csv")
    

    return 
def width_plot(model,task,widths,step=True):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    reconst_resultsdf = pd.read_csv(f"Widths_experiment_reconstruction_{model}_{task}.csv",index_col=0)
    pred_resultsdf = pd.read_csv(f"Widths_experiment_prediction_{model}_{task}.csv",index_col=0)

    pred_avgd = pred_resultsdf.groupby(['widths']).mean()
    pred_std = pred_resultsdf.groupby(['widths']).std()

    reconst_avgd = reconst_resultsdf.groupby(['widths']).mean()
    recons_std = reconst_resultsdf.groupby(['widths']).std()
    
    x_axis = widths#[2,4,5,6,7]
    
    fig,ax = plt.subplots(figsize=(20,10))
    #ax.set_xscale('log',basex=2)
    #plt.figure()
    if not step:
        prefix = 'Localization '
    y = reconst_avgd[f'{prefix}F1'].values
    err = recons_std[f'{prefix}F1'].values/np.sqrt(len(y))
    #err =[0]*len(y)
    print(x_axis,y,err)
    ax.errorbar(x_axis,y,yerr =err, marker= 'o', color='b', linewidth=3)

    #ax2 = ax.twinx() 
    y = pred_avgd[f'{prefix}F1'].values
    err = pred_std[f'{prefix}F1'].values/np.sqrt(len(y))
    #err =[0]*len(y)
    print(pred_std)
    print(err)
    ax.errorbar(x_axis,y,yerr =err,marker= 'o', color='g', linewidth=3)
    
    xticks = widths#[f"$2^{w}$" for w in range(1,8)]
    print("-------------",x_axis)
    print(widths)
    print(xticks)
    #ax.set_xticklabels(xticks)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.set_xlabel("Embedding Dimension")
    ax.set_ylabel(f"{prefix}F1 score")
    #ax2.set_ylabel("Validation Loss")
    
    ax.legend(["Reconstruction Model","Masked Predictive Model"],loc='lower left')
    ax.set_title(f"F1 scores of Models with varying embedding dimension")

    
    fig.savefig(f"F1_width_reconst_predictive_{model}_{task}.jpg", dpi = 500, bbox_inches='tight')
    


    return


def finetune(task,model):
    
    model_family = 'Predictive'
    
    with open(f"/home/plymper/graph-anomaly-detection-clean/models/model_configs/finetuning/{model}_0.json") as f:
        hyper_params_sets = json.load(f)
    m=26
    
    config = make_model_config(model_family,model=model, task=task,masks=m,train_path="./", val_path="./",use_pretrained=False,lr=0.0001)
    
    config.use_json_graph=False
    
    #####################################################
    ####--- Predicive ---####
    if task == 'gridworld':
        dataset_root = "../dataset/polycraft"
        GameDataset = gridworld_jsondata.GameDataset
    elif task == 'monopoly':
        dataset_root = "../dataset/monopoly"
        GameDataset = monopoly_jsondata.GameDataset

    DATASET_ROOT = Path(dataset_root)
        
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=config.winsize, mode="training")
    train_loader = GraphDataLoader(train_set, batch_size=config.bsz)

    valid_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_valid_data.pkl"), concat_steps=config.winsize, mode="validation")
    valid_loader = GraphDataLoader(valid_set, batch_size=50)

    
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    pred_resultsdf = hyperparams_run(hyper_params_sets,200,train_set,config,train_loader,valid_loader,test=True,save=True)
    
    

    pred_resultsdf.to_csv(f"Finetune_{model}_{task}.csv")


def epoch_time_measure():
    repeats=1
    epochs=30
    task = 'monopoly'
    task_masks = {"gridworld":[0,26],
                  "monopoly":[0,9],
                  "swat":[0,50],
                  "wadi":[0,127]
                }
    masks = task_masks[task]

    for model in ["GDN"]:

        K_experiment(masks,repeats,model=model,epochs=epochs, task=task,use_pretrained=False,save=False,lr=0.001,hook = lambda x: "save_times")

    


    return


def main():
    import time
    
    torch.manual_seed(0)
    np.random.seed(0)

    from slack_message import send_message

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 35}

    matplotlib.rc('font', **font)
    
    task_masks = {"gridworld":[0,26],
                  "monopoly":[0,9],
                  "swat":[0,50],
                  "wadi":[0,127]
                }


    repeats=10
    try:        
        task = 'gridworld' #change to monopoly for monopoly
        masks =task_masks[task]
        for model in ["mlp",'GDN',"transformer"]:
            K_experiment(masks,repeats,model=model,epochs=200, task=task,use_pretrained=True,save=True,lr=0.001)
            #K_plot(masks,model,repeats,task)
    except Exception as e :
        send_message("Err K")
        raise e
    
    
    
    send_message("Done json experiment")



fontsize=40
if __name__=='__main__':

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 30}

    matplotlib.rc('font', **font)
    
    main()   
    
    #epoch_time_measure()
    #exit()
    
    #anomaly_filters = [1.1,1.0,0.95,0.5,0.0]
    
    #anomaly_filter_experiment(masks,anomaly_filters, repeats, epochs=200, model=model, task=task, use_pretrained = True, save=True,lr=0.001)
    
    #anomaly_filter_plot_unified(anomaly_filters,model,repeats,["gridworld","monopoly"],step=False)