"""
Policy function for SAC
Reference: https://github.com/haarnoja/sac/blob/master/sac/algos/sac.py
"""

from typing import Tuple
import torch
import gym
import numpy as np
import copy

import torch.nn.functional as F
from torch import nn

from expground.types import Dict, DataArray, Any, Tensor
from expground.algorithms import misc
from expground.algorithms.base_policy import Policy
from expground.common.distributions import Distribution, make_proba_distribution
from expground.algorithms.sac.config import DEFAULT_CONFIG
from expground.common.models import get_model

# Module for auto alpha.
class SAC_NN_MODULE(nn.Module):
    def __init__(self, init_alpha, device, learnable=False):
        super(SAC_NN_MODULE, self).__init__()

        requires_grad = False
        if learnable:
            requires_grad = True
        _log_alpha = torch.ones(1, device=device) * np.log(init_alpha)
        self._log_alpha = nn.Parameter(_log_alpha, requires_grad=requires_grad)

    def forward(self):
        return self._log_alpha


class SAC(Policy):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config: Dict[str, Any],
        custom_config: Dict[str, Any],
    ):
        # update modeLl_config
        _model_config = copy.deepcopy(DEFAULT_CONFIG["model_config"])
        _model_config.update(model_config)
        _custom_config = copy.deepcopy(DEFAULT_CONFIG["custom_config"])
        _custom_config.update(custom_config)

        super(SAC, self).__init__(
            observation_space,
            action_space,
            model_config,
            custom_config,
        )

        self._is_discrete_action = isinstance(action_space, gym.spaces.Discrete)

        self._actor = get_model(self._model_config["actor"])(
            observation_space=observation_space, action_space=action_space
        ).to(device=self.device)
        self._critic_1 = self._gen_critic(observation_space, action_space).to(
            device=self.device
        )
        self._critic_2 = self._gen_critic(observation_space, action_space).to(
            device=self.device
        )

        self._target_critic_1 = self._gen_critic(observation_space, action_space).to(
            device=self.device
        )
        self._target_critic_2 = self._gen_critic(observation_space, action_space).to(
            device=self.device
        )

        self._target_critic_1.load_state_dict(self._critic_1.state_dict())
        self._target_critic_2.load_state_dict(self._critic_2.state_dict())

        init_alpha = 1.0
        if _model_config.get("initial_alpha") is not None:
            init_alpha = _model_config.get("initial_alpha")
        # self._log_alpha = torch.ones(1, device=self.device) * np.log(init_alpha)
        self._log_alpha = SAC_NN_MODULE(
            init_alpha, self.device, _model_config.get("auto_alpha")
        )

        if _model_config.get("auto_alpha"):
            if _model_config.get("target_entropy") is None:
                if self._is_discrete_action:
                    self._target_entropy = 0.98 * -np.log(1.0 / action_space.n)
                else:
                    self._target_entropy = -np.prod(action_space.shape).item()
            else:
                self._target_entropy = _model_config["target_entropy"]
            # self._log_alpha.requires_grad = True
            # self.register_state(self._target_entropy, "_target_entropy")
        self.register_state(self._log_alpha, "_log_alpha")

        self.register_state(self._actor, "actor")
        self.register_state(self._critic_1, "critic1")
        self.register_state(self._critic_2, "critic2")
        self.register_state(self._target_critic_1, "target_critic1")
        self.register_state(self._target_critic_2, "target_critic2")
        # self.register_state(self._is_discrete_action, "is_discrete_action")

        self._action_dist_handler = make_proba_distribution(
            action_space,
            use_sde=self.custom_config.get("use_sed", False),
            dist_kwargs=self.custom_config.get("dist_kwargs", None),
        )

    def _gen_critic(self, observation_space: gym.Space, action_space: gym.Space):
        if self._is_discrete_action:
            return get_model(self._model_config["critic"])(
                observation_space=observation_space,
                action_space=action_space,
            ).to(device=self.device)
        else:
            return get_model(self._model_config["critic"])(
                observation_space=gym.spaces.Dict(
                    {"observation": observation_space, "action": action_space}
                ),
                action_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)),
            ).to(device=self.device)

    @property
    def is_discrete_action(self) -> bool:
        return self._is_discrete_action

    @property
    def action_dist_handler(self) -> Distribution:
        return self._action_dist_handler

    @property
    def critic2(self):
        return self._critic_2

    @property
    def critic1(self):
        return self._critic_1

    def value_function(self, critic_state: Tensor) -> Tuple[Tensor, Tensor]:
        return self.critic1(critic_state), self.critic2(critic_state)

    def target_value_function(self, critic_state: Tensor) -> Tuple[Tensor, Tensor]:
        return self._target_critic_1(critic_state), self._target_critic_2(critic_state)

    def update_target(self, tau: float = 1.0):
        misc.soft_update(self._target_critic_1, self._critic_1, tau)
        misc.soft_update(self._target_critic_2, self._critic_2, tau)

    def check_nan(self):
        sd = self.state_dict()
        for k in sd:
            for d in sd[k]:
                if torch.isnan(sd[k][d]).max():
                    print("find nan in policy param: ", k, d, sd[k][d].shape)

    def compute_action(
        self, observation: DataArray, action_mask: DataArray, evaluate: bool
    ):
        with torch.no_grad():
            action_logits = self.actor(observation)
            if action_mask is not None:
                action_mask = torch.from_numpy(action_mask)
                action_logits = misc.masked_logits(
                    action_logits, action_mask, not evaluate
                )
            try:
                prob = self._action_dist_handler.proba_distribution(action_logits)
            except ValueError as e:
                print("catch error when compute action:", action_logits)
                prob = self._action_dist_handler.proba_distribution(action_logits)

            prob = prob.distribution.probs
            action = self._action_dist_handler.sample()

        return action.detach().cpu().item(), prob.detach().numpy()

    def compute_actions(
        self,
        observation: Tensor,
        action_mask=None,
        explore: bool = False,
        with_log_pis: bool = True,
    ):
        # only support discrete for now
        logits = self.actor(observation)
        if action_mask is not None:
            logits = misc.masked_logits(logits, action_mask, explore=explore)
        action, log_pi = self._action_dist_handler.log_prob_from_params(
            logits, deterministic=not explore
        )

        if self._is_discrete_action:
            action = action.reshape((-1, 1))
            action1hot = torch.zeros(logits.shape[0], logits.shape[1])
            action1hot.scatter_(1, action, 1).to(device=action.device)
            action = action1hot
        if with_log_pis:
            log_pi = torch.nn.functional.log_softmax(logits, dim=1)
            log_pi = torch.clamp(log_pi, -25, 1)
            return action, log_pi
        else:
            return action
