import pandas as pd
import numpy as np
from tabulate import tabulate

import table_prototype
model_name={'mlp':"MLP","GDN":"GDN","transformer":"OAT"}
task_name = {"gridworld":"Polycraftv2","swat":"SWAT","wadi":"WADI","monopoly":"Monopoly"}

def format_score(y,err,sig=False):
    s= np.round(y,decimals=3)
    err= np.round(err,decimals=3)
    if sig:
        return f"\\textbf{{{s}}}±{err}"
    else:
        return f"{s}±{err}"

def get_results(task,model,repeats=10,step=True, with_names= True):
    
    if task in ['swat','wadi']:
        score_key = "F1"
    elif step:
        score_key='Best F1'
    else:
        score_key = 'Localization Best F1'
    
    resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{task}.csv",index_col=0)
    
    masks = np.unique(resultsdf['K'])
    pred_avgd = resultsdf.groupby(['K']).mean()
    pred_std = resultsdf.groupby(['K']).std()
    
    from scipy.stats import ttest_rel
    tmp = resultsdf.groupby(['K'])
    tmp1 = tmp.get_group(0)[score_key]
    tmp2 = tmp.get_group(masks[-1])[score_key]
    pval = ttest_rel(tmp1,tmp2).pvalue
    

    #plt.figure()
    
    y = pred_avgd[score_key].values
    err = pred_std[score_key].values/np.sqrt(repeats)
    prefix = ["","Masked "]
    print(model,task,step,pval)
    res = [[model_name[model],format_score(y[0],err[0],False)],[f"Masked {model_name[model]}",format_score(y[1],err[1],pval<0.05)]]
    if not with_names:
        res = [i[1] for i in res]
    #res = {model_name[model]:format_score(y[0],err[0],False),f"Masked {model_name[model]}":format_score(y[1],err[1],pval<0.05)}
        
    return res

def get_full_results(task,model,repeats=10,step=True):
    
    if task in ['swat','wadi']:
        score_keys = ['F1','Precision','Recall']
        prefix=''
    else:
        prefix = '' if step else 'Localization '
        score_keys = [f"{prefix}Best F1",f"{prefix}Precision",f"{prefix}Recall"]
    
    
    resultsdf = pd.read_csv(f"K_experiment_prediction_{model}_{task}.csv",index_col=0)
    
    masks = np.unique(resultsdf['K'])
    pred_avgd = resultsdf.groupby(['K']).mean()
    pred_std = resultsdf.groupby(['K']).std()
    
    from scipy.stats import ttest_rel
    ## F1 scores
    score_key=score_keys[0]
    tmp = resultsdf.groupby(['K'])
    tmp1 = tmp.get_group(0)[score_key]
    tmp2 = tmp.get_group(masks[-1])[score_key]
    pval = ttest_rel(tmp1,tmp2).pvalue
    

    #plt.figure()
    
    y = pred_avgd[score_key].values
    err = pred_std[score_key].values/np.sqrt(repeats)
    task = task_name[task]
    result_dict = {task:{}}
    #print(model,task,step,pval)
    
    result_dict[task][model_name[model]] = {score_key.replace(prefix,''):format_score(y[0],err[0],False)}
    result_dict[task][f"Masked {model_name[model]}"] = {score_key.replace(prefix,''):format_score(y[1],err[1],pval<0.05)}

    res = [[model_name[model],format_score(y[0],err[0],False)],[f"Masked {model_name[model]}",format_score(y[1],err[1],pval<0.05)]]
    for i,score_key in enumerate(score_keys[1:]):
       
        tmp1 = tmp.get_group(0)[score_key]
        tmp2 = tmp.get_group(masks[-1])[score_key]
        y = pred_avgd[score_key].values
        err = pred_std[score_key].values/np.sqrt(repeats)
        #print(res[0])
        result_dict[task][model_name[model]][score_key.replace(prefix,'')] = format_score(y[0],err[0],False)
        result_dict[task][f"Masked {model_name[model]}"][score_key.replace(prefix,'')] = format_score(y[1],err[1],False)
        
        
        #print(res[0])

    return result_dict

import collections.abc

def update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def make_full_table(tasks,models,step=True):
    table = []
    res = {}
    for m in models:
        _res = get_full_results(tasks[0],m,step=step)
        res=update(res,_res)
        for t in tasks[1:]:

            next_ = get_full_results(t,m,step=step)
            #res.update(next_)
            res = update(res,next_)
            
    
    taskns = [task_name[t] for t in tasks]
    return table_prototype.make_prototype(taskns,res)

def get_model_name(model,masks):
    if masks:
        return f"Masked {model_name[model]}"
    else:
        return model_name[model]


def epoch_time_table(tasks,models):
    task_masks = {"monopoly":[0,9],
                  "swat":[0,50],
                  "wadi":[0,127],
                  "gridworld":[0,26]
                }
   
    get_model_name = lambda m,mask: model_name[m]
    import os
    files = list(filter(lambda x: "epoch_times.txt" in x, os.listdir("./")))
    print(files)
    times = {}
    #read files
    for tsk in tasks:
        task = task_name[tsk]
        if task not in times:
            times[task] = {}
        for mdel in models:
            model=model_name[mdel]
            if model not in times[task]:
                times[task][model] = {}
            for mask in task_masks[tsk]:
                masked = int(mask!=0)
                if masked not in times[task][model]:
                    times[task][model][masked] = {}

                with open(f"{mdel}_{mask}_{tsk}_epoch_times.txt") as f:
                    lines = f.readlines()
                    values = np.mean([float(i) for i in lines])
                    times[task][model][masked]["mean"] = f"{values:.2f}"
                    times[task][model][masked]["std"] = f"{np.std([float(i) for i in lines]):.2f}"
    
    #make table
    
       
    print(times)
    table = table_prototype.make_epoch_time(times,times)
              
    print(table)




if __name__ =='__main__':
    #tasks = ['swat','wadi']
    #models = ['mlp','transformer',"GDN"]
    #print(make_full_table(tasks,models,step=True))    
#
    tasks = ['gridworld','monopoly']
    models = ['mlp','transformer',"GDN"]
    print(make_full_table(tasks,models,step=True))    
   

    tasks = ['gridworld',"monopoly"]
    models = ['mlp','transformer',"GDN"]
    print(make_full_table(tasks,models,step=False))    
   
    
    all_tasks = ['swat','wadi','gridworld','monopoly']
    all_models = ['mlp','transformer',"GDN"]
    #epoch_time_table(all_tasks,all_models)
    
    #header = ["Model","Polycraftv2","Monopoly"]
    
    #print(tabulate(table,headers =header ,tablefmt="latex_raw"))