import numpy as np
import torch
from k_level_policy_gradients.src.policy.policy import ParametricPolicy


class CEMPolicy(ParametricPolicy):
    """
    This policy is commonly used in the Deep Deterministic Policy Gradient
    algorithm.
    """

    def __init__(self, action_space, pi_sigma=0.1, N=64, Ne=6, max_its=2):
        """
        Constructor.
        """
        super().__init__()

        self._action_space = action_space
        self._pi_sigma = pi_sigma
        self._N = N
        self._Ne = Ne
        self._max_its = max_its
        self._mode = "train"  # options are random, train, test

        self._add_save_attr(
            _N="primitive",
            _Ne="primitive",
            _max_its="primitive",
            _action_space="mushroom",
            _mode="primitive",
        )

    def draw_action(self, state):
        if self._mode == "random":
            return self._action_space.sample()
        else:
            state = torch.tensor(state, dtype=torch.float32)
            mu = torch.zeros((self._action_space.shape[0]), dtype=torch.float32)
            std = torch.ones((self._action_space.shape[0]), dtype=torch.float32)
            for it in range(self._max_its):
                dist = torch.distributions.Normal(mu, std)
                actions = dist.sample((self._N,)).detach()
                actions_prime = torch.tanh(actions)
                state_expanded = state.unsqueeze(0).expand(self._N, -1).contiguous()
                qs = self._approximator.predict(
                    state_expanded, actions_prime, output_tensor=True
                )
                topk, topk_idxs = torch.topk(qs, self._Ne, dim=0)
                mu = torch.mean(
                    actions.gather(
                        0, topk_idxs.repeat(1, self._action_space.shape[0]).long()
                    ),
                    dim=0,
                )
                std = torch.std(
                    actions.gather(
                        0, topk_idxs.repeat(1, self._action_space.shape[0]).long()
                    ),
                    dim=0,
                )
            topk, topk_idxs = torch.topk(qs, 1, dim=0)
            action_prime = actions_prime[topk_idxs[0][0]]

            if self._mode == "test":
                return np.clip(
                    action_prime.cpu().numpy(),
                    min=self._action_space.low,
                    max=self._action_space.high,
                )
            elif self._mode == "train":
                # add noise
                x = action_prime.clone().zero_()
                action_prime += self._pi_sigma * x.clone().normal_()
                return np.clip(
                    action_prime.cpu().numpy(),
                    min=self._action_space.low,
                    max=self._action_space.high,
                )
