from functools import partial
from rlf.args import str2bool
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import defaultdict

from rlf.policies.base_policy import ActionData
import rlf.policies.utils as putils
import rlf.rl.utils as rutils
from rlf.policies.actor_critic.base_actor_critic import ActorCritic
import numpy as np

# class DistActorCritic(ActorCritic):
#     """
#     Defines an actor/critic where the actor outputs an action distribution
#     """

#     def __init__(self,
#                  get_actor_fn=None,
#                  get_dist_fn=None,
#                  get_critic_fn=None,
#                  get_critic_head_fn=None,
#                  fuse_states=[],
#                  use_goal=False,
#                  get_base_net_fn=None,
#                  var_statedependent=False):
#         super().__init__(get_critic_fn, get_critic_head_fn, use_goal,
#                 fuse_states, get_base_net_fn)
#         """
#         - get_actor_fn: (obs_space : (int), input_shape : (int) ->
#           rlf.rl.model.BaseNet)
#         """

#         if get_actor_fn is None:
#             get_actor_fn = putils.get_def_actor
#         self.get_actor_fn = get_actor_fn

#         if get_dist_fn is None:
#             get_dist_fn = putils.get_def_dist
#         self.get_dist_fn = get_dist_fn
#         self.var_statedependent = var_statedependent

#     def init(self, obs_space, action_space, args):
#         super().init(obs_space, action_space, args)
#         self.actor = self.get_actor_fn(
#             rutils.get_obs_shape(obs_space, args.policy_ob_key),
#             self._get_base_out_shape())
#         self.dist = self.get_dist_fn(
#             self.actor.output_shape, self.action_space, var_statedependent=self.var_statedependent)

#     def get_action(self, state, add_state, hxs, masks, step_info):
#         dist, value, hxs = self.forward(state, add_state, hxs, masks)
#         if self.args.deterministic_policy:
#             action = dist.mode()
#         else:
#             action = dist.sample()

#         action_log_probs = dist.log_probs(action)
#         dist_entropy = dist.entropy()

#         return ActionData(value, action, action_log_probs, hxs, {
#             'dist_entropy': dist_entropy
#         })

#     def forward(self, state, add_state, hxs, masks):
#         base_features, hxs = self._apply_base_net(state, add_state, hxs, masks)
#         base_features = self._fuse_base_out(base_features, add_state)

#         value = self._get_value_from_features(base_features, hxs, masks)

#         actor_features, _ = self.actor(base_features, hxs, masks)
#         dist = self.dist(actor_features)

#         return dist, value, hxs

#     def evaluate_actions(self, state, add_state, hxs, masks, action):
#         dist, value, hxs = self.forward(state, add_state, hxs, masks)

#         action_log_probs = dist.log_probs(action)
#         dist_entropy = dist.entropy()
#         return {
#             'value': value,
#             'log_prob': action_log_probs,
#             'ent': dist_entropy,
#         }

#     def get_actor_params(self):
#         return super().get_actor_params() + list(self.dist.parameters())

