import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from collections import deque

import exp_utils as PQ
import rl_utils


class ExplorationPolicy(nn.Module, rl_utils.BasePolicy):
    def __init__(self, policy, safe_invariant):
        super().__init__()
        self.policy = policy
        self.safe_invariant = safe_invariant
        self.last_L = 0
        self.last_U = 0
        self.n_opt_failure = 0
        self.n_uncertainty_failure = 0
        self.traj = deque(maxlen=1000)

    @torch.no_grad()
    def forward(self, states: torch.Tensor):
        device = states.device
        assert len(states) == 1
        dist = self.policy(states)
        assert isinstance(dist, rl_utils.distributions.TanhGaussian)

        mean, std = dist.mean, dist.stddev
        last_L, last_U = self.last_L, self.last_U
        current_L = self.last_L = self.safe_invariant.L(states).item()

        n = 100
        states = states.repeat([n, 1])
        decay = torch.logspace(0, -3, n, base=10., device=device)
        actions = (mean + torch.randn([n, *mean.shape[1:]], device=device) * std * decay[:, None]).tanh()
        all_U = self.safe_invariant.U(states, actions).detach().cpu().numpy()
        if np.min(all_U) <= 1:
            index = np.min(np.where(all_U <= 1)[0])
            action = actions[index]
            current_U = self.last_U = all_U[index]
            PQ.meters['expl/backup'] += index

            # if index != 0:
            #     print(index, decay[index].item(), mean.cpu().numpy(), actions[index].cpu().numpy())
        else:
            action = self.safe_invariant.policy(states[0])
            current_U = self.last_U = self.safe_invariant.U(states[0], action).item()
            PQ.meters['expl/backup'] += n
            # PQ.log.warning("backup!")

        self.traj.append((current_L, current_U))

        if current_U > 1:
            if current_L < 1 < current_U:
                self.n_opt_failure += 1
                failure_type = 'compatibility'
            elif last_U < 1 < current_L:
                self.n_uncertainty_failure += 1
                failure_type = 'uncertainty'
            else:
                failure_type = 'unknown'

            if failure_type != 'unknown':
                PQ.log.info(f"[explore] can't sample an action. "
                            f"[L = {last_L:.6f}, U = {last_U:.6f}] => [L = {current_L:.6f}, U = {current_U:.6f}], "
                            f"failure type = {failure_type}")
        return action

    @rl_utils.torch_utils.maybe_numpy
    def get_actions(self, states):
        return self(states)
