import os
import numpy
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.pyplot import MultipleLocator
import argparse
import re
import numpy as np
import pandas as pd
import pdb


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Walker2d-v3')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--steps-per-epoch', type=int, default=4000)
    parser.add_argument('--filename', type=str, default='progress.txt')
    parser.add_argument('-k', '--target-keys', type=str, default=['AverageTestEpRet', 'LossQ', 'LossPi'])
    parser.add_argument('--dir-names', nargs='+', default=[])
    parser.add_argument('--exp-names', nargs='+', default=[])
    parser.add_argument('--save-dirs', nargs='+', default=[])
    parser.add_argument('--base-save-dir', type=str, default='results')
    return parser.parse_args()


def extract_data(args, dir_name):
    target_path = os.path.join(dir_name, args.filename)
    with open(target_path, 'r') as f:
        data = f.readlines()
    data = [line.strip().split('\t') for line in data]
    if len(data) == 0:
        return []
    keys = np.array(data[0])
    values = np.array(data)[1:]
    values = np.where(values == '', '0', values).astype('float64')
    target_values = []
    for target_key in args.target_keys:
        target_key_index = np.where(keys == target_key)[0][0]
        if len(values[:, target_key_index]) >= args.epochs:
            target_values.append(values[args.epochs - 1, target_key_index])
    return np.array(target_values, dtype='float64')


def get_dirs(dir_name):
    first_level_results = os.listdir(dir_name)
    first_level_dirs = [result for result in first_level_results
                        if os.path.isdir(os.path.join(dir_name, result))]
    second_level_dirs = []
    keys = None
    values = []
    for first_level_dir in first_level_dirs:
        second_level_results = os.listdir(os.path.join(dir_name, first_level_dir))
        for result in second_level_results:
            if os.path.isdir(os.path.join(dir_name, first_level_dir, result)):
                second_level_dirs += [os.path.join(dir_name, first_level_dir, result)]
                result = result.replace('True', '1')
                result = result.replace('False', '0')
                if keys is None:
                    original_keys = re.findall(r'[A-Za-z_]+', result)
                    keys = [key.replace('_', '') for key in original_keys if key != 'e']
                values.append(np.array(re.findall(r'[0-9]+\.*[0-9]*e*[+-]*[0-9]*', result), dtype='float'))
    return second_level_dirs, keys, values


def collect(args=parse_args()):
    for save_dir in args.save_dirs:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
    if len(args.dir_names) == 0:
        for dir_name in os.listdir(args.base_save_dir):
            if os.path.isfile(os.path.join(args.base_save_dir, dir_name)):
                continue
            args.dir_names.append(os.path.join(args.base_save_dir, dir_name))
            args.exp_names.append('')
            args.save_dirs.append(args.base_save_dir)
    for (dir_name, exp_name, save_dir) in zip(args.dir_names, args.exp_names, args.save_dirs):
        target_dirs, keys, values = get_dirs(dir_name)
        if keys is None:
            continue
        total_keys = keys + args.target_keys
        data = pd.DataFrame(data=[], columns=total_keys)
        column_index = 0
        for (target_dir, value) in zip(target_dirs, values):
            exp_results = extract_data(args, target_dir)
            if len(exp_results) > 0:
                data.loc[column_index] = list(np.concatenate((value, exp_results)))
                column_index += 1
        data = data.sort_values(by=['env', 's'], ascending=True)
        data.to_csv(os.path.join(save_dir, os.path.basename(dir_name) + '.csv'), index=False)
    for file_name in os.listdir(args.base_save_dir):
        if os.path.isdir(os.path.join(args.base_save_dir, file_name)) or file_name[-3:] != 'csv':
            continue
        data = pd.read_csv(os.path.join(args.base_save_dir, file_name))
        if 'env' not in data.columns:
            continue
        if 'noise' in data.columns:
            data = data[data['noise'] == 0]
        print(file_name)
        for i in np.sort(np.unique(np.array(data['env']))):
            env_data = np.array(data[data['env'] == i]['AverageTestEpRet'])
            print('{:.1f}±{:.1f}'.format(np.mean(env_data), np.std(env_data)), end=' & ')
        print('')


if __name__ == '__main__':
    collect(parse_args())
