import argparse, os, sys

sys.path.append('../')
from dataset.sensordata import make_sensor_datasets 
from dataset.gamedata import GameDataset
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from Autoregressive_model import AutoregressiveModel
from train import train
from scipy.stats import iqr
from test_sensors import test_with_normalized_loss
from test_games import test
import types
from io import StringIO
import re
import pandas as pd
import scipy.stats as stats
import tqdm
import matplotlib.pyplot as plt
import ast
import matplotlib
from dataset.json_graph import JsonToGraph

def collect_result(output, prefix=''):
    print(f"{prefix}Best F1-Score.*",file=sys.stderr)
    f1 = float(re.findall(f"{prefix}Best F1-Score.*", output)[0].split(' ')[-1])
    valloss = float(re.findall("best.*", output)[0].split(' ')[-1])
    return {"F1":f1,'val_loss':valloss}

def collect_result_byline(output):
    lines=  output.split("\n")
    scores = []
    for l in lines:
        if "validation scores :" in l and 'hook' in l:
            scores.append(l[len("validation scores :"):].strip())
    return scores


def make_model_config(model_type = "Reconstructing",task='swat', model='transformer', train_path = "./", val_path = "./", masks=50):

    if model_type =='Reconstructing':
        masks = 0
        recon = True
    elif model_type == "Predictive":
        masks = masks
        recon= False

    config = types.SimpleNamespace()
    config.emb_dim = 32
    config.train_dataset_dir=train_path
    config.test_dataset_dir = "./" # ignored
    config.val_dataset_dir = val_path
    config.model_type = model

    config.bsz=200
    config.model_save_dir = f'./hyperparam_plots_results'
    config.task = task
    config.n_masks = masks  #set to 0 for reconstruction
    config.reconstructing=recon # makes the model a reconstruction type model
    config.subsample = 15
    config.gpu=0
    config.validation_step = 250
    config.winsize=5
    config.model_config_root="../"
    

    return config

def train_hook(test_dataset):
    def f1_calc(model):
        f1= test_with_normalized_loss(test_dataset,model)
        return f1
    return f1_calc


def experiment_measure_during_training(hyper_param_sets,n_train_epochs, trainset,config,trainloader,valloader,test_dataset, hook=None):
    val_losses =[]
    result_sets = []
    f1_scores = []
    for hyper_params in hyper_param_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead

        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params)
        #train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)#train_hook(test_dataset))
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=hook(test_dataset))
        
        
        test_with_normalized_loss(test_dataset,model)
        output = output_buffer.getvalue()
        results = collect_result(output)
        lines=collect_result_byline(output)
        lines = [ast.literal_eval(line) for line in lines]
        

        results['params'] = str(hyper_params)
        result_sets.extend(lines)


    resultdf=pd.DataFrame(result_sets)
    resultdf=resultdf.rename(columns={"hook_result":"F1"})

    print(resultdf, file=sys.stderr)
    return resultdf

def experiment_measure_at_end(hyper_param_sets,n_train_epochs, trainset,config,trainloader,valloader,test_dataset):
    
    result_sets = []
    for hyper_params in hyper_param_sets:
        sys.stdout = output_buffer = StringIO() # ugly but saves stdout to string buffer instead

        model = AutoregressiveModel(trainset.num_nodes, trainset.node_feature_dim, trainset.node_info,model_type= config.model_type, config = config,train_dataset = trainset.all_node_array, device = config.gpu, feature_list = trainset.feature_list, hyper_params=hyper_params)
        train(model, trainloader, valloader=valloader, epochs=n_train_epochs, hook=None)
        
        if config.task in ['wadi','swat']:
            test_with_normalized_loss(test_dataset,model)
            output = output_buffer.getvalue()
            results = collect_result(output)
        
        else:
            
            test(config)
            output = output_buffer.getvalue()
            results = collect_result(output, prefix = "Per graph max normalized loss ")
            results1 = collect_result(output, prefix = "Per node normalized loss ")
            results['Node F1']= results1['F1']
        
        

        results['params'] = str(hyper_params)
        result_sets.append(results)
    resultdf = pd.DataFrame(result_sets)
    
    return resultdf


