import gym
import torch
import copy
import numpy as np
import os

from expground.types import DataArray, Dict, Any
from expground.logger import Log
from expground.algorithms import misc
from expground.algorithms.base_policy import Policy
from expground.algorithms.dqn.config import DEFAULT_CONFIG
from expground.common.models import get_model


class DQN(Policy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        model_config: Dict[str, Any],
        custom_config: Dict[str, Any],
    ):
        _model_config = copy.deepcopy(DEFAULT_CONFIG["model_config"])
        _model_config.update(model_config)
        _custom_config = copy.deepcopy(DEFAULT_CONFIG["custom_config"])
        _custom_config.update(custom_config)

        super(DQN, self).__init__(
            observation_space, action_space, _model_config, _custom_config
        )

        assert isinstance(action_space, gym.spaces.Discrete)

        self._critic = get_model(self._model_config["critic"])(
            observation_space=observation_space,
            action_space=action_space,
        )
        self._target_critic = get_model(self._model_config["critic"])(
            observation_space=observation_space, action_space=action_space
        )

        self.use_cuda = self.custom_config.get("use_cuda", False)

        if self.use_cuda:
            self._critic = self._critic.to("cuda")
            self._target_critic = self._target_critic.to("cuda")

        self._eps = 1.0

        self.register_state(self._eps, "_eps")
        self.register_state(self._critic, "critic")
        self.register_state(self._target_critic, "target_critic")

        with torch.no_grad():
            misc.soft_update(self.target_critic, self.critic)

    @property
    def eps(self) -> float:
        return self._eps

    @eps.setter
    def eps(self, value: float):
        self._eps = value

    def compute_action(
        self, observation: DataArray, action_mask: DataArray, evaluate: bool
    ):
        """Compute action in rollout stage. Do not support vector mode yet.

        Args:
            observation (DataArray): The observation batched data with shape=(n_batch, *obs_shape).
            action_mask (DataArray): The action mask batched with shape=(n_batch, *mask_shape).
            evaluate (bool): Turn off exploration or not.
        """

        observation = torch.as_tensor(
            observation, device="cuda" if self.use_cuda else "cpu"
        )
        logits = self.critic(observation)

        # do masking
        if action_mask is not None:
            mask = torch.FloatTensor(action_mask).to(logits.device)
            action_probs = misc.masked_gumbel_softmax(logits, mask)
            assert mask.shape == logits.shape, (mask.shape, logits.shape)
        else:
            action_probs = misc.gumbel_softmax(logits, hard=True)

        # m = torch.distributions.Categorical(probs=action_probs)

        if not evaluate:
            if np.random.random() < self.eps:
                action_probs = (
                    np.ones((len(observation), self._action_space.n))
                    / self._action_space.n
                )
                if action_mask is not None:
                    legal_actions = np.array(
                        [
                            idx
                            for idx in range(self._action_space.n)
                            if action_mask[0][idx] > 0
                        ],
                        dtype=np.int32,
                    )
                    action = np.random.choice(legal_actions, len(observation))
                else:
                    action = np.random.choice(self._action_space.n, len(observation))
                return (
                    action,
                    action_probs,
                    logits.detach().cpu().numpy()  # cannot squeeze
                    # action_probs.detach().to("cpu").numpy(),
                )
        actions = torch.argmax(action_probs, dim=-1)
        # print("sssss", actions.shape, action_probs.shape, logits.shape)
        return (
            actions.detach().cpu().numpy(),
            action_probs.detach().cpu().numpy(),
            logits.detach().cpu().numpy(),
        )

    def to(self, device: str):
        if device is None:
            return self
        cond1 = "cpu" in device and self.use_cuda
        cond2 = "cuda" in device and not self.use_cuda

        if "cpu" in device:
            self.use_cuda = False
        else:
            self.use_cuda = self._custom_config.get("use_cuda", False)

        if cond1 or cond2:
            self._critic = self._critic.to(device)
            self._target_critic = self._target_critic.to(device)

        return self

    def parameters(self):
        return {
            "critic": self._critic.parameters(),
            "target_critic": self._target_critic.parameters(),
        }

    def update_parameters(self, parameter_dict):
        critic_param = self._critic.parameters()
        target_critic_param = self._target_critic.parameters()

        for tparam, sparam in zip(target_critic_param, parameter_dict["target_critic"]):
            tparam.data.copy_(sparam.data)
        for tparam, sparam in zip(critic_param, parameter_dict["critic"]):
            tparam.data.copy_(sparam.data)

    def compute_actions(self, **kwargs):
        raise NotImplementedError

    def value_function(self, states, action_mask=None) -> np.ndarray:
        states = torch.as_tensor(states, device="cuda" if self.use_cuda else "cpu")
        values = self.critic(states).detach().cpu().numpy()
        if action_mask is not None:
            values[action_mask] = -1e9
        return values

    def reset(self, **kwargs):
        # reset parameters
        self._critic.reset()
        self._target_critic.reset()

    def save(self, path, global_step=0, hard: bool = False):
        file_exist = os.path.exists(path)
        if file_exist:
            Log.warning("\t! detected existing mode with path: {}".format(path))
        if (not file_exist) or hard:
            torch.save(self._critic.state_dict(), path)

    def load(self, path: str):
        state_dict = torch.load(path, map_location="cuda" if self.use_cuda else "cpu")
        self._critic.load_state_dict(state_dict)
        misc.soft_update(self._target_critic, self._critic, tau=1.0)
