import torch.nn.functional as F
from torch import nn
import torch

from erl_lib.util.misc import weight_init
from erl_lib.agent.module.actor.diag_gaussian import DiagGaussianActor, SquashedNormal
from erl_lib.agent.module.layer import EnsembleLinearLayer


# class TwoHeadsDGA(DiagGaussianActor):
#     """Two-heads diagonal Gaussian policy."""
#
#     def __init__(self, dim_obs, dim_act, dim_hidden, num_hidden):
#         super().__init__(dim_obs, dim_act * 2, dim_hidden, num_hidden)
#         self.info = {}
#
#     def forward(self, obs, log=False, optimistic=False):
#         mu, raw_log_std, mu_e, raw_log_std_e = self.hidden_layers(obs).chunk(4, dim=-1)
#         if optimistic:
#             # mu = mu + mu_e
#             # raw_log_std = raw_log_std + raw_log_std_e
#             mu = mu.detach() + mu_e
#             raw_log_std = raw_log_std.detach() + raw_log_std_e
#
#         std = F.softplus(raw_log_std) + 1e-5
#         dist = SquashedNormal(mu, std)
#
#         if log:
#             with torch.no_grad():
#                 q005, q095 = torch.quantile(
#                     raw_log_std,
#                     torch.tensor([0.05, 0.95], device=obs.device),
#                 )
#                 info = dict(
#                     actor_raw_logstd_005=q005,
#                     actor_raw_logstd_095=q095,
#                     actor_raw_scale=std.mean(),
#                 )
#                 if optimistic:
#                     info = {f"v_{key}": value for key, value in info.items()}
#                 self.info.update(**info)
#         return dist


class TwoHeadsDGA(DiagGaussianActor):
    """Two-heads diagonal Gaussian policy."""

    def __init__(self, dim_obs, dim_act, dim_hidden, num_hidden):
        nn.Module.__init__(self)
        layers = [nn.Linear(dim_obs, dim_hidden), nn.SiLU()]
        for i in range(num_hidden - 1):
            layers += [nn.Linear(dim_hidden, dim_hidden), nn.SiLU()]
        layers.append(nn.Linear(dim_hidden, dim_act * 4))
        self.hidden_layers = nn.Sequential(*layers)

        self.apply(weight_init)
        self.exploring = False

        self.info = {}

    def forward(self, obs, log=False):
        greedy_mu, greedy_log_std, exp_mu, exp_log_std = self.hidden_layers(obs).chunk(
            4, dim=-1
        )
        if self.exploring:
            mu, raw_log_std = exp_mu, exp_log_std
        else:
            mu, raw_log_std = greedy_mu, greedy_log_std

        std = F.softplus(raw_log_std) + 1e-5
        dist = SquashedNormal(mu, std)

        if log:
            with torch.no_grad():
                q005, q095 = torch.quantile(
                    raw_log_std,
                    torch.tensor([0.05, 0.95], device=obs.device),
                )
                info = dict(
                    actor_raw_logstd_005=q005,
                    actor_raw_logstd_095=q095,
                    actor_raw_scale=std.mean(),
                )
                if self.exploring:
                    info = {f"v_{key}": value for key, value in info.items()}
                self.info.update(**info)
        return dist