def valloss_f1_experiment(epochs=200):
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family)
    trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=5,subsample=config.subsample)
    trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    
    #### --- Reconstructing --- ####

    hyper_params = [{"output_dim":64,"n_layers":1, 'emb_dim':64}]

    reconst_resultdf = experiment_measure_during_training(hyper_params,epochs, trainset,config,trainloader,valloader,test_dataset, hook=train_hook)
    print(reconst_resultdf,file=sys.stderr)
    
    reconst_resultdf.to_csv("Reconst_Val_F1_during_training.csv")

    #### --- Predictive --- ####

    model_family = 'Predictive'
    config = make_model_config(model_family)

    pred_resultdf= experiment_measure_during_training(hyper_params,epochs, trainset,config,trainloader,valloader,test_dataset, hook=train_hook)
    
    pred_resultdf.to_csv("Pred_Val_F1_during_training.csv")


def plot_F1_and_vallosses_during_training_individual():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.title("F1 score on SWaT test set for reconstructing model during training.")
    
    

    plt.savefig("Reconstructing_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for reconstructing model during training.")
    plt.plot(reconst_resultdf['prediction_val_loss'],'b')
    plt.savefig("Reconstructing_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(pred_resultdf['F1'],'g')
    plt.title("F1 score on SWaT test set for predictive model during training.")
    
    plt.savefig("Predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    #############
    plt.figure(figsize=(10,7))
    plt.title("Validation loss on SWaT for predictive model during training.")
    plt.plot(pred_resultdf['prediction_val_loss'],'g')
    plt.savefig("Predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')
    
def plot_F1_and_vallosses_during_training():
    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)

    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['F1'],'b')
    plt.plot(pred_resultdf['F1'],'g')
    plt.xlabel("Iterations")
    plt.ylabel("F1 score")
    plt.legend(["F1 Reconstructing","Predictive"])
    plt.title("F1 score on SWaT test set during training.")
    
    
    
    plt.savefig("Reconstructing_predictive_model_f1_during_training.jpg",dpi=500, bbox_inches='tight')
    ##############
    plt.figure(figsize=(10,7))
    plt.plot(reconst_resultdf['prediction_val_loss'],'b', fmt='--')
    plt.plot(pred_resultdf['prediction_val_loss'],'g', fmt='--')
    plt.xlabel("Iterations")
    plt.ylabel("Validation Loss")
    plt.legend(["Reconstructing","Predictive"])
    plt.title("Validation loss on SWaT during training.")
    
    plt.savefig("Reconstructing_predictive_model_valloss_during_training.jpg",dpi=500, bbox_inches='tight')


def plot_F1_val_during_training_oneplot():
    

    reconst_resultdf = pd.read_csv("Reconst_Val_F1_during_training.csv",index_col=0)
    pred_resultdf = pd.read_csv("Pred_Val_F1_during_training.csv",index_col=0)
    fig, ax1 = plt.subplots(figsize=(20,10))
    
    ax2 = ax1.twinx() 
    
    ax1.plot(reconst_resultdf['F1'],'b', linewidth=3)
    ax1.plot(pred_resultdf['F1'],'g', linewidth=3)
    ax1.set_xlabel("Iterations", fontname="Arial", fontsize=fontsize)
    ax1.set_ylabel("F1 score", fontname="Arial", fontsize=fontsize)
    
    
    ax1.set_title("F1 score and validation loss on SWaT test set during training.", fontname="Arial", fontsize=fontsize)
    ##############
    ax2.plot(reconst_resultdf['prediction_val_loss'],'b--', linewidth=3)
    ax2.plot(pred_resultdf['prediction_val_loss'],'g--', linewidth=3)
    ax2.set_ylabel("Error", fontname="Arial", fontsize=fontsize)
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax1.legend(["Reconstruction F1","Masked Predictive F1"], loc='upper left', bbox_to_anchor=(1.1, 1))
    ax2.legend(["Reconstruction Val. Loss","Masked Predictive Val. Loss"], loc='upper left', bbox_to_anchor=(1.1, 0.8))
    #ax2.legend(["Reconstructing","Predictive"])
    #plt.title("Validation loss on SWaT during training.")
    
    fig.savefig("Reconstructing_predictive_model_valloss_during_training_merged.jpg",dpi=500,bbox_inches="tight",pad_inches=0.5)


def widths_monopoly_experiment(widths, repeats, epochs=200):
    '''
    In this experiment we measure reconstructing and predictive performance for different widths. Our purpose is to show that predictive models are more reliable because they are not given the ground truth in their input. 
    Our hypothesis is that reconstruction model performance is going to suffer as the information bottleneck gets smaller, but this is difficult to control and is not reflected in the validation loss.
    '''
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family,model='mlp', task='monopoly',masks=19,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val")

    config.use_json_graph=False
    trainset = GameDataset(data_path=config.train_dataset_dir, concat_steps=config.winsize, mode='training',ignore_intermediate_nodes= not config.use_json_graph, task = config.task)
    trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    val_dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=config.winsize, mode='validation',ignore_intermediate_nodes=not config.use_json_graph, task = config.task)
    valloader=GraphDataLoader(val_dataset, batch_size=config.bsz, shuffle=False)

    #trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=5,subsample=config.subsample)
    #trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    #valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)
    
    repeated_widths = widths*repeats
    
    hyper_param_sets = [{"out_dim":o,"n_layers":1, 'h_dim':o, 'dropout_rate':0.2}  for o in repeated_widths for l in [2]]
    
    ####--- Reconstruction ---####
    
    reconst_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset=None)

    reconst_resultsdf['widths']=repeated_widths
    reconst_resultsdf.to_csv("Monopoly_Widths_experiment_reconstruction.csv")


    #####################################################
    ####--- Predicive ---####

    model_family = 'Predictive'
    config = make_model_config(model_family,model='mlp', task='monopoly',masks=19,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val")
    #config = make_model_config(model_family)
    pred_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset=None)

    pred_resultsdf['widths']=repeated_widths
    

    pred_resultsdf.to_csv("Monopoly_Widths_experiment_prediction.csv")

    return 


