from collections import OrderedDict

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
from train.stats import Stats
from matplotlib.ticker import ScalarFormatter


agent1 = {
    ">": 1.0,
    "S": 0.589,
    "SS": 0.562,
    "SSS": 0.562,
    "SSSS": 0.548,
    "SSSSS": 0.521,
    "SSSSSS": 0.479,
    "SSSSSSS": 0.452,
    "SSSSSSSP": 0.452,
    "SSSSSSSPP": 0.452,
    "SSSSSSSPPP": 0.452,
    "SSSSSSSPPPP": 0.452,
    "SSSSSSSPPPPP": 0.452,
    "SSSSSSSPPPPPL": 0.452,
    "SSSSSSSPPPPPLV": 0.452,
    "SSSSSSSPPPPPLVN": 0.438,
    "SSSSSSSPPPPPLVNL": 0.438,
    "SSSSSSSPPPPPLVNLA": 0.397,
    "SSSSSSSPPPPPLVNLAA": 0.397,
    "SSSSSSSPPPPPLVNLAAA": 0.397,
    "SSSSSSSPPPPPLVNLAAAA": 0.397,
    "SSSSSSSPPPPPLVNLAAAAY": 0.356,
    "SSSSSSSPPPPPLVNLAAAAYL": 0.356,
    "SSSSSSSPPPPPLVNLAAAAYLQ": 0.110,
    "SSSSSSSPPPPPLVNLAAAAYLQQ": 0.027,
    "SSSSSSSPPPPPLVNLAAAAYLQQQ": 0.014,
    "SSSSSSSPPPPPLVNLAAAAYLQQQF": 0.014,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFK": 0.014,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKK": 0.014,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKK": 0.014,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKE": 0.009,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKED": 0.0012
}

agent2 = {
    ">": 1.0,
    "S": 0.635,
    "SS": 0.635,
    "SSS": 0.625,
    "SSSS": 0.573,
    "SSSSS": 0.552,
    "SSSSSS": 0.521,
    "SSSSSSS": 0.479,
    "SSSSSSSP": 0.479,
    "SSSSSSSPP": 0.479,
    "SSSSSSSPPP": 0.479,
    "SSSSSSSPPPP": 0.479,
    "SSSSSSSPPPPP": 0.479,
    "SSSSSSSPPPPPL": 0.479,
    "SSSSSSSPPPPPLV": 0.479,
    "SSSSSSSPPPPPLVN": 0.479,
    "SSSSSSSPPPPPLVNL": 0.479,
    "SSSSSSSPPPPPLVNLA": 0.458,
    "SSSSSSSPPPPPLVNLAA": 0.458,
    "SSSSSSSPPPPPLVNLAAA": 0.448,
    "SSSSSSSPPPPPLVNLAAAA": 0.448,
    "SSSSSSSPPPPPLVNLAAAAY": 0.354,
    "SSSSSSSPPPPPLVNLAAAAYL": 0.354,
    "SSSSSSSPPPPPLVNLAAAAYLQ": 0.083,
    "SSSSSSSPPPPPLVNLAAAAYLQQ": 0.052,
    "SSSSSSSPPPPPLVNLAAAAYLQQQ": 0.031,
    "SSSSSSSPPPPPLVNLAAAAYLQQQF": 0.031,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFK": 0.01,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKK": 0.01,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKK": 0.01,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKE": 0.007,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKED": 0.0009
}

human_stats = {
    ">": 1.0,
    "S": 0.9914529914529915,
    "SS": 0.9914529914529915,
    "SSS": 0.9914529914529915,
    "SSSS": 0.9914529914529915,
    "SSSSS": 0.9829059829059829,
    "SSSSSS": 0.8803418803418803,
    "SSSSSSS": 0.7948717948717948,
    "SSSSSSSP": 0.9572649572649573,
    "SSSSSSSPP": 0.9572649572649573,
    "SSSSSSSPPP": 0.9572649572649573,
    "SSSSSSSPPPP": 0.9401709401709402,
    "SSSSSSSPPPPP": 0.9145299145299145,
    "SSSSSSSPPPPPL": 0.9572649572649573,
    "SSSSSSSPPPPPLV": 0.9572649572649573,
    "SSSSSSSPPPPPLVN": 0.9572649572649573,
    "SSSSSSSPPPPPLVNL": 0.9316239316239316,
    "SSSSSSSPPPPPLVNLA": 0.9487179487179487,
    "SSSSSSSPPPPPLVNLAA": 0.9487179487179487,
    "SSSSSSSPPPPPLVNLAAA": 0.9487179487179487,
    "SSSSSSSPPPPPLVNLAAAA": 0.9401709401709402,
    "SSSSSSSPPPPPLVNLAAAAY": 0.905982905982906,
    "SSSSSSSPPPPPLVNLAAAAYL": 0.7435897435897436,
    "SSSSSSSPPPPPLVNLAAAAYLQ": 0.8034188034188035,
    "SSSSSSSPPPPPLVNLAAAAYLQQ": 0.8034188034188035,
    "SSSSSSSPPPPPLVNLAAAAYLQQQ": 0.8034188034188035,
    "SSSSSSSPPPPPLVNLAAAAYLQQQF": 0.8376068376068376,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFK": 0.7435897435897436,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKK": 0.7435897435897436,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKK": 0.7435897435897436,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKE": 0.7435897435897436,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKED": 0.5726495726495726
}

