"""
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
import numpy as np
from collections import OrderedDict
from copy import deepcopy

import torch

import agents.models.robomimic.utils.tensor_utils as TensorUtils
import agents.models.robomimic.utils.obs_utils as ObsUtils
from agents.models.robomimic.config.config import Config
from agents.models.robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE


@register_algo_factory_func("iris")
def algo_config_to_class(algo_config):
    """
    Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs.

    Args:
        algo_config (Config instance): algo config

    Returns:
        algo_class: subclass of Algo
        algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
    """
    pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
    plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner)
    value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value)
    return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls)


class IRIS(HBC, ValueAlgo):
    """
    Implementation of IRIS (https://arxiv.org/abs/1911.05321).
    """
    def __init__(
        self,
        planner_algo_class,
        value_algo_class,
        policy_algo_class,
        algo_config,
        obs_config,
        global_config,
        obs_key_shapes,
        ac_dim,
        device,
    ):
        """
        Args:
            planner_algo_class (Algo class): algo class for the planner

            policy_algo_class (Algo class): algo class for the policy

            algo_config (Config object): instance of Config corresponding to the algo section
                of the config

            obs_config (Config object): instance of Config corresponding to the observation
                section of the config

            global_config (Config object): global training config

            obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes

            ac_dim (int): action dimension

            device: torch device
        """
        self.algo_config = algo_config
        self.obs_config = obs_config
        self.global_config = global_config

        self.ac_dim = ac_dim
        self.device = device

        self._subgoal_step_count = 0  # current step count for deciding when to update subgoal
        self._current_subgoal = None  # latest subgoal
        self._subgoal_update_interval = self.algo_config.subgoal_update_interval  # subgoal update frequency
        self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon
        self._actor_horizon = self.algo_config.actor.rnn.horizon

        self._algo_mode = self.algo_config.mode
        assert self._algo_mode in ["separate", "cascade"]

        self.planner = ValuePlanner(
            planner_algo_class=planner_algo_class,
            value_algo_class=value_algo_class,
            algo_config=algo_config.value_planner,
            obs_config=obs_config.value_planner,
            global_config=global_config,
            obs_key_shapes=obs_key_shapes,
            ac_dim=ac_dim,
            device=device
        )

        self.actor_goal_shapes = self.planner.subgoal_shapes
        assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals"

        # only for the actor: override goal modalities and shapes to match the subgoal set by the planner
        actor_obs_key_shapes = deepcopy(obs_key_shapes)
        # make sure we are not modifying existing observation key shapes
        for k in self.actor_goal_shapes:
            if k in actor_obs_key_shapes:
                assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
        actor_obs_key_shapes.update(self.actor_goal_shapes)

        goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
        for k in self.actor_goal_shapes.keys():
            goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)

        actor_obs_config = deepcopy(obs_config.actor)
        with actor_obs_config.unlocked():
            actor_obs_config["goal"] = Config(**goal_modalities)

        self.actor = policy_algo_class(
            algo_config=algo_config.actor,
            obs_config=actor_obs_config,
            global_config=global_config,
            obs_key_shapes=actor_obs_key_shapes,
            ac_dim=ac_dim,
            device=device
        )

    def process_batch_for_training(self, batch):
        """
        Processes input batch from a data loader to filter out
        relevant information and prepare the batch for training.

        Args:
            batch (dict): dictionary with torch.Tensors sampled
                from a data loader

        Returns:
            input_batch (dict): processed and filtered batch that
                will be used for training 
        """
        input_batch = dict()

        input_batch["planner"] = self.planner.process_batch_for_training(batch)
        input_batch["actor"] = self.actor.process_batch_for_training(batch)

        if self.algo_config.actor_use_random_subgoals:
            # optionally use randomly sampled step between [1, seq_length] as policy goal
            policy_subgoal_indices = torch.randint(
                low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
            goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
            goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
            input_batch["actor"]["goal_obs"] = goal_obs
        else:
            # otherwise, use planner subgoal target as goal for the policy
            input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"]

        # we move to device first before float conversion because image observation modalities will be uint8 -
        # this minimizes the amount of data transferred to GPU
        return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))

    def get_state_value(self, obs_dict, goal_dict=None):
        """
        Get state value outputs.

        Args:
            obs_dict (dict): current observation
            goal_dict (dict): (optional) goal

        Returns:
            value (torch.Tensor): value tensor
        """
        return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)

    def get_state_action_value(self, obs_dict, actions, goal_dict=None):
        """
        Get state-action value outputs.

        Args:
            obs_dict (dict): current observation
            actions (torch.Tensor): action
            goal_dict (dict): (optional) goal

        Returns:
            value (torch.Tensor): value tensor
        """
        return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)
