import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal, kl_divergence

from erl_lib.agent.model_based.modules.gaussian_mlp import Model, GaussianMLP
from erl_lib.util.misc import soft_bound


class VariationalModel(GaussianMLP):
    def __init__(
        self,
        *args,
        rollout_horizon,
        separate_qr=False,
        rich_input=False,
        hidden_depth_v: int = 3,
        hid_size_v: int = 256,
        # batch_size: int = 256,
        # mu_scale_ub: float = 100.0,
        layer_norm: bool = False,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.hidden_depth_v = hidden_depth_v
        self.hid_size_v = hid_size_v
        # self.batch_size = batch_size
        # self.reward_mu_scale_bias = reward_mu_scale_bias
        self.separate_qr = separate_qr
        self.rich_input = rich_input
        # self.mu_scale_ub = mu_scale_ub

        dim_input_v = kwargs["dim_input"]
        if not separate_qr:
            dim_input_v += 1
        self.layers_e = self.build_model(
            dim_input_v, hidden_depth_v, hidden_depth_v, layer_norm
        )
        if separate_qr:
            self.layers_q = self.build_model(
                dim_input_v,
                self.hid_size_v,
                self.hidden_depth_v,
                kwargs["layer_norm"],
            )
        else:
            self.terminal_flag = (
                torch.ones(
                    (self.batch_size * rollout_horizon, 1),
                    dtype=torch.float32,
                    device=self.device,
                )
                * 0.5
            )

        self._info = {}

    def _stack_layers(
        self,
        dim_input: int,
        dim_output: int,
        hid_size: int,
        hidden_depth: int,
        layer_norm: bool,
    ):
        # Network for variational scale parameter for reward
        hidden_layers = [
            nn.Linear(dim_input, hid_size),
            nn.SiLU(),
        ]
        for i in range(hidden_depth - 1):
            if layer_norm:
                hidden_layers.append(
                    nn.LayerNorm(hid_size, elementwise_affine=False),
                )
            hidden_layers.extend([nn.Linear(hid_size, hid_size), nn.SiLU()])
        if layer_norm:
            hidden_layers.append(
                nn.LayerNorm(hid_size, elementwise_affine=False),
            )
        hidden_layers.append(nn.Linear(hid_size, dim_output))
        return hidden_layers

    def sample(self, *args, **kwargs):
        next_obs, reward, terminated, _ = super().sample(*args, **kwargs)
        return next_obs, reward, terminated, self._info.copy()

    def variational_parameters(self, recurse: bool = True):
        if self.separate_qr:
            return list(self.layers_e.parameters(recurse)) + list(
                self.layers_q.parameters(recurse)
            )
        else:
            return self.layers_e.parameters(recurse)


class VariationalDiscreteModel(VariationalModel):
    # def __init__(self, *args, **kwargs):
    #     super().__init__(*args, **kwargs)

    def build_model(
        self,
        dim_input: int,
        hid_size: int,
        hidden_depth: int,
        layer_norm: bool,
    ):
        class SoftmaxLayer(nn.Module):
            def __init__(self, batch_size, device):
                super(SoftmaxLayer, self).__init__()
                # self.bias = torch.zeros(
                #     (batch_size, 1),
                #     device=device,
                #     dtype=torch.float32,
                # )
                self.log_softmax = nn.LogSoftmax(dim=0)

            def forward(self, x):
                bias = x.detach()[..., :1] * 0
                x = torch.cat([x, bias], -1)
                return self.log_softmax(x.t())

        dim_output = self.num_members - 1
        hidden_layers = self._stack_layers(
            dim_input, dim_output, hid_size, hidden_depth, layer_norm
        )
        hidden_layers.append(SoftmaxLayer(self.batch_size, self.device))
        model = nn.Sequential(*hidden_layers).to(self.device)
        return model

    def variational_terms(self, ensemble, x, is_terminal=False):
        """

        Args:
            ensemble: [M, B]
            x: [B, D]
            is_terminal:

        Returns:

        """
        ############# Variational Scales
        if not self.separate_qr:
            terminal_flag = torch.ones_like(x[:, :1]) * 0.5
            if not is_terminal:
                terminal_flag *= -1.0
            cat_xs = [x, terminal_flag]
            x = torch.cat(cat_xs, 1)
        x = x.detach()

        if self.separate_qr and is_terminal:
            log_probs = self.layers_q(x)
        else:
            log_probs = self.layers_e(x)
        probs = log_probs.exp()
        variational_value = torch.sum(probs * ensemble, 0)[..., None]

        ############# KL term
        kl = torch.sum(probs * log_probs, 0)[..., None] + np.log(self.num_members)
        return variational_value, kl

    def internal_reward(self, x: torch.Tensor, idx=None, is_terminal=False):
        mus = self.layers(x)

        mus_reward = torch.tensor_split(mus, [1], dim=2)[0].squeeze(-1)

        mu_v, kl = self.variational_terms(mus_reward, x, is_terminal)
        ############# Combine KL term with prior model's prediction
        mu_v = mu_v * self.output_normalizer.std[0, 0]
        # base_scale = sample_std[0, 0] * std_epi_reward
        return mu_v, kl


class VariationalContinuousModel(VariationalModel):
    def __init__(
        self,
        *args,
        analytic_kl=True,
        exponential_act: bool = True,
        reward_mu_scale_bias: float = 1e-0,
        std_scale_lb: float = -4.0,
        std_scale_ub: float = 1.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.analytic_kl = analytic_kl
        self.std_scale_lb, self.std_scale_ub = std_scale_lb, std_scale_ub

        if exponential_act:
            reward_mu_scale_log_bias = np.log(reward_mu_scale_bias)
            reward_mu_min_bias = np.exp(std_scale_lb)

            def reward_scale_act(raw_scale):
                scale = (
                    raw_scale + reward_mu_scale_log_bias
                ).exp() - reward_mu_min_bias
                return scale

        else:
            reward_mu_scale_log_bias = np.log(np.exp(reward_mu_scale_bias) - 1.0)

            def reward_scale_act(raw_scale):
                return soft_bound(
                    raw_scale + reward_mu_scale_log_bias,
                    lb=0.0,
                )

        self.reward_scale_act = reward_scale_act

    def build_model(
        self,
        dim_input: int,
        hid_size: int,
        hidden_depth: int,
        layer_norm: bool,
    ):
        dim_output = 2
        hidden_layers = self._stack_layers(
            dim_input, dim_output, hid_size, hidden_depth, layer_norm
        )
        model = nn.Sequential(*hidden_layers).to(self.device)
        return model

    def variational_terms(self, mu, sigma, x, is_terminal=False):
        ############# Variational Scales
        if not self.separate_qr:
            terminal_flag = torch.ones_like(x[:, :1]) * 0.5
            if not is_terminal:
                terminal_flag *= -1.0
            cat_xs = [x, terminal_flag]
            # if self.rich_input:
            #     cat_xs += [mu, sigma]
            x = torch.cat(cat_xs, 1)
        x = x.detach()

        if self.separate_qr and is_terminal:
            mu_std_scale = self.layers_q(x)
        else:
            mu_std_scale = self.layers_e(x)

        mu_scale_raw, log_std_scale = mu_std_scale.chunk(2, dim=1)

        log_std_scale = soft_bound(
            log_std_scale, lb=self.std_scale_lb, ub=self.std_scale_ub
        )
        mu_scale = self.reward_scale_act(mu_scale_raw)
        ############# KL term
        if self.analytic_kl:
            kl = (
                torch.square(mu_scale) + torch.square(log_std_scale.exp()) - 1
            ) * 0.5 - log_std_scale
        else:
            mu_v_detach_reward = mu.detach() + sigma.detach() * mu_scale
            log_std_epi_v_detach_reward = sigma.detach().log() + log_std_scale
            std_epi_v_detach_reward = log_std_epi_v_detach_reward.exp()

            p_dist = Normal(mu.detach(), sigma.detach())
            q_dist = Normal(mu_v_detach_reward, std_epi_v_detach_reward)
            kl = kl_divergence(q_dist, p_dist).sum(-1, keepdim=True)

        # if self.mu_scale_ub:
        #     mu_scale = mu_scale.clamp(max=self.mu_scale_ub)
        # if self.stochastic:
        #     epsilon = torch.randn_like(mu_scale)
        #     mu_scale = mu_scale + epsilon * log_std_scale.exp()
        return mu_scale, kl

    def internal_reward(self, x: torch.Tensor, idx=None, is_terminal=False):
        mus = self.layers(x)

        std_epi_reward = mus[..., :1].std(0)
        mu = torch.sum(mus * idx, 0)
        mu_reward = torch.tensor_split(mu, [1], dim=1)[0]

        mu_v_scale, kl = self.variational_terms(
            mu_reward, std_epi_reward, x, is_terminal
        )
        ############# Combine KL term with prior model's prediction
        sample_std = self.output_normalizer.std
        base_scale = sample_std[0, 0] * std_epi_reward
        return base_scale, mu_v_scale, kl


# class eXtremeModel(GaussianMLP):
#     def __init__(
#         self,
#         *args,
#         hidden_depth_v: int = 3,
#         hid_size_v: int = 256,
#         batch_size: int = 256,
#         dim_input_v: int = None,
#         reward_mu_scale_bias: float = 1e-0,
#         mu_scale_ub: float = 100.0,
#         std_scale_lb: float = -4.0,
#         std_scale_ub: float = 1.0,
#         exponential_act: bool = True,
#         layer_norm: bool = False,
#         **kwargs,
#     ):
#         super().__init__(*args, **kwargs)
#         self.hidden_depth_v = hidden_depth_v
#         self.hid_size_v = hid_size_v
#         self.batch_size = batch_size
#         self.reward_mu_scale_bias = reward_mu_scale_bias
#
#         self.mu_scale_ub = mu_scale_ub
#         self.std_scale_lb, self.std_scale_ub = std_scale_lb, std_scale_ub
#         dim_input = self.dim_input if dim_input_v is None else dim_input_v
#
#         self.layers_e = self.build_model(
#             dim_input, hidden_depth_v, hidden_depth_v, layer_norm
#         )
#
#         #
#         if exponential_act:
#             reward_mu_scale_log_bias = np.log(reward_mu_scale_bias)
#
#             def reward_scale_act(raw_scale):
#                 scale = (raw_scale + reward_mu_scale_log_bias).exp()
#                 return scale
#                 # return soft_bound(scale, ub=self.mu_scale_ub)
#
#         else:
#             reward_mu_scale_log_bias = np.log(np.exp(reward_mu_scale_bias) - 1.0)
#
#             def reward_scale_act(raw_scale):
#                 return soft_bound(
#                     raw_scale + reward_mu_scale_log_bias,
#                     lb=0.0,
#                     # ub=self.mu_scale_ub
#                 )
#
#         self.reward_scale_act = reward_scale_act
#
#         num_repeat = batch_size // self.num_members
#         idx_row = torch.arange(self.num_members, device=self.device).repeat_interleave(
#             num_repeat
#         )
#         idx_col = torch.arange(batch_size, device=self.device)
#         self.mask = torch.zeros(
#             (self.num_members, batch_size), device=self.device, dtype=torch.float32
#         )
#         self.mask[idx_row, idx_col] = 1.0
#         self.mask = self.mask[..., None].repeat_interleave(self.dim_output, 2)
#
#         self._info = {}
#
#     def build_model(
#         self,
#         dim_input: int,
#         hid_size: int,
#         hidden_depth: int,
#         layer_norm: bool,
#     ):
#         # Network for variational scale parameter for reward
#         hidden_layers = [
#             nn.Linear(dim_input, hid_size),
#             nn.SiLU(),
#         ]
#         for i in range(hidden_depth - 1):
#             if layer_norm:
#                 hidden_layers.append(
#                     nn.LayerNorm(hid_size, elementwise_affine=False),
#                 )
#             hidden_layers.extend([nn.Linear(hid_size, hid_size), nn.SiLU()])
#         if layer_norm:
#             hidden_layers.append(
#                 nn.LayerNorm(hid_size, elementwise_affine=False),
#             )
#         hidden_layers.append(nn.Linear(hid_size, 2))
#         return nn.Sequential(*hidden_layers).to(self.device)
#
#     # def forward(self, x: torch.Tensor, prediction_strategy=None, log=False, **kwargs):
#     #     if prediction_strategy == CTX_EXPLR:
#     #         return self._forward(x, log)
#     #     else:
#     #         return super().forward(x, prediction_strategy, **kwargs)
#
#     def _forward(self, x: torch.Tensor, log=False):
#         reward_x, mus, reward_scale = self.internal_reward(x)
#
#         mu = torch.sum(mus * self.mask, 0)
#         log_std = torch.sum(self._log_noise * self.mask, 0)
#
#         # self._info[EXT_REW] = reward_x
#         mu[:, :1] += reward_x
#         if log:
#             with torch.no_grad():
#                 reward_x_unormed = reward_x.mean() * self.output_normalizer.std[0, 0]
#                 self._info.update(
#                     **{
#                         "extreme_reward_scale": reward_scale.mean(),
#                         "extreme_reward": reward_x_unormed,
#                     }
#                 )
#
#         return mu, log_std.exp()
#
#     def sample(self, *args, **kwargs):
#         next_obs, reward, terminated, _ = super().sample(*args, **kwargs)
#         return next_obs, reward, terminated, self._info.copy()
#
#     def variational_parameters(self, recurse: bool = True):
#         if self.separate_qr:
#             return list(self.layers_e.parameters(recurse)) + list(
#                 self.layers_q.parameters(recurse)
#             )
#         else:
#             return self.layers_e.parameters(recurse)