def widths_experiment(widths, repeats, epochs=200):
    '''
    In this experiment we measure reconstructing and predictive performance for different widths. Our purpose is to show that predictive models are more reliable because they are not given the ground truth in their input. 
    Our hypothesis is that reconstruction model performance is going to suffer as the information bottleneck gets smaller, but this is difficult to control and is not reflected in the validation loss.
    '''
    model_family = 'Reconstructing'
    
    config = make_model_config(model_family)
    trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=5,subsample=config.subsample)
    trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)

    repeated_widths = widths*repeats

    hyper_param_sets = [{"output_dim":o,"n_layers":l, 'emb_dim':o} for o in repeated_widths for l in [2]]
    
    ####--- Reconstruction ---####

    reconst_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset)

    reconst_resultsdf['widths']=repeated_widths
    reconst_resultsdf.to_csv("Widths_experiment_reconstruction.csv")


    #####################################################
    ####--- Predicive ---####

    model_family = 'Predictive'
    config = make_model_config(model_family)
    pred_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,trainset,config,trainloader,valloader,test_dataset)

    pred_resultsdf['widths']=repeated_widths
    

    pred_resultsdf.to_csv("Widths_experiment_prediction.csv")

    return 

def dropout_experiment(dropouts, repeats, epochs=200):
    '''
    In this experiment we measure the predictive and anomaly detection performance with varying dropout'''
    

    model_family = 'Predictive'
    
    config = make_model_config(model_family,model='mlp', task='monopoly',masks=19,train_path="/home/yli52/dataset/monopoly/train_new", val_path="/home/yli52/dataset/monopoly/val")
    
    #####################################################
    ####--- Predicive ---####
    config.use_json_graph=False
    dataset = GameDataset(data_path=config.train_dataset_dir, concat_steps=config.winsize, mode='training',ignore_intermediate_nodes= not config.use_json_graph, task = config.task)
    trainloader = GraphDataLoader(dataset, batch_size=config.bsz, shuffle=False)

    val_dataset = GameDataset(data_path=config.val_dataset_dir, concat_steps=config.winsize, mode='validation',ignore_intermediate_nodes=not config.use_json_graph, task = config.task)
    valloader=GraphDataLoader(val_dataset, batch_size=config.bsz, shuffle=False)
    
    #trainset,val,test_dataset = make_sensor_datasets(name = config.task, concat_steps=5,subsample=config.subsample)
    #trainloader = GraphDataLoader(trainset, batch_size=config.bsz, shuffle=False)
    #valloader=GraphDataLoader(val, batch_size=config.bsz, shuffle=False)

    repeated_dropouts = dropouts*repeats
    #hyper_param_sets = [{"output_dim":64,"n_layers":2, 'emb_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    hyper_param_sets = [{"out_dim":64,"n_layers":1, 'h_dim':64, 'dropout_rate':i} for i in repeated_dropouts]
    pred_resultsdf = experiment_measure_at_end(hyper_param_sets,epochs,dataset,config,trainloader,valloader,test_dataset=None)

    pred_resultsdf['dropout_rate']=repeated_dropouts
    

    pred_resultsdf.to_csv("dropout_experiment_prediction.csv")

    return 
