import argparse
import warnings
import os
import pickle as pkl
from typing import List, Callable

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sparse import COO
from matplotlib import animation

from config import Config
from environment_generator import EnvironmentDataset
from figure_utils import board_to_image
from state_violins import PlotGenerator
from policy.policy_evaluation import policy_evaluation
from model import reward_model
from model import head
from model.reward_model import load_all_saved_models, model_to_reward_function
from driving_gridworld.actions import ACTION_NAMES, LEFT, RIGHT
from driving_gridworld.road import Road

def obtain_average_speed(
        pi: np.array,
        env: EnvironmentDataset,
):
    """
    Obtain the average speed given the policy.
    """
    prob_trans_mat = COO(np.array(env.prob_trans_mat))
    config =  Config()
    def speed_reward(state_index, action):
        board = env.obtain_board_representation(state_index)
        board, speed, _, _ = env.board_to_state(board)
        speed = speed.numpy()
        return np.argmax(speed)

    return policy_evaluation(
        pi,
        prob_trans_mat,
        speed_reward
    ) * (1 - config.gamma)

def obtain_average_distance_travelled(
        pi: np.array,
        env: EnvironmentDataset,
):
    """
    Obtain the average distance travelled given the policy.
    """
    prob_trans_mat = COO(np.array(env.prob_trans_mat))
    config =  Config()
    def distance_reward(state_index, action):
        board = env.obtain_board_representation(state_index)
        board, speed, _, _ = env.board_to_state(board)
        speed = np.argmax(speed.numpy())
        return max(speed - (action == LEFT or action == RIGHT), 0)

    return policy_evaluation(
        pi,
        prob_trans_mat,
        distance_reward
    ) * (1 - config.gamma)

def obtain_average_collision_rate(
        pi: np.array,
        env: EnvironmentDataset,
):
    """
    Obtain the average collision rate given the policy.
    Assumes that the only obstacles are Bump obstacles.
    See driving_gridworld/road/safety_information for the 2 index choice.
    """
    prob_trans_mat = COO(np.array(env.prob_trans_mat))
    config =  Config()
    safety_info, _ = env.road.safety_information()
    def collision_reward(state_index, action):
        exp_val = 0
        for next_state, prob in enumerate(
                env.prob_trans_mat[state_index][action]):
            if prob > 0:
                exp_val += prob * safety_info[state_index][action][next_state][2]
        return exp_val

    return policy_evaluation(
        pi,
        prob_trans_mat,
        collision_reward
    ) * (1 - config.gamma)

def obtain_average_collision_rate_speed(
        pi: np.array,
        env: EnvironmentDataset,
):
    """
    Obtain the average collision rate given the policy.
    Assumes that the only obstacles are Bump obstacles.
    See driving_gridworld/road/safety_information for the 2 index choice.
    """
    prob_trans_mat = COO(np.array(env.prob_trans_mat))
    config =  Config()
    safety_info, _ = env.road.safety_information()
    def collision_reward(state_index, action):
        exp_val = 0
        for next_state, prob in enumerate(
                env.prob_trans_mat[state_index][action]):
            if prob > 0:
                exp_val += prob * safety_info[state_index][action][next_state][2] \
                    * safety_info[state_index][action][next_state][4]
        return exp_val

    return policy_evaluation(
        pi,
        prob_trans_mat,
        collision_reward
    ) * (1 - config.gamma)

def obtain_true_return(
        pi: np.array,
        env: EnvironmentDataset,) -> np.array:
    """
    Obtain the true return given the policy.
    """
    prob_trans_mat = COO(np.array(env.prob_trans_mat))
    def true_reward(state_index, action):
        return env.true_reward[state_index][action]
    return policy_evaluation(
        pi,
        prob_trans_mat,
        true_reward,
    )

