"""
Usage:
Step 1: python summary_graphs_wandb.py --url username/project-name/
"""

import argparse
import wandb
import numpy as np
import seaborn as sns
import math
import matplotlib.pyplot as plt
import re
import pickle
from scipy import signal


def transform_title(env_name, seperator=" "):
    env_name = env_name.split('-')[0]
    if '_' in env_name:
        env_name = "".join([x.capitalize() for x in env_name.split('_')])
    env_name = seperator.join(
        re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', env_name))

    return env_name


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Custom script for extracting data from tf summaries and '
                                     'plotting only specific scalars  in plotly')
    parser.add_argument('--url', type=str, help='Path to event files', default="")

    args = parser.parse_args()

    # get data from API
    api = wandb.Api()
    runs = api.runs(args.url)

    # extract relevant data
    data = {}
    for runs in [api.runs(args.url)]:
        for run in runs:
            # if run.config['explore_mode'] == 'mcts+fixed':
            #     continue

            env_name = run.config['env_name']
            explore_mode = run.config['explore_mode']
            auto_entropy_tuning = run.config['automatic_entropy_tuning']

            if env_name not in data:
                data[env_name] = {}
            if explore_mode not in data[env_name]:
                data[env_name][explore_mode] = {}
            if auto_entropy_tuning not in data[env_name][explore_mode]:
                data[env_name][explore_mode][auto_entropy_tuning] = {}

            # save the metrics for the run to a csv file
            history = run.history(keys=["test/no_search/episode_reward", "global_step", "_runtime"], samples=1000)
            # no_search_test = np.array([[row["global_step"], row["test/no_search/episode_reward"], row['_runtime']]
            #                            for row in history.values[0]])
            # no_search_data = {'x': no_search_test[:, 0], 'y': no_search_test[:, 1], 'runtime': no_search_test[:, 2]}

            try:
                no_search_data = {'x': history.values[:, history.columns.get_loc('global_step')][:80],
                                  'y': history.values[:, history.columns.get_loc('test/no_search/episode_reward')][:80],
                                  'runtime': history.values[:, history.columns.get_loc('_runtime')][:80]}

                data[env_name][explore_mode][auto_entropy_tuning][run.id] = {}
                data[env_name][explore_mode][auto_entropy_tuning][run.id][ 'test/no_search/episode_reward'] = no_search_data
            except Exception as e:
                print(e)


    # pickle.dump(data, open('data.p', 'wb'))
    data = pickle.load(open('data.p', 'rb'))

    group_8_envs = ["AcrobotSwingup-v0", "CheetahRun-v0", "ReacherHard-v0", "QuadrupedRun-v0",
                    "FingerSpin-v0", "HopperHop-v0", "WalkerStand-v0", "WalkerWalk-v0",
                    "CartpoleBalance-v0", "CartpoleSwingup-v0", "CartpoleBalance_sparse-v0",
                    "CartpoleSwingup_sparse-v0", "PendulumSwingup-v0"]

    group_8_envs += ["Ball_in_cupCatch-v0"]
    group_8_envs += ["ReacherEasy-v0"]
    group_8_envs += ["WalkerRun-v0"]
    group_8_envs += ["HopperStand-v0"]
    group_8_envs += ["QuadrupedWalk-v0"]
    group_8_envs += ["FingerTurn_easy-v0", "FingerTurn_hard-v0"]
    # group_8_envs += ["FishUpright-v0", "FishSwim-v0"]

    group_8_envs = sorted(group_8_envs)

    # Plot Data
    plt.style.use('seaborn-whitegrid')
    n_cols = 4
    # n_rows = math.ceil(len(list(data.keys())) / n_cols)
    n_rows = math.ceil(len(group_8_envs) / n_cols)
    grid_8_plot, grid_8_plot_axes = plt.subplots(n_rows, n_cols, figsize=(8, min(10, 2 * n_rows)), )

    # set width of bar
    barWidth = 0.22

    # set height of bar
    bars = {'mean': {}, 'std': {}}
    across_task_mean = {'x': {}, 'y': {}}
    across_task_std = {'x': {}, 'y': {}}
    across_task_runtime_mean = {'runtime': {}, 'steps': {}}

    # Note: Refer here for color names : https://www.w3schools.com/cssref/css_colors.asp
    candidate_colors = ['coral', 'orchid', 'springgreen', 'turquoise', 'thistle',
                        'yellow', 'chartreuse', 'darkcyan', 'darkmagenta']

    legend_colors = {}
    # for env_i, env_name in enumerate(sorted(list(data.keys()))):
    for env_i, env_name in enumerate(group_8_envs):

        row_i, col_i = env_i // n_cols, (env_i % n_cols)

        env_data = data[env_name]
        for explore_mode, explore_data in env_data.items():
            if explore_mode == 'mcts+fixed':
                continue
            for auto_entropy_tuning, tune_data in explore_data.items():
                max_step_i = min(len(tune_data[run_id]["test/no_search/episode_reward"]["x"]) for run_id in tune_data)
                x_data = tune_data[list(tune_data.keys())[0]]["test/no_search/episode_reward"]["x"][:max_step_i]
                score_data = []
                runtime_data = []
                for run_id in tune_data:
                    score_data.append(tune_data[run_id]["test/no_search/episode_reward"]["y"][:max_step_i])
                    runtime_data.append(tune_data[run_id]["test/no_search/episode_reward"]["runtime"][:max_step_i])

                score_mean = np.array(score_data).mean(axis=0)
                score_std = np.array(score_data).std(axis=0)
                score_st_error = score_std/ np.sqrt(len(score_mean))

                runtime_mean = np.array(runtime_data).mean(axis=0)
                runtime_std = np.array(runtime_data).std(axis=0)

                if explore_mode not in legend_colors:
                    legend_colors[explore_mode] = candidate_colors.pop(0)

                _label = "dreamer" if explore_mode == 'no-search' else "dreamer + " + explore_mode
                print(len(score_mean))
                _y = (score_mean if len(score_mean) < 9 else signal.savgol_filter(score_mean, 9, 3))
                sns.lineplot(x=x_data, y=_y, color=legend_colors[explore_mode],
                             ax=grid_8_plot_axes[row_i, col_i], label=_label)

                # tree_comp_tasks_score_plot_axes.fill_between(x_data[:min_i],
                #                                              _mean_perf[:min_i] + _mean_perf_std,
                #                                              _mean_perf[:min_i] - _mean_perf_std,
                #                                              facecolor=_color, alpha=0.5)
                grid_8_plot_axes[row_i, col_i].fill_between(x_data,score_mean+ score_std , score_mean-score_std,
                                                            facecolor=legend_colors[explore_mode], alpha=0.5)

                if (explore_mode, _label) in bars['mean']:
                    bars['mean'][(explore_mode, _label)][env_name] = score_mean[-1]
                    bars['std'][(explore_mode, _label)][env_name] = score_std[-1]
                    # bars['mean'][(explore_mode, _label)][env_name].append(score_mean[-1])
                    # bars['std'][(explore_mode, _label)][env_name].append(score_std[-1])
                    across_task_mean['y'][(explore_mode, _label)].append(score_mean)
                    across_task_runtime_mean['runtime'][(explore_mode, _label)].append(runtime_mean)
                else:
                    bars['mean'][(explore_mode, _label)] = {env_name: score_mean[-1]}
                    bars['std'][(explore_mode, _label)] = {env_name: score_std[-1]}
                    # bars['mean'][(explore_mode, _label)][env_name] = score_mean[-1]
                    # bars['std'][(explore_mode, _label)][env_name] = score_std[-1]
                    across_task_mean['y'][(explore_mode, _label)] = [score_mean]
                    across_task_runtime_mean['runtime'][(explore_mode, _label)] = [runtime_mean]

                if ((explore_mode, _label) not in across_task_mean['x']) or \
                        len(x_data) >= len(across_task_mean['x'][(explore_mode, _label)]):
                    across_task_mean['x'][(explore_mode, _label)] = x_data
                    across_task_runtime_mean['steps'][(explore_mode, _label)] = x_data

        title = transform_title(env_name)
        grid_8_plot_axes[row_i][col_i].set_title(title, y=1.0, fontsize=10)
        grid_8_plot_axes[row_i][col_i].margins(x=0.01, y=0.01)

        if row_i == n_rows - 1 and col_i == n_cols - 1:
            grid_8_plot_axes[row_i][col_i].legend(loc='lower right', prop={'size': 8})
        else:
            grid_8_plot_axes[row_i][col_i].get_legend().remove()

    # grid_8_plot.setp(grid_8_plot_axes, )
    grid_8_plot.tight_layout(pad=0.01)
    grid_8_plot.savefig('summary_2million_st_error.png')

    group_bar_plot, group_bar_plot_axes = plt.subplots(1, 1, figsize=(8, 2.5), )
    legend_order = [('no-search', 'dreamer'),
                    ('rollout', 'dreamer + rollout'),
                    ('mcts', 'dreamer + mcts')]
    for label_i, key in enumerate(legend_order):
        bar_values = [(bars['mean'][key][x] if x in bars['mean'][key] else 0) for x in group_8_envs]
        bar_errs = [(bars['std'][key][x] if x in bars['std'][key] else 0) for x in group_8_envs]
        explore_mode, _label = key
        if label_i != 0:
            position_x = [x + barWidth for x in position_x]
        else:
            position_x = np.arange(max(len(_) for _ in bars['mean'].values()))
        try:
            group_bar_plot_axes.bar(position_x[:len(bar_values)], bar_values, yerr=bar_errs, capsize=2,
                                    ecolor='black',
                                    color=legend_colors[explore_mode], width=barWidth,
                                    edgecolor='white', label=_label)
        except:
            pass

    # Add xticks on the middle of the group bars
    plt.xticks([r + barWidth * len(legend_order) for r in range(len(position_x))],
               ['Ball In \n CupCatch' if "cup" in env else transform_title(env, "\n") for env in group_8_envs],
               ha='right',
               rotation=90)

    # Create legend & Show graphic
    group_bar_plot_axes.legend(loc='center', bbox_to_anchor=(0.5, 1.1), ncol=len(bars['mean'].keys()))
    group_bar_plot.tight_layout(pad=0.01)
    group_bar_plot.savefig('group_bar.png')

    grid_8_plot.tight_layout(pad=0.01)
    grid_8_plot.savefig('summary.png')

    # expected performance  & expected time plots across tasks
    across_tasks_score_plot, across_tasks_score_plot_axes = plt.subplots(1, 1, figsize=(3, 2.5), )
    across_tasks_time_plot, across_tasks_time_plot_axes = plt.subplots(1, 1, figsize=(3, 2.5), )
    across_tasks_score_time_plot, across_tasks_score_time_plot_axes = plt.subplots(1, 1, figsize=(3, 2.5), )
    for label_i, key in enumerate(legend_order):
        if key not in across_task_mean['y']:
            continue
        explore_mode, _label = key
        mean_values = across_task_mean['y'][key]

        perf_sum = np.zeros(max(len(x) for x in mean_values))
        time_sum = np.zeros(max(len(x) for x in mean_values))
        _valids = np.zeros(max(len(x) for x in mean_values))
        for _val_i, _val in enumerate(mean_values):
            perf_sum += list(_val) + [0 for _ in range(len(perf_sum) - len(_val))]
            time_sum += list(across_task_runtime_mean['runtime'][key][_val_i]) + \
                        [0 for _ in range(len(time_sum) - len(_val))]
            _valids += list(np.ones(len(_val))) + [0 for _ in range(len(perf_sum) - len(_val))]

        _mean_perf = perf_sum / _valids
        _mean_time = time_sum / _valids

        min_i = min(len(x) for x in mean_values)
        _perf_std = np.array([x[:min_i] for x in mean_values]).std(axis=0)
        across_tasks_score_plot_axes.plot(across_task_mean['x'][key], _mean_perf,
                                          color=legend_colors[explore_mode], label=_label)
        across_tasks_score_plot_axes.fill_between(across_task_mean['x'][key][:min_i], _mean_perf[:min_i] + _perf_std,
                                                 _mean_perf[:min_i] - _perf_std,
                                                 facecolor=legend_colors[explore_mode], alpha=0.5)
        across_tasks_time_plot_axes.plot(across_task_runtime_mean['steps'][key],
                                         _mean_time / (60 * 60),
                                         color=legend_colors[explore_mode], label=_label)
        across_tasks_score_time_plot_axes.plot(_mean_time / (60 * 60), _mean_perf,
                                               color=legend_colors[explore_mode], label=_label)

    # handles, labels = across_tasks_score_plot_axes[1].get_legend_handles_labels()
    # across_tasks_score_plot.legend(handles, labels, loc='upper center', ncol=len(labels))
    across_tasks_score_plot_axes.legend()
    across_tasks_score_plot_axes.set_ylabel('episode reward')
    across_tasks_score_plot_axes.set_xlabel('env. steps')
    across_tasks_time_plot_axes.set_ylabel('time(hours)')
    across_tasks_time_plot_axes.set_xlabel('env. steps')
    across_tasks_score_time_plot_axes.set_ylabel('episode rewards')
    across_tasks_score_time_plot_axes.set_xlabel('time(hours)')
    # across_tasks_score_plot.tight_layout(h_pad=1, w_pad=1, rect=(0.05, 0.05, 0.95, 0.95))
    across_tasks_score_plot.tight_layout(pad=1)
    across_tasks_time_plot.tight_layout(pad=1)
    across_tasks_score_time_plot.tight_layout(pad=1)

    across_tasks_score_plot.savefig('across_tasks_score_with_std' + '.png')
    across_tasks_time_plot.savefig('across_tasks_time' + '.png')
    across_tasks_score_time_plot.savefig('across_tasks_score_time' + '.png')


    # Mean of mcts fixed and progressive
    tree_envs = ["AcrobotSwingup-v0",
            "CheetahRun-v0",
            "ReacherHard-v0",
            "QuadrupedRun-v0",
            "FingerSpin-v0",
            "HopperHop-v0",
            "WalkerStand-v0",
            "WalkerWalk-v0"]

    tree_comp_tasks_score_plot, tree_comp_tasks_score_plot_axes = plt.subplots(1, 1, figsize=(3, 2.5), )
    tree_mean_scores = {'mcts+fixed':{'x':[],'y':[]},'mcts':{'x':[],'y':[]}}

    for env_i, env_name in enumerate(tree_envs):

        row_i, col_i = env_i // n_cols, (env_i % n_cols)

        env_data = data[env_name]
        for explore_mode, explore_data in env_data.items():
            if 'mcts' not in explore_mode:
                continue
            print(env_name,explore_mode,explore_data.keys())
            for auto_entropy_tuning, tune_data in explore_data.items():
                max_step_i = min(len(tune_data[run_id]["test/no_search/episode_reward"]["x"]) for run_id in tune_data)
                x_data = tune_data[list(tune_data.keys())[0]]["test/no_search/episode_reward"]["x"][:max_step_i]
                score_data = []
                runtime_data = []
                for run_id in tune_data:
                    score_data.append(tune_data[run_id]["test/no_search/episode_reward"]["y"][:max_step_i])
                    runtime_data.append(tune_data[run_id]["test/no_search/episode_reward"]["runtime"][:max_step_i])

                score_mean = np.array(score_data).mean(axis=0)
                score_std = np.array(score_data).std(axis=0)
                tree_mean_scores[explore_mode]['y'].append(score_mean)
                if len(x_data) >= len(tree_mean_scores[explore_mode]['x']):
                    tree_mean_scores[explore_mode]['x'] = x_data

    print()
    for _key in tree_mean_scores:
        mean_values = tree_mean_scores[_key]['y']
        perf_sum = np.zeros(max(len(x) for x in mean_values))
        _valids = np.zeros(max(len(x) for x in mean_values))
        for _val_i, _val in enumerate(mean_values):
            perf_sum += list(_val) + [0 for _ in range(len(perf_sum) - len(_val))]
            _valids += list(np.ones(len(_val))) + [0 for _ in range(len(perf_sum) - len(_val))]
        _mean_perf = perf_sum / _valids

        max_i_len,max_i = max((len(x),i) for i,x in enumerate(mean_values))
        print([len(x) for x in mean_values])
        _x = [list(x)+list(mean_values[max_i][len(x):]) for x in mean_values]
        _mean_perf_std = np.array(_x).std(axis=0)
        _color = candidate_colors.pop(0) if _key not in legend_colors else legend_colors[_key]
        tree_comp_tasks_score_plot_axes.plot(tree_mean_scores[_key]['x'], _mean_perf,
                                             color=_color, label='mcts:fixed' if 'fixed' in _key else 'mcts:progressive')
        tree_comp_tasks_score_plot_axes.fill_between(tree_mean_scores[_key]['x'], _mean_perf + _mean_perf_std,
                                                 _mean_perf - _mean_perf_std,
                                                 facecolor=_color, alpha=0.5)
    tree_comp_tasks_score_plot_axes.legend()
    tree_comp_tasks_score_plot_axes.set_ylabel('episode reward')
    tree_comp_tasks_score_plot_axes.set_xlabel('env. steps')
    tree_comp_tasks_score_plot.tight_layout(pad=1)

    tree_comp_tasks_score_plot.savefig('tree-search-summary' + '.png')