def dropout_plot(dropouts,repeats):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    dropout_resultsdf = pd.read_csv("dropout_experiment_prediction.csv",index_col=0)
    

    pred_avgd = dropout_resultsdf.groupby(['dropout_rate']).mean()
    pred_std = dropout_resultsdf.groupby(['dropout_rate']).std()
    
    
    fig,ax = plt.subplots(figsize=(20,10))
    #plt.figure()
    #y = pred_avgd['F1'].values
    #err = pred_std['F1'].values/np.sqrt(repeats)
    #
    #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    y = pred_avgd['Node F1'].values
    err = pred_std['Node F1'].values/np.sqrt(repeats)

    ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    
    #y = pred_avgd['F1'].values
    #err = pred_std['F1'].values/np.sqrt(repeats)
    #print(y,err)
    #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
    
    y = pred_avgd['val_loss'].values
    err = pred_std['val_loss'].values/np.sqrt(repeats)
    #y= y[:idx+1]
    #err= err[:idx+1]
    
    
    print(y)
    print(err)
    print(dropouts)

    ax2 = ax.twinx() 
    ax2.errorbar(dropouts,y,yerr =err,marker= 'o', color='g', linewidth=3)
    ax.set_xticks(dropouts)
    ax.set_xlabel("Dropout Rate")
    ax.set_ylabel("F1 score")
    ax2.set_ylabel("Prediction Error")
    
    ax.set_title(f"Monopoly F1 scores and Validation Loss of MLP with varying Dropout")

    ax.legend(["Feature F1 Score"], loc = 'upper left',bbox_to_anchor=(0, 0.88))
    ax2.legend(["Prediction Error"], loc = 'upper left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"F1_loss_dropout_predictive.jpg", dpi = 500, bbox_inches='tight')
    


    return

def width_plot(widths):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    
    
    reconst_resultsdf = pd.read_csv("Widths_experiment_reconstruction.csv",index_col=0)
    pred_resultsdf = pd.read_csv("Widths_experiment_prediction.csv",index_col=0)

    pred_avgd = pred_resultsdf.groupby(['widths']).mean()
    pred_std = pred_resultsdf.groupby(['widths']).std()

    reconst_avgd = reconst_resultsdf.groupby(['widths']).mean()
    recons_std = reconst_resultsdf.groupby(['widths']).std()
    
    x_axis = [2,4,5,6,7]
    
    fig,ax = plt.subplots(figsize=(20,10))
    #ax.set_xscale('log',basex=2)
    #plt.figure()
    y = reconst_avgd['F1'].values
    err = recons_std['F1'].values/np.sqrt(len(y))
    print(reconst_resultsdf)
    ax.errorbar(x_axis,y,yerr =err, marker= 'o', color='b', linewidth=3)

    #ax2 = ax.twinx() 
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(len(y))
    print(pred_std)
    print(err)
    ax.errorbar(x_axis,y,yerr =err,marker= 'o', color='g', linewidth=3)
    
    xticks = [f"$2^{w}$" for w in range(1,8)]
    print("-------------",x_axis)
    print(widths)
    print(xticks)
    ax.set_xticklabels(xticks)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.set_xlabel("Embedding Dimension")
    ax.set_ylabel("F1 score")
    #ax2.set_ylabel("Validation Loss")
    
    ax.legend(["Reconsruction Model","Masked Predictive Model"],loc='lower left')
    ax.set_title(f"F1 scores of Models with varying embedding dimension")

    
    fig.savefig(f"F1_width_reconst_predictive.jpg", dpi = 500, bbox_inches='tight')
    


    return

fontsize=40
if __name__=='__main__':

    import time
    
    widths = [4,16,32,64,128]
    repeats = 10

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 35}

    matplotlib.rc('font', **font)
    
    #widths_monopoly_experiment(widths,repeats,epochs=200)
    #width_plot(widths)
    
    #valloss_f1_experiment(200)
    #plot_F1_val_during_training_oneplot()

    #dropouts = [0.0,0.1,0.2,0.3,0.4,0.5]
    dropouts = [0.0,0.2,0.4,0.6,0.8]

    repeats=10
    #dropout_experiment(dropouts,repeats, epochs=50)
    dropout_plot(dropouts,repeats)
    
    from slack_message import send_message
    #send_message("Done dropout")



   