import itertools
import pickle
import string

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler

from experiments import args_parser


def plot():
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.size": 18,
        "font.sans-serif": ["Helvetica"]})

    uniform_performance = {
        'SIS-Graphon': {
            'unif-att': 8.818052870102195,
            'rank-att': 20.819854740697597,
            'er': 6.6561950457623915,
        },
        'Investment-Graphon': {
            'unif-att': 31.94247706262201,
            'rank-att': 36.19500591691547,
            'er': 33.726102257573864,
        },
    }

    i = 1
    for game in ['SIS-Graphon', 'Investment-Graphon']:
        clist = itertools.cycle(cycler(color='rbgcmyk'))
        plt.subplot(1, 2, i)
        plt.gca().text(-0.01, 1.06, '(' + string.ascii_lowercase[i-1] + ')', transform=plt.gca().transAxes,
                size=22, weight='bold')
        i += 1

        for graphon in ['unif-att', 'rank-att', 'er']:
            for solver in ['boltzmann']:
                etas = np.arange(0.001, 0.305, 0.005) if game == 'SIS-Graphon' else \
                        np.arange(0.05, 0.3, 0.05) if game == 'Investment-Graphon' else None
                etas = np.concatenate([[0.], etas])
                max_eps = []
                min_eps = []
                mean_eps = []
                graphon_label = 'unif' if graphon == 'unif-att' else \
                    'rank' if graphon == 'rank-att' else \
                        'ER'
                graphon_label += ' graphon'

                for eta in etas:
                    args = args_parser.generate_config_from_kw(**{
                        'game': game,
                        'graphon': graphon,
                        'solver': 'exact' if eta == 0 and solver in ['boltzmann'] else solver,
                        'simulator': 'exact',
                        'evaluator': 'exact',
                        'eval_solver': 'exact',
                        'iterations': 250,
                        'total_iterations': 500,
                        'eta': eta,
                        'results_dir': None,
                        'exp_name': None,
                        'verbose': 0,
                    })
                    with open(args['experiment_directory'] + 'logs.pkl', 'rb') as f:
                        result = pickle.load(f)
                        max_exploitability_last_n = max([result[t]['eval_opt']['eval_mean_returns']
                                                   - result[t]['eval_pi']['eval_mean_returns']
                                                   for t in range(len(result) - 10, len(result))])
                        min_exploitability_last_n = min([result[t]['eval_opt']['eval_mean_returns']
                                                   - result[t]['eval_pi']['eval_mean_returns']
                                                   for t in range(len(result) - 10, len(result))])
                        mean_exploitability_last_n = np.mean([result[t]['eval_opt']['eval_mean_returns']
                                                   - result[t]['eval_pi']['eval_mean_returns']
                                                   for t in range(len(result) - 10, len(result))])

                        max_eps.append(max_exploitability_last_n)
                        min_eps.append(min_exploitability_last_n)
                        mean_eps.append(mean_exploitability_last_n)

                label = r'$\eta$-Boltzmann, ' + graphon_label
                color = clist.__next__()['color']

                plt.plot(etas, max_eps, '--', color=color, label='_nolabel_', alpha=0.5)
                plt.plot(etas, min_eps, '--', color=color, label='_nolabel_', alpha=0.5)
                plt.plot(etas, mean_eps, color=color, label=label, alpha=0.85)
                plt.fill_between(etas, min_eps, max_eps, color=color, alpha=0.15)

                if game in ['SIS-Graphon']:
                    plt.plot(etas, [uniform_performance[game][graphon]] * len(etas), '-.', color=color,
                             label='Uniform policy, ' + graphon_label, alpha=0.85)
        plt.legend()
        plt.grid('on')
        plt.xlabel(r'$\eta$', fontsize=22)
        plt.ylabel(r'$\Delta J(\pi)$', fontsize=22)
        plt.title(game + ', %d iterations' % len(result))

    plt.gcf().set_size_inches(22, 6)
    plt.savefig('./figures/exploitability.pdf', bbox_inches='tight', transparent=True, pad_inches=0)
    plt.savefig('./figures/exploitability.png', bbox_inches='tight', transparent=True, pad_inches=0)
    plt.show()


if __name__ == '__main__':
    plot()