class DistActorCritic(ActorCritic):
    """
    Defines an actor/critic where the actor outputs an action distribution
    """

    def __init__(self,
                 get_actor_fn=None,
                 get_dist_fn=None,
                 get_critic_fn=None,
                 get_critic_head_fn=None,
                 fuse_states=[],
                 use_goal=False,
                 get_base_net_fn=None,
                 var_statedependent=False,
                 add_dropout=False):
        super().__init__(get_critic_fn, get_critic_head_fn, use_goal,
                fuse_states, get_base_net_fn, add_dropout=add_dropout)
        """
        - get_actor_fn: (obs_space : (int), input_shape : (int) ->
          rlf.rl.model.BaseNet)
        """

        if get_actor_fn is None:
            get_actor_fn = partial(putils.get_def_actor, add_dropout=add_dropout)
        self.get_actor_fn = get_actor_fn

        if get_dist_fn is None:
            get_dist_fn = putils.get_def_dist
        self.get_dist_fn = get_dist_fn
        self.var_statedependent = var_statedependent

    def init(self, obs_space, action_space, args):
        super().init(obs_space, action_space, args)
        self.actor = self.get_actor_fn(
            rutils.get_obs_shape(obs_space, args.policy_ob_key),
            self._get_base_out_shape())
        self.dist = self.get_dist_fn(
            self.actor.output_shape, self.action_space, var_statedependent=self.var_statedependent)
        # if the action space is box and the bound is not [-1, 1], scale the action
        self.squash_action = args.squash_action and self.action_space.__class__.__name__ == "Box" \
                             and not np.allclose(self.action_space.low, -1) and not np.allclose(self.action_space.high, 1)

    def scale_action(self, action):
        """
        Rescale the action from [low, high] to [-1, 1]
        (no need for symmetric action space)

        :param action: Action to scale
        :return: Scaled action
        """
        # if action is numpy
        if isinstance(action, np.ndarray):
            low, high = self.action_space.low, self.action_space.high
            return 2.0 * ((action - low) / (high - low)) - 1.0
        # if action is tensor
        elif isinstance(action, torch.Tensor):
            low = torch.tensor(self.action_space.low, dtype=action.dtype, device=action.device)
            high = torch.tensor(self.action_space.high, dtype=action.dtype, device=action.device)
            return 2.0 * ((action - low) / (high - low)) - 1.0
        else:
            raise NotImplementedError


    def unscale_action(self, scaled_action):
        """
        Rescale the action from [-1, 1] to [low, high]
        (no need for symmetric action space)

        :param scaled_action: Action to un-scale
        """
        # if action is numpy
        if isinstance(scaled_action, np.ndarray):
            low, high = self.action_space.low, self.action_space.high
            return low + (0.5 * (scaled_action + 1.0) * (high - low))
        # if action is tensor
        elif isinstance(scaled_action, torch.Tensor):
            low = torch.tensor(self.action_space.low, dtype=scaled_action.dtype, device=scaled_action.device)
            high = torch.tensor(self.action_space.high, dtype=scaled_action.dtype, device=scaled_action.device)
            return low + (0.5 * (scaled_action + 1.0) * (high - low))
        else:
            raise NotImplementedError

    def get_action(self, state, add_state, hxs, masks, step_info):
        dist, value, hxs = self.forward(state, add_state, hxs, masks)
        if self.args.deterministic_policy:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy()

        if self.squash_action:
            action = self.unscale_action(action)
        else: # crop the action to the bound
            # low = torch.tensor(self.action_space.low, dtype=action.dtype, device=action.device)
            # high = torch.tensor(self.action_space.high, dtype=action.dtype, device=action.device)
            # action = torch.clamp(action, low, high)
            pass

        return ActionData(value, action, action_log_probs, hxs, {
            'dist_entropy': dist_entropy
        })

    def forward(self, state, add_state, hxs, masks):
        base_features, hxs = self._apply_base_net(state, add_state, hxs, masks)
        base_features = self._fuse_base_out(base_features, add_state)

        value = self._get_value_from_features(base_features, hxs, masks)

        actor_features, _ = self.actor(base_features, hxs, masks)
        dist = self.dist(actor_features)

        return dist, value, hxs

    def evaluate_actions(self, state, add_state, hxs, masks, action):
        if self.squash_action:
            action = self.scale_action(action)
        
        dist, value, hxs = self.forward(state, add_state, hxs, masks)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy()
        return {
            'value': value,
            'log_prob': action_log_probs,
            'ent': dist_entropy,
        }

    def get_actor_params(self):
        return super().get_actor_params() + list(self.dist.parameters())
    
    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument(f"--squash-action",
                            type=str2bool,
                            default=False,
                            help='whether to squash the action to [-1, 1] for box action space (default: False)')
        parser.add_argument(f"--var-statedependent",
                            type=str2bool,
                            default=True,
                            help='whether to use state-dependent variance for the policy (default: False)')
        

