import numpy as np
import argparse
import pandas as pd
import logging
import os
import glob
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator


ROOT = os.path.dirname(os.path.abspath(__file__)) + '/../../'
logger = logging.getLogger(__name__)
ALGORITHMS = ['Continuous/PPO', 'Discrete/PPO', 'Continuous/TD3', 'Discrete/DQN']

def tf_deploy_stats(base_folder="Default"):
    tag_base = 'benchmark_deploy/'
    cwd = ROOT + 'tensorboard/Deploy/' + base_folder
    path = ROOT + 'data/Deploy/' + base_folder + '/'
    f_csv = init_csv(path)
    f_tex = init_tex(path)

    for alg in ALGORITHMS:
        group = glob.glob(cwd + f'/**/{alg}', recursive=True)
        tuple_list = []
        method_list = []
        data_list = []
        for g in group:
            dirs = [dir for dir in os.listdir(g) if os.path.isdir(g)]
            v1, v2 = np.zeros(shape=(len(dirs))), np.zeros(shape=(len(dirs)))
            if 'Baseline' in g:
                tag_sup = f'{tag_base}is_safety_violation'
                is_baseline = True
            else:
                tag_sup = f'{tag_base}safety_activity'
                is_baseline = False

            for i, dir in enumerate(dirs):
                summary_iterator = EventAccumulator(g + '/' + dir).Reload()

                _r = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_base + 'env_reward'),
                    columns=summary_iterator.Scalars(tag_base + 'env_reward')[0]._fields)["value"]
                _s = pd.DataFrame.from_records(
                    summary_iterator.Scalars(tag_sup),
                    columns=summary_iterator.Scalars(tag_sup)[0]._fields)["value"]
                v1[i] = np.mean(_r)
                v2[i] = np.mean(_s)

            smooth_data = np.full([1, 6], np.nan)
            # reward
            smooth_data[0, 0] = np.mean(v1)
            smooth_data[0, 1] = np.std(v1)
            # safety
            if is_baseline:
                smooth_data[0, 4] = np.mean(v2)
                smooth_data[0, 5] = np.std(v2)
            else:
                smooth_data[0, 2] = np.mean(v2)
                smooth_data[0, 3] = np.std(v2)

            head, algo = os.path.split(g)
            head, space = os.path.split(head)
            head, tuple = os.path.split(head)
            head, method = os.path.split(head)
            write_to_csv(f_csv, algo, space, tuple, method, smooth_data)
            tuple_list.append(tuple)
            method_list.append(method)
            data_list.append(smooth_data)
        # Write tex block
        write_tex_block(f_tex, alg, tuple_list, method_list, data_list)
    write_tex_end(f_tex)
    f_csv.close()
    print("Done.")


def init_csv(path):
    header = 'configuration, mean_reward, std_reward, mean_safety_activity, std_safety_activity, mean_safety_violation, std_safety_violation \n'  # noqa: E501
    file = 'deploy_stats.csv'
    if not os.path.exists(path):
        os.makedirs(path)
    print("Writing to file " + path + file + " ...")
    f_csv = open(path + file, "w")
    f_csv.write(header)
    return f_csv


def write_to_csv(
    f,
    algo,
    space,
    tuple,
    method,
    data
):
    configuration = f'{algo}_{space}_{tuple}_{method}'
    f.write(f'{configuration}, {data[0, 0]}, {data[0, 1]}, {data[0, 2]}, {data[0, 3]}, {data[0, 4]}, {data[0, 5]} \n')  # noqa: E501


def init_tex(path):
    file = 'deploy_stats_table.tex'
    if not os.path.exists(path):
        os.makedirs(path)
    print("Writing to file " + path + file + " ...")
    f_tex = open(path + file, "w")
    begin_str = "\
\\begin{table}[t] \n \
    \\caption{Mean and standard deviation of N deployments.} \n \
    \\label{tab:deployment} \n \
    \\vskip 0.15in \n \
    \\begin{center} \n \
    \\begin{small} \n \
    \\begin{sc} \n \
    \\begin{tabularx}{\\textwidth}{l *{2}{Y}Z*{2}{Y}Z*{2}{Y}} %l *{8}{Y} \n \
        \\toprule \n \
        \\multirow{2}[2]{*}{ \\normalsize{\\textbf{Approach}}} \n \
        & \\multicolumn{2}{c}{ \\normalsize{\\textbf{Reward}}} \n \
        && \\multicolumn{2}{c}{ \\normalsize{\\textbf{Safety Activity}}} \n \
        && \\multicolumn{2}{c}{ \\normalsize{\\textbf{Safety Violation}}}\\\\ \n \
        \\cmidrule(lr){2-3} \\cmidrule(lr){5-6} \\cmidrule(lr){8-9} \n \
        & Mean & Std. Dev. && Mean & Std. Dev. && Mean & Std. Dev. \\\\ \n"
    f_tex.write(begin_str)
    return f_tex


def write_tex_block(
    f_tex,
    alg,
    tuple_list,
    method_list,
    data_list
):
    tex_headers = {
        'Continuous/PPO': "PPO (continuous)",
        'Discrete/PPO': "PPO (discrete)",
        'Continuous/TD3': "TD3",
        'Discrete/DQN': "DQN"
    }
    header = tex_headers[alg]
    block_str = f"\
        \\midrule \n \
        \\textbf{{{header}}} &&\\\\ \n \
        \\midrule \n"
    for i in range(len(tuple_list)):
        block_str += f"\
        {method_list[i]} ({tuple_list[i]}) & {data_list[i][0, 0]:.2f} & {data_list[i][0, 1]:.2f} && {data_list[i][0, 2]:.2f} & {data_list[i][0, 3]:.2f} && {data_list[i][0, 4]:.3f} & {data_list[i][0, 5]:.3f} \\\\ \n"  # noqa: E501
    f_tex.write(block_str)


def write_tex_end(f_tex):
    end_str = "\
        \\bottomrule \n \
        \\end{tabularx} \n \
    \\end{sc} \n \
    \\end{small} \n \
    \\end{center} \n \
    \\vskip -0.1in \n \
\\end{table} \n"
    f_tex.write(end_str)
    f_tex.close()
    print("Wrote tex file.")


def init_argparse() -> argparse.ArgumentParser:
    """Initialize the argument parser."""
    parser = argparse.ArgumentParser(
        usage="%(prog)s [OPTION] [FILE]...",
        description="Create deployment statistics.",
    )

    parser.add_argument(
        "type",
        help="Type of the environment.",
        type=str,
        default="Default"
    )

    return parser


if __name__ == "__main__":
    parser = init_argparse()
    args = parser.parse_args()
    tf_deploy_stats(args.type)
