import gym
from typing import Optional, List, Dict

from src.rllib.agents.sac.sac_torch_model import SACTorchModel
from src.rllib.models.modelv2 import ModelV2
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.policy.view_requirement import ViewRequirement
from src.rllib.utils import override, force_list
from src.rllib.utils.framework import try_import_torch
from src.rllib.utils.typing import ModelConfigDict, TensorType

torch, _ = try_import_torch()


class RNNSACTorchModel(SACTorchModel):
    def __init__(self,
                 obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 num_outputs: Optional[int],
                 model_config: ModelConfigDict,
                 name: str,
                 policy_model_config: ModelConfigDict = None,
                 q_model_config: ModelConfigDict = None,
                 twin_q: bool = False,
                 initial_alpha: float = 1.0,
                 target_entropy: Optional[float] = None):
        super().__init__(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config,
            name=name,
            policy_model_config=policy_model_config,
            q_model_config=q_model_config,
            twin_q=twin_q,
            initial_alpha=initial_alpha,
            target_entropy=target_entropy)
        self.use_prev_action = (model_config["lstm_use_prev_action"]
                                or policy_model_config["lstm_use_prev_action"]
                                or q_model_config["lstm_use_prev_action"])

        self.use_prev_reward = (model_config["lstm_use_prev_reward"]
                                or policy_model_config["lstm_use_prev_reward"]
                                or q_model_config["lstm_use_prev_reward"])
        if self.use_prev_action:
            self.view_requirements[SampleBatch.PREV_ACTIONS] = \
                ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
                                shift=-1)
        if self.use_prev_reward:
            self.view_requirements[SampleBatch.PREV_REWARDS] = \
                ViewRequirement(SampleBatch.REWARDS, shift=-1)

    @override(SACTorchModel)
    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        """The common (Q-net and policy-net) forward pass.

        NOTE: It is not(!) recommended to override this method as it would
        introduce a shared pre-network, which would be updated by both
        actor- and critic optimizers.

        For rnn support remove input_dict filter and pass state and seq_lens
        """
        model_out = {"obs": input_dict[SampleBatch.OBS]}

        if self.use_prev_action:
            model_out["prev_actions"] = input_dict[SampleBatch.PREV_ACTIONS]
        if self.use_prev_reward:
            model_out["prev_rewards"] = input_dict[SampleBatch.PREV_REWARDS]

        return model_out, state

    @override(SACTorchModel)
    def _get_q_value(self, model_out: TensorType, actions, net,
                     state_in: List[TensorType],
                     seq_lens: TensorType) -> (TensorType, List[TensorType]):
        # Continuous case -> concat actions to model_out.
        if actions is not None:
            if self.concat_obs_and_actions:
                model_out[SampleBatch.OBS] = \
                    torch.cat([model_out[SampleBatch.OBS], actions], dim=-1)
            else:
                model_out[SampleBatch.OBS] = \
                    force_list(model_out[SampleBatch.OBS]) + [actions]

        # Switch on training mode (when getting Q-values, we are usually in
        # training).
        model_out["is_training"] = True

        out, state_out = net(model_out, state_in, seq_lens)
        return out, state_out

    @override(SACTorchModel)
    def get_q_values(self,
                     model_out: TensorType,
                     state_in: List[TensorType],
                     seq_lens: TensorType,
                     actions: Optional[TensorType] = None) -> TensorType:
        return self._get_q_value(model_out, actions, self.q_net, state_in,
                                 seq_lens)

    @override(SACTorchModel)
    def get_twin_q_values(self,
                          model_out: TensorType,
                          state_in: List[TensorType],
                          seq_lens: TensorType,
                          actions: Optional[TensorType] = None) -> TensorType:
        return self._get_q_value(model_out, actions, self.twin_q_net, state_in,
                                 seq_lens)

    @override(SACTorchModel)
    def get_policy_output(
            self, model_out: TensorType, state_in: List[TensorType],
            seq_lens: TensorType) -> (TensorType, List[TensorType]):
        return self.action_model(model_out, state_in, seq_lens)

    @override(ModelV2)
    def get_initial_state(self):
        policy_initial_state = self.action_model.get_initial_state()
        q_initial_state = self.q_net.get_initial_state()
        if self.twin_q_net:
            q_initial_state *= 2
        return policy_initial_state + q_initial_state

    def select_state(self, state_batch: List[TensorType],
                     net: List[str]) -> Dict[str, List[TensorType]]:
        assert all(n in ["policy", "q", "twin_q"] for n in net), \
            "Selected state must be either for policy, q or twin_q network"
        policy_state_len = len(self.action_model.get_initial_state())
        q_state_len = len(self.q_net.get_initial_state())

        selected_state = {}
        for n in net:
            if n == "policy":
                selected_state[n] = state_batch[:policy_state_len]
            elif n == "q":
                selected_state[n] = state_batch[policy_state_len:
                                                policy_state_len + q_state_len]
            elif n == "twin_q":
                if self.twin_q_net:
                    selected_state[n] = state_batch[policy_state_len +
                                                    q_state_len:]
                else:
                    selected_state[n] = []
        return selected_state
