"""
This code is adapted from mbrl-lib by Facebook.
See https://github.com/freiberg-roman/mbrl-lib/blob/pddm/mbrl/diagnostics/visualize_model_preds.py.
"""

import argparse
from pathlib import Path
from omegaconf import OmegaConf
import hydra
from typing import Generator, List, Optional, Tuple, cast

import gym.wrappers
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch

import rlkit.samplers as samplers
import rlkit.samplers.util.model_rollout_functions
import rlkit.util as util

import rlkit.policies as policies
from rlkit.policies import MPCPolicy, create_mpc_agent_from_config
from rlkit.policies.base.base import Policy
from rlkit.policies.base.simple import RandomPolicy
from rlkit.envs.env_util import make_env_from_cfg, get_dim
from rlkit.envs.model_env import ModelEnv
from rlkit.util.common import create_dynamics_and_reward_model
from rlkit.torch import pytorch_util as ptu

VisData = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

class Visualizer:
    def __init__(
        self,
        lookahead: int,                         # How many steps to predict into the future using the model
        results_dir: str,                       # The directory where the original experiment was run
        agent_dir: Optional[str],               # The directory where the agent configuration and data is stored
        num_steps: Optional[int] = None,        # How many steps to run for visualization
        num_model_samples: int = 1,             # Number of samples from the model, to visualize uncertainty
        model_subdir: Optional[str] = None,     # Can be used to point to models generated by other diagnostics tools
    ):
        self.lookahead = lookahead
        self.results_path = Path(results_dir)
        self.model_path = self.results_path
        self.vis_path = self.results_path / "diagnostics"
        if model_subdir:
            self.model_path /= model_subdir
            # If model subdir is child of diagnostics, remove "diagnostics" before
            # appending to vis_path. This can happen, for example, if Finetuner
            # generated this model with a model_subdir
            if "diagnostics" in model_subdir:
                model_subdir = Path(model_subdir).name
            self.vis_path /= model_subdir
        Path.mkdir(self.vis_path, parents=True, exist_ok=True)

        self.num_model_samples = num_model_samples
        self.num_steps = num_steps

        # TODO: need to use .yaml file for config
        self.cfg = util.common.load_hydra_cfg(self.results_path)

        self.env, term_func, reward_func = make_env_from_cfg(self.cfg)

        self.reward_func = reward_func

        obs_dim = get_dim(self.env.observation_space)
        action_dim = get_dim(self.env.action_space)
        input_dim = obs_dim + action_dim
        self.obs_preproc = hydra.utils.get_method(self.cfg.env.get('obs_preproc')) \
            if self.cfg.env.get('obs_preproc', False) else None
        self.obs_postproc = hydra.utils.get_method(self.cfg.env.get('obs_postproc')) \
            if self.cfg.env.get('obs_postproc', False) else None
        if self.obs_preproc:
            input_dim = action_dim + self.obs_preproc(np.zeros((1, obs_dim))).shape[-1]

        self.cfg.overrides.dynamics.ensemble_size = self.cfg.algorithm.dynamics.ensemble_size
        self.dynamics_model = create_dynamics_and_reward_model(
            cfg=self.cfg,
            obs_dim=obs_dim,
            input_dim=input_dim,
            action_dim=action_dim,
        )

        """
        For MPC algorithms, we simply need to load trained models (dynamics, reward, value,    
        """
        self.model_env = ModelEnv(
            self.env,
            self.dynamics_model,
            term_func,
            reward_func,
            generator=torch.Generator(ptu.device),
            obs_preproc=self.obs_preproc,
            obs_postproc=self.obs_postproc,
        )

        self.agent: Policy
        if agent_dir is None:
            self.agent = RandomPolicy(self.env.action_space)
        else:
            if agent_dir == 'same_as_exp_dir':
                agent_cfg = util.common.load_hydra_cfg(self.results_path)
            else:
                agent_cfg = util.common.load_hydra_cfg(agent_dir)
            if (
                agent_cfg.algorithm.agent._target_
                == "rlkit.policies.mpc.mpc.MPCPolicy"
            ):
                agent_cfg.algorithm.agent.horizon = lookahead
                agent_cfg.algorithm.agent.num_particles = 1
                agent_cfg.algorithm.agent.population_size = self.num_model_samples
                self.agent = create_mpc_agent_from_config(
                    env=self.model_env,
                    dynamics_model=self.dynamics_model,
                    cfg=agent_cfg,
                    input_dim=input_dim,
                    cache_dir=self.results_path,
                    reward_func=reward_func,
                    obs_preproc=self.obs_preproc,
                    obs_postproc=self.obs_postproc,
                )
            else:
                self.agent = policies.util.load_agent(agent_dir, self.env)

        self.fig = None
        self.axs: List[plt.Axes] = []
        self.lines: List[plt.Line2D] = []
        self.writer = animation.FFMpegWriter(
            fps=15, metadata=dict(artist="Me"), bitrate=1800
        )

        # The total reward obtained while building the visualizationn
        self.total_reward = 0

    def get_obs_rewards_and_actions(
        self, obs: np.ndarray, use_mpc: bool = False
    ) -> VisData:
        """
        Initially, a single observation array is given, from which rollouts in the model environment and the true
        environment occur.
        """
        if use_mpc:
            # When using MPC, rollout model trajectories to see the controller actions
            model_obses, model_rewards, actions = rlkit.samplers.util.model_rollout_functions.rollout_model_env(
                self.model_env,
                obs,
                plan=None,
                agent=self.agent,
                num_samples=self.num_model_samples,
            )
            # Then evaluate in the environment
            real_obses, real_rewards, _ = samplers.util.rollout_plan_mujoco_env(
                cast(gym.wrappers.TimeLimit, self.env),
                obs,
                self.lookahead,
                agent=None,
                plan=actions,
            )
        else:
            # When not using MPC, rollout the agent on the environment and get its actions
            real_obses, real_rewards, actions = samplers.util.rollout_plan_mujoco_env(
                cast(gym.wrappers.TimeLimit, self.env),
                obs,
                self.lookahead,
                agent=self.agent,
            )
            # Then see what the model would predict for this
            model_obses, model_rewards, _ = rlkit.samplers.util.model_rollout_functions.rollout_model_env(
                self.model_env,
                obs,
                agent=None,
                plan=actions,
                num_samples=self.num_model_samples,
            )
        return real_obses, real_rewards, model_obses, model_rewards, actions

    def vis_rollout(self, use_mpc: bool = False) -> Generator:
        obs = self.env.reset()      # A single observation array returned from the true environment
        done = False
        i = 0
        while not done:
            vis_data = self.get_obs_rewards_and_actions(obs, use_mpc=use_mpc)
            action, _ = self.agent.get_action(obs)
            next_obs, reward, done, _ = self.env.step(action)
            self.total_reward += reward
            obs = next_obs
            i += 1
            if self.num_steps and i == self.num_steps:
                break

            yield vis_data

    def set_data_lines_idx(
        self,
        plot_idx: int,
        data_idx: int,
        real_data: np.ndarray,
        model_data: np.ndarray,
    ):
        def adjust_ylim(ax, array):
            ymin, ymax = ax.get_ylim()
            real_ymin = array.min() - 0.5 * np.abs(array.min())
            real_ymax = array.max() + 0.5 * np.abs(array.max())
            if real_ymin < ymin or real_ymax > ymax:
                self.axs[plot_idx].set_ylim(min(ymin, real_ymin), max(ymax, real_ymax))
                self.axs[plot_idx].figure.canvas.draw()

        def fix_array_len(array):
            """
            When a rollout has terminated before reaching self.lookahead, the array's length will be shorter.
            Append the last observation to fill the remaining entries.
            """
            if len(array) < self.lookahead + 1:
                new_array = np.ones((self.lookahead + 1,) + tuple(array.shape[1:]))
                new_array *= array[-1]
                new_array[: len(array)] = array
                return new_array
            return array

        x_data = range(self.lookahead + 1)
        # Handle cases when the state space dimension is 1
        if real_data.ndim == 1:
            real_data = real_data[:, None]
        if model_data.ndim == 2:
            model_data = model_data[:, :, None]

        # Handle cases of shorter rollouts
        real_data = fix_array_len(real_data)
        model_data = fix_array_len(model_data)

        # Plot data
        adjust_ylim(self.axs[plot_idx], real_data[:, data_idx])
        adjust_ylim(self.axs[plot_idx], model_data.mean(1)[:, data_idx])

        self.lines[4 * plot_idx].set_data(x_data, real_data[:, data_idx])
        model_obs_mean = model_data[:, :, data_idx].mean(axis=1)
        model_obs_min = model_data[:, :, data_idx].min(axis=1)
        model_obs_max = model_data[:, :, data_idx].max(axis=1)
        self.lines[4 * plot_idx + 1].set_data(x_data, model_obs_mean)
        self.lines[4 * plot_idx + 2].set_data(x_data, model_obs_min)
        self.lines[4 * plot_idx + 3].set_data(x_data, model_obs_max)

    def plot_func(self, data: VisData):
        real_obses, real_rewards, model_obses, model_rewards, actions = data

        num_plots = len(real_obses[0]) + 1                              # num state dimension + reward dimension
        assert len(self.lines) == 4 * num_plots                         # 4 lines will be plotted per plot
        for i in range(num_plots - 1):                                  # Plot state observation rollouts
            self.set_data_lines_idx(i, i, real_obses, model_obses)
        self.set_data_lines_idx(num_plots - 1, 0, real_rewards, model_rewards)

        return self.lines

    def create_axes(self):
        num_plots = self.env.observation_space.shape[0] + 1
        num_cols = int(np.ceil(np.sqrt(num_plots)))
        num_rows = int(np.ceil(num_plots / num_cols))

        fig, axs = plt.subplots(num_rows, num_cols)
        fig.text(
            0.5, 0.04, f"Time step (lookahead of {self.lookahead} steps)", ha="center"
        )
        fig.text(
            0.04,
            0.17,
            "Predictions (blue/red) and ground truth (black).",
            ha="center",
            rotation="vertical",
        )

        axs = axs.reshape(-1)
        lines = []
        for i, ax in enumerate(axs):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.set_xlim(0, self.lookahead)
            if i < num_plots:
                (real_line,) = ax.plot([], [], "k")
                (model_mean_line,) = ax.plot([], [], "r" if i == num_plots - 1 else "b")
                (model_ub_line,) = ax.plot(
                    [], [], "r" if i == num_plots - 1 else "b", linewidth=0.5
                )
                (model_lb_line,) = ax.plot(
                    [], [], "r" if i == num_plots - 1 else "b", linewidth=0.5
                )
                lines.append(real_line)
                lines.append(model_mean_line)
                lines.append(model_lb_line)
                lines.append(model_ub_line)

        self.fig = fig

        self.axs = axs
        self.lines = lines

    def run(self, use_mpc: bool):
        # Set up device here


        self.create_axes()
        ani = animation.FuncAnimation(
            self.fig,
            self.plot_func,
            frames=lambda: self.vis_rollout(use_mpc=use_mpc),
            blit=True,
            interval=100,
            save_count=self.num_steps,
            repeat=False,
        )
        save_path = self.vis_path / f"rollout_{type(self.agent).__name__}_policy.mp4"
        ani.save(save_path, writer=self.writer)
        print(f"Video saved at {save_path}.")
        print(f"Total rewards obtained was: {self.total_reward}.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--experiments_dir",
        type=str,
        default=None,
        help="The directory where the original experiment was run.",
    )
    parser.add_argument("--lookahead", type=int, default=25)
    parser.add_argument(                    # TODO: which agent config to refer to?
        "--agent_dir",
        type=str,
        default=None,
        help="The directory where the agent configuration and data is stored. "
        "If not provided, a random agent will be used.",
    )
    parser.add_argument("--num_steps", type=int, default=200)
    parser.add_argument(
        "--model_subdir",
        type=str,
        default=None,
        help="Can be used to point to models generated by other diagnostics tools.",
    )
    parser.add_argument(
        "--num_model_samples",
        type=int,
        default=35,
        help="Number of samples from the model, to visualize uncertainty.",
    )
    args = parser.parse_args()

    # Set up the gpu mode
    gpu_mode = torch.cuda.is_available()
    ptu.set_gpu_mode(gpu_mode)

    # Instantiate the visualizer
    visualizer = Visualizer(
        lookahead=args.lookahead,
        results_dir=args.experiments_dir,
        agent_dir=args.agent_dir,
        num_steps=args.num_steps,
        num_model_samples=args.num_model_samples,
        model_subdir=args.model_subdir,
    )
    use_mpc = isinstance(visualizer.agent, MPCPolicy)

    # Run the visualization
    visualizer.run(use_mpc=use_mpc)