def obtain_gif(
        pi: np.array,
        env: EnvironmentDataset,
        save_loc: str,
        num_frames: int,
        use_head_models: bool,
) -> None:
    """
    :param pi: The policy to record.
    :param env: The environment that the policy is acting on.
    :param save_loc: The location to save the gif.
    :param num_frames: The amount of frames to obtain for the gif.
    :param use_head_models: Whether or not this is for the head models.
    """
    warnings.warn("The starting state is hardcoded to be 0 here; should make \
that a configurable, or have the source of truth be the env.")
    warnings.warn("Kind of a hacky way of maintaining the current state; \
there has to be a better solution...")
    config = Config()
    models = (head.load_all_saved_models()
              if use_head_models else reward_model.load_all_saved_models())
    plot_generator = PlotGenerator(
        env,
        models,
        config.head_figs_dir if use_head_models else config.model_figs_dir,
    )
    q = policy_evaluation(
        pi,
        COO(np.array(env.prob_trans_mat)),
        env.true_reward_callable,
    )

    fig, axs = plt.subplots(2, 2)
    fig.set_size_inches(10, 7)
    current_state = [0]
    def update(i):
        board = env.obtain_board_representation(current_state[0])
        axs[0][0].imshow(board_to_image(board))
        action = np.random.choice(np.arange(pi.shape[1]), p=pi[current_state[0]])
        axs[0][1].clear(), axs[1][0].clear(), axs[1][1].clear()
        plot_generator.construct_violin_plot(
                plot_generator.obtain_reward_predictions(current_state[0]),
                axs[0][1],
                )
        axs[1][0].bar(np.arange(pi.shape[1]), pi[current_state[0]])
        axs[1][1].bar(np.arange(pi.shape[1]), q[current_state[0]])
        axs[1][0].set_xticklabels([None] + ACTION_NAMES)
        axs[1][1].set_xticklabels([None] + ACTION_NAMES)
        next_state = np.random.choice(np.arange(pi.shape[0]),
                                      p=env.prob_trans_mat[current_state[0]][action])
        current_state.pop(), current_state.append(next_state)

    writer = animation.writers['ffmpeg'](fps=1)
    anim = animation.FuncAnimation(fig, update, frames=np.arange(num_frames), interval=200)
    anim.save(save_loc, writer=writer)

def obtain_percentile_plot(pis):
    config = Config()
    for j, pi in enumerate(pis):
        qs = []
        for model in tqdm(
                load_all_saved_models(config.models_dir)):
            qs.append(policy_evaluation(
                pi,
                prob_trans_mat,
                model_to_reward_function(model, env),
            ))
        xs, ys = [], []
        for i, q in enumerate(qs):
            xs.append(i / len(qs))
            ys.append(np.sum(pi[0] * q[0]) / len(qs))
        ys = sorted(ys)
        plt.plot(xs, ys, label=f'{config.k[j]}-of-{config.n}')
    plt.legend()
    plt.tight_layout()
    pkl.dump((xs, ys), open(f'{config.pkls_git_dir}/percentiles.pkl', 'wb'))
    plt.savefig(f'{config.model_figs_dir}/percentiles.png', dpi=300)
    plt.clf()