human_estimate = {
    ">": 1.0,
    "S": 1.0,
    "SS": 1.0,
    "SSS": 1.0,
    "SSSS": 1.0,
    "SSSSS": 0.9,
    "SSSSSS": 0.88,
    "SSSSSSS": 0.85,
    "SSSSSSSP": 0.82,
    "SSSSSSSPP": 0.80,
    "SSSSSSSPPP": 0.79,
    "SSSSSSSPPPP": 0.78,
    "SSSSSSSPPPPP": 0.78,
    "SSSSSSSPPPPPL": 0.77,
    "SSSSSSSPPPPPLV": 0.75,
    "SSSSSSSPPPPPLVN": 0.75,
    "SSSSSSSPPPPPLVNL": 0.74,
    "SSSSSSSPPPPPLVNLA": 0.7,
    "SSSSSSSPPPPPLVNLAA": 0.7,
    "SSSSSSSPPPPPLVNLAAA": 0.7,
    "SSSSSSSPPPPPLVNLAAAA": 0.68,
    "SSSSSSSPPPPPLVNLAAAAY": 0.68,
    "SSSSSSSPPPPPLVNLAAAAYL": 0.68,
    "SSSSSSSPPPPPLVNLAAAAYLQ": 0.67,
    "SSSSSSSPPPPPLVNLAAAAYLQQ": 0.67,
    "SSSSSSSPPPPPLVNLAAAAYLQQQ": 0.66,
    "SSSSSSSPPPPPLVNLAAAAYLQQQF": 0.66,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFK": 0.64,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKK": 0.64,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKK": 0.64,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKE": 0.63,
    "SSSSSSSPPPPPLVNLAAAAYLQQQFKKKED": 0.6
}


def convert_to_plot(summary_stats):
    d = OrderedDict()
    d['>'] = 1.0
    consensus = 'SSSSSSSPPPPPLVNLAAAAYLQQQFKKKED'
    for i, l in enumerate(consensus):
        d[consensus[:i+1]] = 0.0

    for k, v in summary_stats.items():
        if k in d:
            d[k] = v
    return d


def plot_consensus_likelihood():
    prev_agents = {k1: (v1 + v2) / 2 for (k1, v1), (k2, v2) in zip(agent1.items(), agent2.items())}
    bc_agents = convert_to_plot(Stats.calculate_ensemble_stats(['tmp/experiments/minerl/meta/eval_bc/stats'],
                                                               'SS')['summary'])

    new_rl_agents = convert_to_plot(Stats.calculate_ensemble_stats(['tmp/rec/stats_new_log4'],
                                                                   'SS')['summary'])

    prev_agents_sum = convert_to_plot(Stats.calculate_ensemble_stats(root_dirs=[#'tmp/rec/stats/',
                                                                                'tmp/minerl/minerl-competition/tmp/rec/stats',
                                                                                'tmp/rec/stats'],
                                                                     minimum_consensus_match='SS',
                                                                     replay_until_correction=True)['summary'])

    fig = plt.figure(figsize=(10, 4))

    labels = [k[-1] for k in agent1.keys()]
    x = np.array([i for i, _ in enumerate(labels)])
    prev_agents_arr = np.array([v for k, v in prev_agents.items()])
    bc_agents_arr = np.array([v for k, v in bc_agents.items()])
    new_rl_agents_arr = np.array([v for k, v in new_rl_agents.items()])
    hs = np.array([v for k, v in human_stats.items()])
    mhs = np.array([v for k, v in human_estimate.items()])
    pas = np.array([v for k, v in prev_agents_sum.items()])
    plt.xticks(x, labels)

    plt.plot(x, prev_agents_arr, label='RL fine tuning', marker='o')
    #plt.plot(x, pas, label='RL fine tuning', marker='o')
    plt.plot(x, bc_agents_arr, label='BC', marker='o')
    #plt.plot(x, new_rl_agents_arr, label='new log agents', marker='o')
    #plt.plot(x, mhs, label='human estimate', marker='o')
    plt.plot(x, [1.0 for _, _ in agent1.items()], '--', label='exact consensus path', linewidth=2, color='black')
    plt.plot(x, hs, label='human demonstrations', marker='o')

    plt.yscale('symlog', linthreshy=0.015)
    plt.ylabel('frequency')
    plt.xlabel('consensus')
    plt.ylim(ymin=0)
    plt.xlim(xmin=0)
    #plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc='lower left', ncol=1, mode="expand", borderaxespad=0.)
    plt.legend(bbox_to_anchor=(0.5, -0.5), frameon=False, loc='lower left', ncol=2, mode="expand", borderaxespad=0.)
    plt.tight_layout()
    ax = plt.gca()
    ax.set_facecolor((1, 1, 1))
    ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    plt.hlines(y=0, xmin=0, xmax=len(labels)-1, colors='black', linestyles='-', lw=2)
    plt.savefig('tmp/consensus_metric.eps', dpi=300, transparent=True, format='eps',
                bbox_inches='tight', pad_inches=0.0, frameon=False)
    plt.close(fig)


if __name__ == '__main__':
    plot_consensus_likelihood()
