import matplotlib.pyplot as plt
import sys
import pandas as pd
import numpy as np
import matplotlib

def K_plot(masks,model,repeats,task):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    t=task

    unmasked= []
    masked=[]
    masked_err = []
    unmasked_err = []
    sparse_levels=  [10,15,20,25]
    for sparsity in sparse_levels:
        if sparsity==15:
            dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{t}.csv",index_col=0)
        else:
            dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{t}_{sparsity}.csv",index_col=0)
        

        pred_avgd = dropout_resultsdf.groupby(['K']).mean()
        pred_std = dropout_resultsdf.groupby(['K']).std()
        
        from scipy.stats import ttest_rel
        tmp = dropout_resultsdf.groupby(['K'])
        tmp1 = tmp.get_group(0)["F1"]
        tmp2 = tmp.get_group(masks[-1])["F1"]
        print("Pval:",ttest_rel(tmp1,tmp2))

        fig,ax = plt.subplots(figsize=(20,10))
        #plt.figure()
        
        y = pred_avgd['F1'].values
        masked.append(y[1])
        unmasked.append(y[0])
        err = pred_std['F1'].values/np.sqrt(repeats)

        masked_err.append(err[1])
        unmasked_err.append(err[0])
        print(y,err)

        print("Step")
        print(y)
        print(err)
        #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
        try:
            y = pred_avgd['Node F1'].values
            err = pred_std['Node F1'].values/np.sqrt(len(y))
            print("Node")
            print(y)
            print(err)
            
        except KeyError as e:
            print("No node scores")
            print(e)

    ax.errorbar(sparse_levels,masked,yerr =masked_err, marker= 'o', color='b', linewidth=3)
    y = pred_avgd['F1'].values
    err = pred_std['F1'].values/np.sqrt(repeats)
    #print(y,err)
    #ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
    
    y = pred_avgd['val_loss'].values
    err = pred_std['val_loss'].values/np.sqrt(len(y))
    #y= y[:idx+1]
    #err= err[:idx+1]
    
    

    ax2 = ax.twinx() 
    ax.errorbar(sparse_levels,unmasked,yerr =unmasked_err,marker= 'o', color='g', linewidth=3)
    ax.set_xticks(sparse_levels)
    ax.set_xlabel("Subsampling rate")
    ax.set_ylabel("F1 score")
    
    task_name={"swat":'SWaT','wadi':'WADI'}
    model_name={"transformer":'OAT','mlp':"MLP",'GDN':'GDN'}
    ax.set_title(f"{task_name[task]} F1 scores of {model_name[model]} with varying subsampling rate")

    ax.legend(["Masked","Unmasked"])
    #ax2.legend(["Unmasked"], loc = 'upper left')# bbox_to_anchor=(0.3,0.9))
    fig.savefig(f"sparsity_{task}_{model}.jpg", dpi = 500, bbox_inches='tight')
    
    
    print(pred_avgd['time'],pred_std['time'])

    return



task_name={"swat":'SWaT','wadi':'WADI'}
model_name={"transformer":'OAT','mlp':"MLP",'GDN':'GDN'}
model_colors = {"transformer":"blue","mlp":"red","GDN":"green"}

def merged_plot(masks,models,repeats,task):
    #####################################################
    ####--- Plot ---####
    sys.stdout = sys.__stdout__
    t=task

    
    sparse_levels=  [10,15,20,25]
    fig,ax = plt.subplots(figsize=(20,10))
    ax2 = ax.twinx() 
    for model in models:
        unmasked= []
        masked=[]
        masked_err = []
        unmasked_err = []
        for sparsity in sparse_levels:
            if sparsity==15:
                dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{t}.csv",index_col=0)
            else:
                dropout_resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{t}_{sparsity}.csv",index_col=0)
            

            pred_avgd = dropout_resultsdf.groupby(['K']).mean()
            pred_std = dropout_resultsdf.groupby(['K']).std()
            
            from scipy.stats import ttest_rel
            tmp = dropout_resultsdf.groupby(['K'])
            tmp1 = tmp.get_group(0)["F1"]
            tmp2 = tmp.get_group(masks[-1])["F1"]
            print("Pval:",ttest_rel(tmp1,tmp2))

            
            #plt.figure()
            
            y = pred_avgd['F1'].values
            masked.append(y[1])
            unmasked.append(y[0])
            err = pred_std['F1'].values/np.sqrt(repeats)

            masked_err.append(err[1])
            unmasked_err.append(err[0])
            print(y,err)

            print("Step")
            print(y)
            print(err)
            #ax.errorbar(dropouts,y,yerr =err, marker= 'o', color='b', linewidth=3)
            try:
                y = pred_avgd['Node F1'].values
                err = pred_std['Node F1'].values/np.sqrt(len(y))
                print("Node")
                print(y)
                print(err)
                
            except KeyError as e:
                print("No node scores")
                print(e)

        print(sparse_levels,masked)
        ax.errorbar(sparse_levels,masked,yerr =masked_err, marker= 'o', linewidth=3,color = model_colors[model])
        y = pred_avgd['F1'].values
        err = pred_std['F1'].values/np.sqrt(repeats)
        #print(y,err)
        #ax.errorbar(masks,y,yerr =err, marker= 'o', color='b', linewidth=3)
        
        y = pred_avgd['val_loss'].values
        err = pred_std['val_loss'].values/np.sqrt(len(y))
        #y= y[:idx+1]
        #err= err[:idx+1]
        
        

        
        ax.errorbar(sparse_levels,unmasked,yerr =unmasked_err,marker= 'o', linewidth=3,linestyle= (0,(10,4)),color=model_colors[model])
        
        



    ax.set_title(f"{task_name[task]} F1 scores with varying subsampling rate")
    ax.set_xticks(sparse_levels)
    ax.set_xlabel("Subsampling gap")
    ax.set_ylabel("F1 score")

    leg = []
    for model in models:
        leg.append("Masked "+model_name[model])
        leg.append(model_name[model])
    ax.legend(leg)
        #ax2.legend(["Unmasked"], loc = 'upper left')# bbox_to_anchor=(0.3,0.9))

    fig.savefig(f"sparsity_{task}_all.jpg", dpi = 500, bbox_inches='tight')
        
        
        

    return






if __name__=='__main__':

    task_masks = {"monopoly":[0,19],
                  "swat":[0,50],
                  "wadi":[0,127]
                }
    task = 'swat'

    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 28}

    matplotlib.rc('font', **font)

    #for model in ['transformer','mlp','GDN']:
    #    for task in ['swat','wadi']:
    #        K_plot(task_masks[task],model,10,task)
    for task in ['swat','wadi']:
        merged_plot(task_masks[task],['transformer','mlp','GDN'],10,task)