def obtain_barplot(
        pis: List[List["Policy"]],
        metric: Callable[["Policy", "Environment"], float],
        title: str,
        figname: str,
        sampled_pis: List["Policy"],
):
    """
    Generates a simple barplot with error bars using the metric given.
    Uses the given title and saves it in config.model_figs_dir/{figname}.
    """
    config = Config()
    env = EnvironmentDataset.obtain_test_env()
    metric_per_seed = [[] for _ in range(len(pis))]
    for pi_list in pis:
        for i, pi in enumerate(pi_list):
            metric_per_seed[i].append(np.sum(pi[0] * metric(pi, env)[0]))

    sampled_pi_metrics = [metric(pi, env) for pi in tqdm(sampled_pis)]
    sampled_mean = np.mean(sampled_pi_metrics)
    sampled_std = np.std(sampled_pi_metrics)

    for i, metric_list in enumerate(metric_per_seed):
        if len(metric_list):
            metric_list.append(sampled_mean)
            bars = plt.bar(np.arange(len(metric_list)) + (i - 2) / len(pis),
                           metric_list,
                           yerr = [0] * 6 + [sampled_std] if i == 2 else [0] * 7,
                           width=1/len(pis))
            bars[-2].set_color('C0')


    if config.barplot_title:
        plt.title(title)
    plt.gca().set_xticklabels([None] + config.k[:-1] + ["Averaged", "Sampled"])
    pkl.dump(metric_list, open(f'{config.pkls_git_dir}/{figname}.pkl', 'wb'))
    plt.savefig(f'{config.model_figs_dir}/{figname}.png', dpi=300)
    plt.clf()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', type=bool, default=False)
    parser.add_argument('--use-true-reward', type=bool, default=False)
    parser.add_argument('--percentile', type=bool, default=False)
    parser.add_argument('--gif', type=bool, default=False)
    parser.add_argument('--speed', type=bool, default=False)
    parser.add_argument('--average-return', type=bool, default=False)
    parser.add_argument('--distance', type=bool, default=False)
    parser.add_argument('--collision', type=bool, default=False)
    parser.add_argument('--collision-speed', type=bool, default=False)
    parser.add_argument('--collision-given-speed', type=bool, default=False)
    parser.add_argument('--speed-given-collision', type=bool, default=False)
    parser.add_argument('--seeds', type=int, default=[0], nargs='+')

    args = parser.parse_args()
    config = Config()
    env = (EnvironmentDataset.obtain_test_env() if args.test
           else EnvironmentDataset.obtain_train_env())
    prob_trans_mat = np.array(env.prob_trans_mat)
    if args.use_true_reward:
        config.k, config.n = [1], 1


    pis = []
    for k in config.k:
        if k == config.n:
            continue

        pi_list = []
        for seed in args.seeds:
            policy_type = ('true_reward' if args.use_true_reward else f'{k}_of_{config.n}') \
                + ("" if args.test else "_train")
            pi = np.load(f'{config.policies_dir}/{policy_type}_{seed}.npy')
            pi_list.append(pi)
        pis.append(pi_list)

        if args.gif:
            obtain_gif(
                pi,
                env,
                os.path.join(config.model_figs_dir, f'policy/{policy_type}.gif'),
                config.num_gif_frames,
                False)

    pis.append([np.load(f'{config.policies_dir}/averaged.npy')] * 5)
    sampled_pis = []
    for x in os.listdir(f'{config.policies_dir}/sampled'):
        sampled_pis.append(np.load(os.path.join(f'{config.policies_dir}/sampled/{x}')))

    if args.speed:
        obtain_barplot(pis,
                       obtain_average_speed,
                       'Average speed value per timestep',
                       'speeds',
                       sampled_pis,
        )
    if args.distance:
        obtain_barplot(pis,
                       obtain_average_distance_travelled,
                       'Average distance travelled per timestep',
                       'distance_travelled',
                       sampled_pis,
        )
    if args.average_return:
        obtain_barplot(pis,
                       obtain_true_return,
                       'True return on environment',
                       'true_returns',
                       sampled_pis,
        )
    if args.collision:
        obtain_barplot(pis,
                       obtain_average_collision_rate,
                       'Average collisions per timestep',
                       'collision',
                       sampled_pis,
        )
    if args.collision_speed:
        obtain_barplot(pis,
                       obtain_average_collision_rate_speed,
                       'Average collisions per timestep weighed by speed',
                       'collision_speed',
                       sampled_pis,
        )
    if args.collision_given_speed:
        obtain_barplot(pis,
                       lambda pi, env: obtain_average_collision_rate_speed(pi, env) \
                       / obtain_average_speed(pi, env),
                       'Average collisions per timestep given speed',
                       'collision_given_speed',
                       sampled_pis,
        )
    if args.speed_given_collision:
        obtain_barplot(pis,
                       lambda pi, env: obtain_average_collision_rate_speed(pi, env) \
                       / obtain_average_collision_rate(pi, env),
                       'Average speed per timestep given collision occurs',
                       'speed_given_collision',
                       sampled_pis,
        )
    if args.percentile:
        obtain_percentile_plot(pis)

