import copy
import os
import gym
import torch
import numpy as np
import torch.nn.functional as F

from torch.distributions import Categorical

from expground.types import DataArray, Dict, Tuple
from expground.logger import Log
from expground.algorithms import misc
from expground.algorithms.base_policy import Policy, Action, ActionDist, Logits
from expground.algorithms.ddpg.config import DEFAULT_CONFIG
from expground.common.models import get_model


class DDPG(Policy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        model_config: Dict,
        custom_config: Dict,
    ):
        # update modeLl_config
        _model_config = copy.deepcopy(DEFAULT_CONFIG["model_config"])
        _model_config.update(model_config)
        _custom_config = copy.deepcopy(DEFAULT_CONFIG["custom_config"])
        _custom_config.update(custom_config)

        super(DDPG, self).__init__(
            observation_space, action_space, _model_config, _custom_config
        )

        self._discrete_action = isinstance(action_space, gym.spaces.Discrete)
        self._actor = get_model(self._model_config["actor"])(
            observation_space=observation_space,
            action_space=action_space,
        )
        self._critic = get_model(self._model_config["critic"])(
            observation_space=gym.spaces.Dict(
                {
                    "observation": observation_space,
                    "action": gym.spaces.Box(low=0.0, high=1.0, shape=(action_space.n,))
                    if self._discrete_action
                    else action_space,
                }
            ),
            action_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)),
        )
        self._target_actor = get_model(self._model_config["actor"])(
            observation_space=observation_space,
            action_space=action_space,
        )
        self._target_critic = get_model(self._model_config["critic"])(
            observation_space=gym.spaces.Dict(
                {
                    "observation": observation_space,
                    "action": gym.spaces.Box(low=0.0, high=1.0, shape=(action_space.n,))
                    if self._discrete_action
                    else action_space,
                }
            ),
            action_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)),
        )
        # sync target and eval
        self._target_critic.load_state_dict(self._critic.state_dict())
        self._target_actor.load_state_dict(self._actor.state_dict())
        self._step_count = 0
        self._exploration_start = self._custom_config.get(
            "start_noise_exploration", 100
        )
        self._use_stochastic_decision = self._custom_config.get(
            "use_stochastic_decision", False
        )
        # Log.info("USE_STOCHASTIC_DECISION: {}".format(self._use_stochastic_decision))

        self.register_state(self._actor, "actor")
        self.register_state(self._critic, "critic")
        self.register_state(self._target_actor, "target_actor")
        self.register_state(self._target_critic, "target_critic")

        with torch.no_grad():
            misc.soft_update(self.target_actor, self.actor)
            misc.soft_update(self.target_critic, self.critic)

    def value_function(self, states):
        return self._critic(states)

    def update_target(self, tau: float = 1.0):
        """Update target critic and actor.

        Args:
            tau (float, optional): Soft migration factor. Defaults to 1.0, which is equivalent to hard update.
        """

        misc.soft_update(self._target_actor, self._actor, tau)
        misc.soft_update(self._target_critic, self._critic, tau)

    def compute_actions(
        self,
        observation,
        use_target: bool = False,
        action_mask=None,
        explore: bool = False,
    ) -> Tuple[Action, ActionDist, Logits]:
        """Compute actions when training. The actor used to compute action
        is determined by `use_target`, i.e., use target actor or not.

        Args:
            observation ([type]): The batched observations tensor
            use_target (bool, optional): Use target actor or not. Defaults to False.
            action_mask ([type], optional): Action mask. Defaults to None.

        Raises:
            NotImplementedError: [description]

        Returns:
            torch.Tensor: A tensor of actions
        """

        if use_target:
            logits = self.target_actor(observation)
        else:
            logits = self.actor(observation)

        # we need to clip logits
        # actions = misc.clip_action(logits, self._action_space, exploration=False, action_mask=action_mask)
        if action_mask is not None:
            logits = misc.masked_logits(logits, action_mask, explore)

        if self._discrete_action:
            pi = misc.gumbel_softmax_sample(
                logits=logits, temperature=1.0, explore=explore
            )
            if not self._use_stochastic_decision:
                one_hot = misc.onehot_from_logits(pi)
                pi = (one_hot - pi).detach() + pi
            actions = pi
        else:
            # actions = misc.clip_action(logits, self._action_space, exploration=False)
            pi = torch.softmax(logits, dim=-1)
            actions = torch.tanh(logits)

        return actions, pi, logits

    def compute_action(
        self, observation: DataArray, action_mask: DataArray, evaluate: bool
    ) -> Tuple[Action, ActionDist, Logits]:
        """Compute action in rollout stage.

        Args:
            observation (DataArray): The observation batched data with shape=(n_batch, *obs_shape)
            action_mask (DataArray): The action mask batched data with shape=(n_batch, *mask_shape)

        Returns:
            [type]: [description]
        """

        with torch.no_grad():
            # gumbel softmax convert to differentiable one-hot
            logits = self._actor(observation)

            if action_mask is not None:
                action_mask = torch.from_numpy(action_mask)

            if action_mask is not None:
                logits = misc.masked_logits(logits, action_mask, explore=not evaluate)

            if self._discrete_action:
                pi = misc.gumbel_softmax_sample(
                    logits,
                    temperature=1.0,
                    explore=not evaluate,
                )
                if not self._use_stochastic_decision:
                    one_hot = misc.onehot_from_logits(pi)
                    pi = (one_hot - pi).detach() + pi
                    action = pi.argmax(-1).item()
                else:
                    handler = Categorical(probs=pi)
                    action = handler.sample()
                pi = pi.numpy()
                assert np.isclose(np.sum(pi), 1.0), (pi, logits, action_mask)
                if action_mask is not None:
                    assert action_mask[action] == 1.0, (action, action_mask, pi, logits)
            else:
                raise NotImplementedError
                logits = torch.tanh(logits)
                action = misc.clip_action(
                    logits, action_space=self._action_space, exploration=not evaluate
                )
                pi = action.detach().numpy()
                action = action.detach().numpy()
                # print("action and pi:", action, pi, logits)
                # input("fff")
        return action, pi, logits.detach().numpy()

    def save(self, path, global_step=0, hard: bool = False):
        file_exist = os.path.exists(path)
        if file_exist:
            Log.warning("\t! detected existing model with path: {}".format(path))
        if (not file_exist) or hard:
            torch.save(
                {
                    "actor": self._actor.state_dict(),
                    "critic": self._critic.state_dict(),
                },
                path,
            )

    def load(self, path: str):
        state_dict = torch.load(path, map_location="cuda" if self.use_cuda else "cpu")
        self._actor.load_state_dict(state_dict=state_dict["actor"])
        self._critic.load_state_dict(state_dict=state_dict["critic"])
        misc.soft_update(self._target_actor, self._actor, tau=1.0)
        misc.soft_update(self._target_critic, self._critic, tau=1.0)
