from typing import Tuple
import torch
import gym
import numpy as np
import copy

import torch.nn.functional as F
from torch import nn

from expground.types import Dict, DataArray, Any, Tensor
from expground.algorithms import misc
from expground.algorithms.base_policy import Policy
from expground.common.distributions import Distribution, make_proba_distribution
from expground.algorithms.ppo.config import DEFAULT_CONFIG
from expground.common.models import get_model


class PPO(Policy):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config: Dict[str, Any],
        custom_config: Dict[str, Any],
    ):
        # update modeLl_config
        _model_config = copy.deepcopy(DEFAULT_CONFIG["model_config"])
        _model_config.update(model_config)
        _custom_config = copy.deepcopy(DEFAULT_CONFIG["custom_config"])
        _custom_config.update(custom_config)

        super(PPO, self).__init__(
            observation_space,
            action_space,
            model_config,
            custom_config,
        )

        self._is_discrete_action = isinstance(action_space, gym.spaces.Discrete)

        self.use_cuda = self.custom_config.get("use_cuda", False)

        self._actor = get_model(self._model_config["actor"])(
            observation_space=observation_space, action_space=action_space
        ).to(device=self.device)
        self._critic_1 = self._gen_critic(
            observation_space,
        )

        self._target_critic_1 = self._gen_critic(
            observation_space,
        )

        self._target_critic_1.load_state_dict(self._critic_1.state_dict())

        self.register_state(self._actor, "actor")
        self.register_state(self._critic_1, "critic1")
        self.register_state(self._target_critic_1, "target_critic1")
        self.register_state(self._is_discrete_action, "is_discrete_action")

        self._action_dist_handler = make_proba_distribution(
            action_space,
            use_sde=self.custom_config.get("use_sed", False),
            dist_kwargs=self.custom_config.get("dist_kwargs", None),
        )

    def _gen_critic(self, observation_space: gym.Space):
        # value funtions for ppo
        return get_model(self._model_config["critic"])(
            observation_space=observation_space,
            action_space=1,
        ).to(device=self.device)

    @property
    def is_discrete_action(self) -> bool:
        return self._is_discrete_action

    @property
    def action_dist_handler(self) -> Distribution:
        return self._action_dist_handler

    @property
    def critic1(self):
        return self._critic_1

    def value_function(self, critic_state: Tensor) -> Tuple[Tensor, Tensor]:
        return self.critic1(critic_state)

    def target_value_function(self, critic_state: Tensor) -> Tuple[Tensor, Tensor]:
        return self._target_critic_1(critic_state)

    def update_target(self, tau: float = 1.0):
        misc.soft_update(self._target_critic_1, self._critic_1, tau)

    def compute_action(
        self, observation: DataArray, action_mask: DataArray, evaluate: bool
    ):
        with torch.no_grad():
            action_logits = self.actor(observation)
            if action_mask is not None:
                action_mask = torch.from_numpy(action_mask)
                action_logits = misc.masked_logits(
                    action_logits, action_mask, not evaluate
                )
            prob = self._action_dist_handler.proba_distribution(action_logits)

            prob = prob.distribution.probs
            action = self._action_dist_handler.sample()

        return (
            action.detach().cpu().item(),
            prob.detach().numpy(),
            action_logits.cpu().detach().numpy(),
        )

    def actor_parameters(self):
        raise NotImplementedError

    def critic_parameters(self):
        raise NotImplementedError

    def to(self, device: str):
        cond1 = "cpu" in device and self.use_cuda
        cond2 = "cuda" in device and not self.use_cuda

        if "cpu" in device:
            self.use_cuda = False
        else:
            self.use_cuda = self._custom_config.get("use_cuda", False)

        if cond1 or cond2:
            self._actor = self._actor.to(device)
            self._critic_1 = self._critic_1.to(device)
            self._target_critic_1 = self._target_critic_1.to(device)

        return self

    def compute_actions(
        self,
        observation: Tensor,
        action_mask=None,
        explore: bool = False,
        with_log_pis: bool = True,
    ):
        # only support discrete for now
        logits = self.actor(observation)
        if action_mask is not None:
            logits = misc.masked_logits(logits, action_mask, explore=explore)
        action, log_pi = self._action_dist_handler.log_prob_from_params(
            logits, deterministic=not explore
        )

        if self._is_discrete_action:
            action = action.reshape((-1, 1))
            action1hot = torch.zeros(logits.shape[0], logits.shape[1])
            action1hot.scatter_(1, action, 1).to(device=action.device)
            action = action1hot

        if with_log_pis:
            log_pi = torch.nn.functional.log_softmax(logits, dim=1)
            log_pi = torch.clamp(log_pi, -25, 1)
            entropy = self._action_dist_handler.entropy()
            return action, log_pi, entropy
        else:
            return action
