import abc
from typing import List

import numpy as np

from rlkit.policies.base import ExplorationPolicy
from rlkit.torch.networks.stochastic.distribution_generator import DistributionGenerator
from rlkit.torch.pytorch_util import torch_ify, elem_or_tuple_to_numpy
from rlkit.torch.distributions import Delta, TanhNormal


class TorchStochasticPolicy(
    DistributionGenerator, ExplorationPolicy, metaclass=abc.ABCMeta
):
    def get_action(self, obs_np, time_step=None):
        actions = self.get_actions(obs_np[None], time_step)
        return actions.squeeze(), {}

    def get_actions(
        self,
        obs_np,
        time_step=None,
    ):
        if time_step is None:
            dist = self._get_dist_from_np(obs_np)
        else:
            dist = self._get_dist_from_np(obs_np, time_step)
        actions = dist.sample()
        return elem_or_tuple_to_numpy(actions)

    def _get_dist_from_np(self, *args, **kwargs):
        torch_args = tuple(torch_ify(x) for x in args)
        torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
        dist = self(*torch_args, **torch_kwargs)
        return dist


class MakeDeterministic(TorchStochasticPolicy):
    def __init__(
        self,
        action_distribution_generator: DistributionGenerator,
    ):
        super().__init__()
        self._action_distribution_generator = action_distribution_generator

    def forward(self, *args, **kwargs):
        dist = self._action_distribution_generator.forward(*args, **kwargs)
        return Delta(dist.mle_estimate())


class MakePessimisticDeterministic(TorchStochasticPolicy):
    def __init__(self, trainer):
        super().__init__()
        self.trainer = trainer

    def forward(self, *args, **kwargs):
        # with torch.enable_grad():
        dist: Delta = self.trainer.get_pessimistic_action(*args, **kwargs)
        return dist


class MixedPolicy(TorchStochasticPolicy):
    def __init__(
        self, policies: List[TorchStochasticPolicy], expert_weight: float
    ) -> None:
        super().__init__()
        self.policies = policies
        self.prob = [expert_weight, 1 - expert_weight]

    def forward(self, *args, **kwargs):
        idx = np.random.choice(len(self.policies), size=1, p=self.prob)[0]
        return self.policies[idx].forward(*args, **kwargs)
