import os
import sys
import warnings
import argparse
import pickle as pkl
from collections import defaultdict
from itertools import product
from typing import Iterable, Tuple, List, Callable

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sparse import COO

import model.reward_model as reward_model
import model.head as head
from config import Config
from figure_utils import make_board_from_info, standardize_y_axis
from environment_generator import EnvironmentDataset
from figure_utils import board_to_image
from policy.policy_evaluation import policy_evaluation
from model.reward_model import model_to_reward_function
from driving_gridworld.actions import ACTION_NAMES, ACTIONS


class PlotGenerator:
    def __init__(
            self,
            env: EnvironmentDataset,
            models: List[Callable],
            figs_dir: str,
    ):
        """
        :param env: The environment dataset which we'll use for a
        variety of convenience functions, and as a source of truth
        for mapping state indicies to their boards.
        :param models: The models which we'll use to generate
        the reward distributions.
        :param figs_dir: The location to store all of the figures.
        """
        self.env = env
        self.models = models
        self.reward_functions = [model_to_reward_function(model, env) for model in models]
        self.config = Config()
        self.figs_dir = figs_dir

    def add_true_reward_line(
            self,
            state_index: int,
            ax: "Axis",
            c: str = "white",
    ) -> "Axis":
        """
        Draws a line to signify the true reward on some plot which
        is currently on ax.
        :param state_index: The state index for which we want
        to obtain the true rewards for.
        :param c: The color of the line, using matplotlib's setup.
        "white", "black", "blue", etc will probably suffice.
        """
        [ax.plot(
            [action - 0.1, action + 0.1],
            [self.env.true_reward[state_index][action] for _ in range(2)],
            c=c,
            zorder=20,
        ) for action in range(self.env.true_reward.shape[1])]

    def obtain_reward_predictions(
            self,
            state_index: int,
    ):
        config = Config()
        return np.array([
            [reward_function(
                state_index, action).detach().numpy()[0][0] for action in ACTIONS]
            for reward_function in self.reward_functions])

    def create_image_figure(
            self,
            board: "Board",
    ) -> Tuple["Figure", "Axes"]:
        fig, axs = plt.subplots(2, 1)
        axs[0].imshow(board_to_image(board))
        axs[0].axis('off')
        return fig, axs

    def construct_violin_plot(
            self,
            reward_predictions: "Array",
            ax: "Axis",
    ):
        """
        Constructs a violin plot which shows the reward distribution over all models
        for a given board, and all of it's associated info onto ax.
        """
        action_to_pred = pd.DataFrame.from_dict(
            {f'{name}\n{np.mean(reward_predictions[:, action]):.5f}': reward_predictions[:, action]
                        for action, name in enumerate(ACTION_NAMES)})
        actions, rewards = [], []
        for action in range(len(ACTION_NAMES)):
            actions += [action + 0.5] * reward_predictions.shape[0]
            rewards.append(reward_predictions[:, action])
        actions = np.array(actions)
        actions += np.random.uniform(-0.05, 0.05, size=actions.shape)
        rewards = np.concatenate(rewards)
        ax.scatter(x=actions, y=rewards, s=0.1)
        sns.violinplot(data=action_to_pred, ax=ax, dodge=True)
        plt.tight_layout()

    def construct_bar_plot(
            self,
            values: np.array,
            ax: "Axis",
            disp_legend: bool = True,
    ):
        """
        Constructs a bar plot of whatever values are corresponding to the
        given board onto ax.
        :param disp_legend: Whether or not to display the legend.
        """
        width = 0.1
        xs = np.array(range(len(ACTION_NAMES)))
        for i, x in enumerate(values):
            ax.bar(
                xs + width * i,
                x,
                width,
                label=f'{self.config.k[i]}-of-{self.config.n}',
            )

        ax.set_xticklabels([None] + ACTION_NAMES)
        if disp_legend:
            plt.legend(bbox_to_anchor=(1.3, 1))
        plt.tight_layout()

    def state_info_to_board(
            self,
            obstacle_row: int,
            obstacle_col: int,
            car_col: int,
            car_speed: int,
    ) -> "Board":
        obstacles = ([[None, obstacle_row, obstacle_col, None, None]]
                if obstacle_row is not None and obstacle_col is not None else [])
        return make_board_from_info(car_col,
                                     car_speed,
                                     obstacles,
                                     headlight_range=self.config.headlight_range,
                                     )

    def generate_plots_and_save(
            self,
            info: Iterable[Tuple[int, int, int, int]],
            violin_root_dir: str,
            bar_root_dir: str,
            returns_root_dir: str,
    ):
        """
        :param info: An iterable of
        (obstacle_row, obstacle_col, car_col, car_speed) tuples.
        """
        reward_figs = []
        for i, (obstacle_row, obstacle_col, car_col, car_speed) in enumerate(info):
            board = self.state_info_to_board(
                obstacle_row,
                obstacle_col,
                car_col,
                car_speed,
                )
            state_index = self.env.board_to_state_index(board)

            fig, axs = self.create_image_figure(board)
            reward_predictions = self.obtain_reward_predictions(state_index)
            self.add_true_reward_line(state_index, axs[1], c='black')
            self.construct_violin_plot(reward_predictions, axs[1])
            reward_figs.append(
                (fig,
                 f'{violin_root_dir}/obstacle_row_{obstacle_row}_obstacle_col_{obstacle_col}_car_col_{car_col}_car_speed_{car_speed}.png'))

        def standardize_save_close(figs):
            standardize_y_axis([fig for fig, fname in figs])
            [fig.savefig(fname) for fig, fname in figs]
            [plt.close(fig) for fig, fname in figs]

        standardize_save_close(reward_figs)

    def no_obstacle_states(self) -> None:
        info = []
        for car_col in range(4):
            for car_speed in range(self.config.speed_limit + 1):
                info.append((None, None, car_col, car_speed))
        self.generate_plots_and_save(
            info,
            violin_root_dir=f'{self.figs_dir}/violin_plots/no_obstacle',
            bar_root_dir=f'{self.figs_dir}/policy/bar_plots/no_obstacle',
            returns_root_dir=f'{self.figs_dir}/returns/no_obstacle',
        )

    def car_on_road_obstacle_in_ditch(self) -> None:
        info = product(
            range(3),
            [0, 3],
            [1, 2],
            range(self.config.speed_limit + 1),
        )
        self.generate_plots_and_save(
            info,
            f'{self.figs_dir}/violin_plots/car_road_obstacle_ditch',
            f'{self.figs_dir}/policy/bar_plots/car_road_obstacle_ditch',
            f'{self.figs_dir}/returns/car_road_obstacle_ditch',
        )

    def obstacle_car_in_ditch(self) -> None:
        info = []
        for obstacle_row in range(3):
            for obstacle_car_col in [0, 3]:
                for car_speed in range(self.config.speed_limit + 1):
                    info.append(
                        (obstacle_row,
                        obstacle_car_col,
                        obstacle_car_col,
                        car_speed)
                    )
        self.generate_plots_and_save(
            info,
            violin_root_dir=f'{self.figs_dir}/violin_plots/car_ditch_obstacle_ditch',
            bar_root_dir=f'{self.figs_dir}/policy/bar_plots/car_ditch_obstacle_ditch',
            returns_root_dir=f'{self.figs_dir}/returns/car_ditch_obstacle_ditch'
        )

    def obstacle_car_on_road(self) -> None:
        """
        Generates violin plots for all states where
        the obstacle and car are on the road.
        """
        info = product(
            range(3),
            [1, 2],
            [1, 2],
            range(self.config.speed_limit + 1),
        )
        self.generate_plots_and_save(
            info,
            violin_root_dir=f'{self.figs_dir}/violin_plots/car_road_obstacle_road',
            bar_root_dir=f'{self.figs_dir}/policy/bar_plots/car_road_obstacle_road',
            returns_root_dir=f'{self.figs_dir}/returns/car_road_obstacle_road'
        )

    def car_ditch_obstacle_road(self) -> None:
        """
        Generates violin plots for all states where
        the obstacle and car is in the ditch.
        """
        info = product(
            range(3),
            [1, 2],
            [0, 3],
            range(self.config.speed_limit + 1),
        )
        self.generate_plots_and_save(
            info,
            violin_root_dir=f'{self.figs_dir}/violin_plots/car_ditch_obstacle_road',
            bar_root_dir=f'{self.figs_dir}/policy/bar_plots/car_ditch_obstacle_road',
            returns_root_dir=f'{self.figs_dir}/returns/car_ditch_obstacle_road'
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--head', type=bool, default=False)
    parser.add_argument('--no-obstacle', type=bool, default=False)
    parser.add_argument('--car-road-obstacle-ditch', type=bool, default=False)
    parser.add_argument('--car-ditch-obstacle-ditch', type=bool, default=False)
    parser.add_argument('--car-road-obstacle-road', type=bool, default=False)
    parser.add_argument('--car-ditch-obstacle-road', type=bool, default=False)
    args = parser.parse_args()

    config = Config()
    env = EnvironmentDataset.obtain_test_env()

    if args.head:
        native_models = reward_model.load_all_saved_models(amount_to_load=config.head_model_seed_amount)
        state_to_tensor = head.head_state_to_tensor(native_models)
        env.state_to_tensor = state_to_tensor
    models = head.load_all_saved_models() if args.head else reward_model.load_all_saved_models()

    figs_dir = config.head_figs_dir if args.head else config.model_figs_dir
    plot_generator = PlotGenerator(env, models, figs_dir)
    if args.no_obstacle:
        plot_generator.no_obstacle_states()
    if args.car_road_obstacle_ditch:
        plot_generator.car_on_road_obstacle_in_ditch()
    if args.car_ditch_obstacle_ditch:
        plot_generator.obstacle_car_in_ditch()
    if args.car_road_obstacle_road:
        plot_generator.obstacle_car_on_road()
    if args.car_ditch_obstacle_road:
        plot_generator.car_ditch_obstacle_road()
