from os import read
import numpy as np
import collections

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from numpy import e, exp, genfromtxt
import itertools



def load(path):
    return np.load(path, allow_pickle=True)

'params'
# datapath = './advancedAI/output/svhn.csv'
# figurepath = './advancedAI/output/tmp.png'
datapath = './train_output/bak/out_first_edg.txt'
# datapath = './train_output/bak/out_edg_random_pair.txt'
# datapath = './train_output/out.txt'
figurepath = './train_output/images/tmp.png'
# exps_folder = 'EXPS/grid-holdout_fraction-env_density'
# exps_folder = 'EXPS/grid-env_number-env_distance-env_sample_number'
exps_folder = 'EXPS/new_compare'

'init'
marker = itertools.cycle(('x', '+', '.', 'o', '*'))

'read data'
def read_data(datapath):
    with open(datapath, 'r') as f:
        lines = f.readlines()
    metric_dict = collections.defaultdict(lambda: [])
    print(datapath)
    for i in range(len(lines)):
        line = lines[i]
        if "eval results" in line and i + 2 < len(lines) - 1:
            keys = lines[i+1].split()
            vals = lines[i+2].split()
            vals = [float(each) for each in vals]
            for k_i in range(len(keys)):
                metric_dict[keys[k_i]].append(vals[k_i])
    df = pd.DataFrame.from_dict(metric_dict)
    return df

# df = read_data('./train_output/bak/out_first_edg.txt') # basic edg
# df_erm = read_data('./train_output/bak/out_erm.txt') # basic erm
# df_edg_random_pair = read_data('./train_output/bak/out_edg_random_pair.txt') # edg random pair - sampleShuffle(random shuffle for each sampling)
# df_edg_random_fixed_pair = read_data('./train_output/bak/out_edg_random_fixed_pair.txt') # edg random pair- initShuffle(fixed for all sampling)
# df_edg_all_data_fraction001 = read_data('./train_output/bak/out_edg_all_data_fraction=001.txt') # edg using all data for training without split
df_list = [
]

def dict_to_str(d):
    res = ''
    for k, v in d.items():
        res += k
        res += '#'
        res += str(v)
        res += '-'
    return res[:-1]

def decode_dir_name(dir):
    # # print(dir)
    # dir_sp = dir.split("-")
    # # print("=====")
    # # print(dir_sp)
    # seed = int(dir_sp[4][4:])
    # holdout_fraction = int(dir_sp[7])
    # env_density = int(dir_sp[9][11:])
    # algorithm = str(dir_sp[-1][9:])
    # # seed = int(dir_sp[-1][4:])
    # # holdout_fraction = int(dir_sp[5])
    # # env_density = int(dir_sp[7][11:])
    # return {
    #     "env_density": env_density,
    #     "seed": seed,
    #     "algorithm": algorithm,
    #     "holdout_fraction": holdout_fraction,
    # }
    # print(dir)
    dir_sp = dir.split("--")
    dir_sp.pop(0)
    algorithm = str(dir_sp[1][9:])
    env_distance = int(dir_sp[2].split('#')[-1])
    env_number = int(dir_sp[3].split('#')[-1])
    env_sample_number = int(dir_sp[4].split('#')[-1])
    seed = int(dir_sp[5].split('#')[-1])
    return {
        "env_distance": env_distance,
        "env_number": env_number,
        "env_sample_number": env_sample_number,
        "seed": seed,
        "algorithm": algorithm,
    }
    # TODO simplify

def deduct_df_list(df_list):
    tmp = collections.defaultdict(lambda: [])
    for each in df_list:
        each["hparams"]
        del each["hparams"]['seed']
        tmp[dict_to_str(each["hparams"])].append(each)
    res = []
    for k,v in tmp.items():
        df_mean = pd.concat([each['df'] for each in v])
        df_mean = df_mean.groupby("step").mean()
        df_mean['step'] = df_mean.index

        # del v[0]['hparams']['seed']
        res.append({
            'name': k,
            'df': df_mean,
            'hparams': v[0]['hparams'] 
        })
        print(k)
    return res

def expand_hparams_columns(df_list):
    for i, content in enumerate(df_list):
        for k, v in content['hparams'].items():
            df_list[i]['df'][k] = v
    return df_list

