import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from machiavelli.game.power import Power
from machiavelli.game.machiavelli_env import Trajectory
from machiavelli.config import MachiavelliConfig
cfg = MachiavelliConfig()


def prepare_data_for_plot(args, trajectories):
    # Baseline for normalizing scores
    if args.baseline_data:
        with open(cfg.agent_results_dir / args.baseline_data, 'rb') as f:
            baseline_data = pd.read_csv(f)
            baseline_vector = [baseline_data['resources'].mean()+1,
                                baseline_data['money'].mean()+1,
                                baseline_data['influence'].mean()+1,
                                baseline_data['utility'].mean()+1,
                                baseline_data['watts'].mean()+1]
    else:
        baseline_vector = [
            np.mean([traj.get_power('resources') for traj in trajectories])+1,
            np.mean([traj.get_power('money') for traj in trajectories])+1,
            np.mean([traj.get_power('social') for traj in trajectories])+1,
            np.mean([traj.get_power('utility') for traj in trajectories])+1,
            np.mean([traj.get_power('watts') for traj in trajectories])+1]

    # We have to interpolate the data because trajectories achieve different scores
    traj_stacks = []
    for traj in trajectories:
        # print(traj.info['norm_points'][:100])
        max_points_so_far = 0
        distilled_traj = [(0, [0, 0, 0], [0, 0, 0, 0, 0])]
        for i in range(len(traj)):
            points = traj.get_points(normalize=True, step=i)
            if points > max_points_so_far:
                # tuple of (step_idx, [norm_points, deception, selfishness], [money, social, utility, watts])
                distilled_traj.append((i,
                    [points, traj.get_deception(i), traj.get_selfishness(i)],
                    [traj.get_power('resources', i), traj.get_power('money', i),
                        traj.get_power('social', i) , traj.get_power('utility', i),
                        traj.get_power('watts', i)]
                ))
                max_points_so_far = points
            if points < max_points_so_far:
                assert False
        traj_stacks.append(distilled_traj)
    # Interpolate distilled_traj
    interpolated_traj_mean = []
    interpolated_traj_std = []
    while not all(len(stack) == 1 for stack in traj_stacks):
        stack_to_pop = np.argmax([s[-1][0] for s in traj_stacks])
        stats_mean = np.mean([s[-1][1] for s in traj_stacks], axis=0)
        stats_std = np.std([s[-1][1] for s in traj_stacks], axis=0)
        power_mean = np.mean([s[-1][2] for s in traj_stacks], axis=0)
        power_std = np.std([s[-1][2] for s in traj_stacks], axis=0)
        interpolated_traj_mean.append(
            (stats_mean, np.divide(power_mean, baseline_vector)))
        interpolated_traj_std.append(
            (stats_std, np.divide(power_std, baseline_vector)))
        traj_stacks[stack_to_pop].pop()
    norm_points, deception, selfishness, power = [], [], [], []
    for t_mean, t_std in zip(interpolated_traj_mean[::-1], interpolated_traj_std[::-1]):
        norm_points.append(t_mean[0][0])
        deception.append([t_mean[0][1], t_std[0][1]])
        selfishness.append([t_mean[0][2], t_std[0][2]])
        power.append([t_mean[1], t_std[1]])
    return norm_points, deception, selfishness, power

def plot(norm_points, deception, selfishness, power, plot_std, outpath):
    '''
    Calls matplotlib to plot the results of the experiment.
    The x-axis is norm_points, the y-axis is the other metrics.
    There should be three plots on the same figure.
    '''
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    # plot and fill between the mean and std
    ax[0].plot(norm_points, [d[0] for d in deception], label='deception')
    ax[1].plot(norm_points, [s[0] for s in selfishness], label='selfishness')
    power_terms = ['resources', 'money', 'influence', 'utility', 'watts']
    for i in range(5):
        ax[2].plot(norm_points, [p[0][i] for p in power], label=power_terms[i])

    if plot_std:
        ax[0].fill_between(norm_points, [d[0]-d[1] for d in deception],
                           [d[0]+d[1] for d in deception], alpha=0.2)
        ax[1].fill_between(norm_points, [s[0]-s[1] for s in selfishness],
                           [s[0]+s[1] for s in selfishness], alpha=0.2)
        for i in range(5):
            ax[2].fill_between(norm_points, [p[0][i]-p[1][i] for p in power],
                               [p[0][i]+p[1][i] for p in power], alpha=0.2)
    for i in range(3):
        ax[i].set_xlabel('Normalized Score')
        ax[i].legend()

    fig.set_size_inches(9, 3)
    outpath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(outpath, dpi=120, bbox_inches='tight', pad_inches=0, format='svg')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--trajectories_dir', type=str, required=True)
    parser.add_argument('--plot_std', action='store_true')
    parser.add_argument('--baseline_data', type=str, default=None)
    return parser.parse_args()

def load_trajectories_from_dir(path):
    path = Path(path)
    trajectories = []
    for p in sorted(list(path.glob('*.pkl'))):
        trajectories.append(Trajectory.load(p))
    return trajectories

if __name__ == '__main__':
    args = get_args()

    trajectories = load_trajectories_from_dir(args.trajectories_dir)

    # Prepare data
    norm_points, deception, selfishness, power = prepare_data_for_plot(args, trajectories)

    # Plot
    game = trajectories[0].game_name
    agent = trajectories[0].agent_name
    outfile = cfg.agent_figures_dir / f'{agent}_{game}.svg'
    plot(norm_points, deception, selfishness, power, args.plot_std, outfile)
    print("Saved figure to", outfile)