import os
exps_list = ["grid-holdout_fraction-env_density--holdout_fraction0-6--env_density9"]


for root, dirs, files in os.walk(exps_folder, topdown=False):
    if root == exps_folder:
        for dir in dirs:
            if "seed" in dir:
                df_list.append({
                    "name": dir,
                    "path": os.path.join(root, dir, "out.txt"),
                    "hparams": decode_dir_name(dir)
                })
                # print(df_list[-1])
                # print(dict_to_str(df_list[-1]["hparams"]))
                # exps_list.append(dir[:-6])
                # print(df_list[-1]["hparams"])
# read data
for i, each in enumerate(df_list):
    df_list[i]['df'] = read_data(each['path'])
 

df_list = expand_hparams_columns(df_list)
df_total = pd.concat([each['df'] for each in df_list])
df_mean = df_total.groupby(['algorithm', 'env_number', 'env_distance', 'env_sample_number', 'step']).mean()
last_step = df_mean.iloc[df_mean.index.get_level_values('step') == 4900]
# deduct the same exps
df_list = deduct_df_list(df_list)





'process'
# df = pd.read_csv(datapath)
# df_std = df.loc[df.name=="std"].drop(columns="name").T
# df_avg = df.loc[df.name=="avg"].drop(columns="name").T
# columns = df_std.iloc[0]
# # indexes = ["1K", "2K", "3K", "4K", "5K", "6K"] # for svhn
# indexes = ["2K", "4K", "6K", "8K", "10K", "12K"] # for sifar10
# df_std = df_std.drop("baselines")
# df_avg = df_avg.drop("baselines")
# df_std.columns = columns # set columns into algorithm name
# df_avg.columns = columns
# df_avg.index = indexes
# df_avg.columns.name = None
# df_avg.index.name = "round"


################## plot ##################plt.clf()
fig, ((ax1)) = plt.subplots(1, 1, sharey=True, figsize=[20, 20])
for i, each in enumerate(df_list):
    # print(each['name'])
    # print(each['df'])
    ax1.plot(each['df'].loc[:, 'epoch'], each['df'].loc[:, 'query_acc'], label=each['name'])
    
# ax1.plot(df.loc[:, 'epoch'], df.loc[:, 'acc'], label='train_acc', marker = next(marker))
# ax1.plot(df.loc[:, 'epoch'], df.loc[:, 'query_acc'], label='basic_edg-query_acc', marker = next(marker))
# ax1.plot(df_erm.loc[:, 'epoch'], df_erm.loc[:, 'env5_in_acc']*0.8+ 0.2*df_erm.loc[:, 'env5_out_acc'], label='basic_erm-query_acc', marker = next(marker))
# ax1.plot(df_edg_random_pair.loc[:, 'epoch'], df_edg_random_pair.loc[:, 'query_acc'], label='edg_random_pair(sampleShuffle)-query_acc', marker = next(marker))
# ax1.plot(df_edg_random_fixed_pair.loc[:, 'epoch'], df_edg_random_fixed_pair.loc[:, 'query_acc'], label='edg_random_pair(initShuffle)-query_acc', marker = next(marker))
# ax1.plot(df_edg_all_data_fraction001.loc[:, 'epoch'], df_edg_all_data_fraction001.loc[:, 'query_acc'], label='edg_all_data_fraction001-query_acc', marker = next(marker))
# ax1.plot(df.loc[:, 'step'], df.loc[:, 'query_acc'], label='query_acc', marker = next(marker))
# df.plot(x ='step', y='acc', kind = 'line')
# df.plot(x ='step', y='query_acc', kind = 'line')

plt.legend() # show legend
ax1.spines['top'].set_visible(False) # remove borders
ax1.spines['right'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax1.spines['left'].set_visible(False)

ax1.set_facecolor('#f0f0f0') # set gray background
plt.grid(axis = 'y', color='white', linewidth=2) # set white grid
plt.grid(axis = 'x', color='white', linewidth=2) # set white grid
ax1.set_title('') # title
ax1.set_xlabel('Epoch') # x label
ax1.set_ylabel('Accuracy') # y label
# plt.xlim((-5, 5)) # x lim range
# plt.ylim((0.8, 1.0)) # y lim range
plt.savefig(figurepath)
##########################################

