from itertools import chain

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence

from erl_lib.agent.model_based.modules.gaussian_mlp import Model, GaussianMLP

# from erl_lib.model_based.modules.gaussian_process import GPsModel
from erl_lib.agent.module.layer import NormalizedEnsembleLinear
from erl_lib.base import (
    KEY_REW_SCALE,
    KEY_REW_SIGMA_SCALE,
    KEY_STATE_SCALE,
    KEY_SIGMA_SCALE,
    CTX_EXPLR,
    KL,
)
from erl_lib.util.misc import SymLog, soft_bound


class BehaviorModel(Model):
    hidden_v_layers: nn.Module

    def optimistic_scales(
        self, x, mu, logstd_epi, info, with_pred=True, log=False, **_
    ):
        raise NotImplementedError

    def variational_mu_var_info(
        self,
        mu,
        logstd_epi,
        mu_v_scale,
        log_scale_v,
        info={},
        with_pred=True,
    ):
        raise NotImplementedError

    def variational_parameters(self, recurse: bool = True):
        return self.hidden_v_layers.parameters(recurse)


class VariationalModel(BehaviorModel):
    def __init__(
        self,
        dim_input: int,
        dim_output: int,
        device,
        learned_reward=True,
        hidden_depth_v: int = 2,
        hid_size_v: int = 256,
        reward_mu_scale_bias: float = 1e-2,
        reward_ub: float = 1e4,
        std_scale_ub: float = 1.0,
        std_scale_lb: float = -4.0,
        optimistic_reward_only=False,
        exp_reward_scale: bool = True,
        eta: float = 0.0,
        **kwargs,
    ):
        super().__init__(dim_input, dim_output, device, learned_reward=learned_reward)

        self.eta = None
        separate_model_dim_output = 0
        if optimistic_reward_only:
            if 0 < eta:
                # Variational reward params and hallucinated observations
                self.eta = torch.tensor([eta], dtype=torch.float32, device=device)
                separate_model_dim_output = dim_output - 1
            dim_output = 1

        elif 0 < eta:
            raise ValueError(f"Eta should be 0 when optimistic_reward_only==True")

        # self.exp_reward_scale = exp_reward_scale

        hidden_layers = [
            NormalizedEnsembleLinear(
                2,
                dim_input,
                dim_output=hid_size_v,
                activation=nn.SiLU(),
                # dropout_rate=dropout_rate,
            )
        ]
        for i in range(hidden_depth_v - 1):
            hidden_layers.append(
                NormalizedEnsembleLinear(
                    2,
                    hid_size_v,
                    normalize_eps=1e-5,
                    activation=nn.SiLU(),
                    # dropout_rate=dropout_rate,
                )
            )
        hidden_layers.append(
            NormalizedEnsembleLinear(
                2,
                hid_size_v,
                dim_output=dim_output,
                normalize_eps=1e-5,
                # dropout_rate=dropout_rate,
            )
        )
        self.hidden_v_layers = nn.Sequential(*hidden_layers)

        if separate_model_dim_output:
            hidden_layers = [nn.Linear(dim_input, hid_size_v), nn.SiLU()]
            for i in range(hidden_depth_v - 1):
                hidden_layers.extend(
                    [
                        nn.LayerNorm(hid_size_v, elementwise_affine=False, eps=1e-4),
                        nn.Linear(hid_size_v, hid_size_v),
                        nn.SiLU(),
                    ]
                )
            hidden_layers.extend(
                [
                    nn.LayerNorm(hid_size_v, elementwise_affine=False, eps=1e-4),
                    nn.Linear(hid_size_v, separate_model_dim_output),
                ]
            )
            self.additional_hidden_v_layers = nn.Sequential(*hidden_layers)

        self.act_obs = SymLog()
        self.act_std_scale = nn.Softplus()
        self.std_scale_ub = std_scale_ub
        self.std_scale_lb = std_scale_lb
        if std_scale_ub is not None and std_scale_lb is not None:
            assert std_scale_lb < std_scale_ub
            self.bias_std_scale = (
                0 if 0 < std_scale_ub else (std_scale_ub + std_scale_lb) * 0.5
            )
        else:
            self.bias_std_scale = 0.0

        bias = np.log(np.exp(reward_mu_scale_bias) - 1)

        if exp_reward_scale:

            def act_reward_scale(raw_scale):
                # mu_reward_scale = (
                #                           raw_scale + np.log(reward_mu_scale_bias)).exp() - reward_mu_scale_bias
                mu_reward_scale = raw_scale.exp() - reward_mu_scale_bias
                return soft_bound(mu_reward_scale, ub=reward_ub)

            self.reward_mu_scale_log_bias = np.log(reward_mu_scale_bias)
        else:

            def act_reward_scale(raw_scale):
                mu_reward_scale = F.softplus(raw_scale + bias) - reward_mu_scale_bias
                return soft_bound(mu_reward_scale, ub=reward_ub)

        self.act_reward_scale = act_reward_scale

        self.reward_ub = reward_ub
        self.optimistic_reward_only = optimistic_reward_only

        self.to(device)

        self._info = {}

    def optimistic_scales(self, x, mu, logstd_epi, log=False, with_pred=True, **_):
        if self.optimistic_reward_only:
            if self.eta is None:
                return self._variational_reward_scales(
                    x, mu, logstd_epi, log=log, with_pred=with_pred
                )
            else:
                return self._variational_hallucinated_scales(
                    x, mu, logstd_epi, log=log, with_pred=with_pred
                )
        else:
            return self._variational_model_scales(
                x, mu, logstd_epi, log=log, with_pred=with_pred
            )

    def _variational_model_scales(
        self, x, mu, log_var_epi, log=False, with_pred=True, **_
    ):
        x = x.detach()
        mu_std_scale = self.hidden_v_layers(x)
        mu_scale, log_std_scale = mu_std_scale * 0.1
        log_std_scale = soft_bound(
            log_std_scale, lb=self.std_scale_lb, ub=self.std_scale_ub
        )

        if self.learned_reward:
            mu_reward_scale, mu_obs_scale = torch.tensor_split(mu_scale, [1], dim=1)
            mu_reward_v_scale = self.act_reward_scale(mu_reward_scale)
            mu_obs_scale = self.act_obs(mu_obs_scale)
            mu_v_scale = torch.hstack([mu_reward_v_scale, mu_obs_scale])
        else:
            mu_v_scale = self.act_obs(mu_scale)

        # Combine with prior model's prediction
        log_std_epi = log_var_epi * 0.5
        std_epi = log_std_epi.exp()

        mu_detach = mu.detach()
        std_epi_detach = std_epi.detach()
        log_std_epi_detach = log_std_epi.detach()

        mu_v_detach = mu_detach + std_epi_detach * mu_v_scale
        log_std_epi_v_detach = log_std_epi_detach + log_std_scale
        std_epi_v_detach = log_std_epi_v_detach.exp()

        p_dist = Normal(mu_detach, std_epi_detach)
        q_dist = Normal(mu_v_detach, std_epi_v_detach)
        kl = kl_divergence(q_dist, p_dist).sum(-1, keepdim=True)

        self._info[KL] = kl

        if log:
            with torch.no_grad():
                if self.learned_reward:
                    self._info[KEY_REW_SCALE] = mu_reward_v_scale.mean()
                    self._info[KEY_STATE_SCALE] = mu_obs_scale.std(0).mean()
                    self._info[KEY_SIGMA_SCALE] = log_std_scale.mean()
                else:
                    self._info[KEY_STATE_SCALE] = mu_v_scale.std(0).mean()
                    self._info[KEY_SIGMA_SCALE] = log_std_scale.mean()

        if with_pred:
            mu_v = mu + std_epi * mu_v_scale
            log_var_epi_v = log_var_epi + log_std_scale * 2
            return mu_v, log_var_epi_v
        else:
            return mu_v_detach, log_std_epi_v_detach * 2

    def _variational_reward_scales(
        self, x, mu, log_var_epi, log=False, with_pred=True, **_
    ):
        x = x.detach()
        mu_std_scale = self.hidden_v_layers(x)
        # Attempting to avoid passing large value into activation function,
        # which could happen after initialization and lead vanishing gradient
        # when backpropagation. But this maybe unnecessary.
        mu_scale_reward, log_std_scale_reward = mu_std_scale

        log_std_scale_reward = soft_bound(
            log_std_scale_reward, lb=self.std_scale_lb, ub=self.std_scale_ub
        )

        mu_reward_v_scale = self.act_reward_scale(mu_scale_reward)

        # Combine with prior model's prediction
        mu_reward, mu_obs = torch.tensor_split(mu, [1], dim=1)
        log_var_epi_reward, log_var_epi_obs = torch.tensor_split(
            log_var_epi, [1], dim=1
        )
        std_epi_reward = (log_var_epi_reward * 0.5).exp()

        mu_v_detach_reward = (
            mu_reward.detach() + std_epi_reward.detach() * mu_reward_v_scale
        )
        log_std_epi_v_detach_reward = (
            log_var_epi_reward.detach() * 0.5 + log_std_scale_reward
        )
        std_epi_v_detach_reward = log_std_epi_v_detach_reward.exp()

        p_dist = Normal(mu_reward.detach(), std_epi_reward.detach())
        q_dist = Normal(mu_v_detach_reward, std_epi_v_detach_reward)
        kl = kl_divergence(q_dist, p_dist).sum(-1, keepdim=True)

        self._info[KL] = kl
        if log:
            with torch.no_grad():
                self._info[KEY_REW_SCALE] = mu_reward_v_scale.mean()
                self._info[KEY_REW_SIGMA_SCALE] = log_std_scale_reward.mean()

        if with_pred:
            mu_v_reward = mu_reward + std_epi_reward * mu_reward_v_scale
            log_var_epi_v_reward = log_var_epi_reward + log_std_scale_reward * 2.0

            mu_v = torch.cat([mu_v_reward, mu_obs], dim=1)
            log_var_epi_v = torch.cat(([log_var_epi_v_reward, log_var_epi_obs]), dim=1)
            return mu_v, log_var_epi_v
        else:
            return None, None
            # mu_v_detach = torch.cat([mu_v_detach_reward, mu_obs], dim=1)
            # log_var_epi_v_detach = torch.cat(
            #     [log_var_epi_v_detach_reward, logstd_epi_obs], dim=1
            # )
            # return mu_v_detach, log_var_epi_v_detach, info

    def _variational_hallucinated_scales(
        self, x, mu, logstd_epi, info, log=False, with_pred=True, **_
    ):
        # x = x.detach()
        # mu_std_scale = self.hidden_v_layers(x) * 0.1
        # # Attempting to avoid passing large value into activation function,
        # # which could happen after initialization and lead vanishing gradient
        # # when backpropagation. But this maybe unnecessary.
        # mu_scale_reward, log_var_scale_reward, mu_scale_obs = torch.tensor_split(
        #     mu_std_scale, [1, 2], dim=1
        # )

        mu_scale_reward, log_var_scale_reward = self.hidden_v_layers(x.detach()) * 0.1
        mu_scale_obs = self.additional_hidden_v_layers(x) * 0.1
        # Attempting to avoid passing large value into activation function,
        # which could happen after initialization and lead vanishing gradient
        # when backpropagation. But this maybe unnecessary.
        # mu_scale_reward, log_var_scale_reward,  = torch.tensor_split(
        #     mu_std_scale, [1, 2], dim=1
        # )

        # Log scale for reward's sigma param
        log_var_scale_reward = self.std_scale_ub - self.act_std_scale(
            self.std_scale_ub - log_var_scale_reward
        )
        log_var_scale_reward.sub_(self.std_scale_lb)
        log_var_scale_reward = self.std_scale_lb + self.act_std_scale(
            log_var_scale_reward
        )

        # Scale for reward's mu param
        mu_v_scale_reward = (
            mu_scale_reward + self.reward_mu_scale_log_bias
        ).exp() - self.reward_mu_scale_bias

        # Bounded scale for obs's param
        mu_v_scale_obs = torch.tanh(mu_scale_obs) * self.eta

        # Combine with prior model's prediction
        mu_reward, mu_obs = torch.tensor_split(mu, [1], dim=1)
        logstd_epi_reward, logstd_epi_obs = torch.tensor_split(logstd_epi, [1], dim=1)
        std_epi_reward = logstd_epi_reward.exp()
        std_epi_obs = logstd_epi_obs.exp()

        # Behavior mu
        mu_v_reward = mu_reward + std_epi_reward * mu_v_scale_reward
        mu_v_obs = mu_obs + std_epi_obs * mu_v_scale_obs
        mu_v = torch.cat([mu_v_reward, mu_v_obs], dim=1)
        # Behavior sigma
        logstd_epi_v_reward = logstd_epi_reward + log_var_scale_reward
        logstd_epi_obs = torch.full_like(logstd_epi_obs, -torch.inf)
        logstd_epi_v = torch.cat(([logstd_epi_v_reward, logstd_epi_obs]), dim=1)

        # KL constraint only for reward
        mu_detach_reward = mu_reward.detach()
        std_epi_detach_reward = std_epi_reward.detach()
        log_var_epi_detach_reward = logstd_epi_reward.detach()
        mu_v_detach_reward = (
            mu_detach_reward + std_epi_detach_reward * mu_v_scale_reward
        )
        log_var_epi_v_detach_reward = log_var_epi_detach_reward + log_var_scale_reward
        std_epi_v_detach_reward = log_var_epi_v_detach_reward.exp()
        p_dist = Normal(mu_detach_reward, std_epi_detach_reward)
        q_dist = Normal(mu_v_detach_reward, std_epi_v_detach_reward)
        kl = kl_divergence(q_dist, p_dist).sum(-1, keepdim=True)

        # Log
        info[KL] = kl
        if log:
            with torch.no_grad():
                info[KEY_REW_SCALE] = mu_v_scale_reward.mean()
                info[KEY_STATE_SCALE] = mu_v_scale_obs.std(0).mean()

        return mu_v, logstd_epi_v, info

    def base_dist(self, *args, **kwargs):
        """Predicts optimistic parameter of Gaussian dist. based on mu and scale from posterior model."""
        pass

    def variational_parameters(self, recurse: bool = True):
        params = super().variational_parameters(recurse=recurse)
        if self.eta is None:
            return params
        else:
            if not hasattr(self, "additional_hidden_v_layers"):
                raise RuntimeError(f"{self} doesn't have additional_hidden_v_layers")
            return chain(params, self.additional_hidden_v_layers.parameters())


class HallucinatedModel(BehaviorModel):
    def __init__(
        self,
        dim_input: int,
        dim_output: int,
        device,
        learned_reward=True,
        hidden_depth_v: int = 2,
        hid_size_v: int = 256,
        base_log_scale: float = 0.0,
        optimistic_reward_only: bool = False,
        eta: float = 1.0,
        **kwargs,
    ):
        super().__init__(dim_input, dim_output, device, learned_reward=learned_reward)

        self.optimistic_reward_only = optimistic_reward_only
        self.base_log_scale = base_log_scale
        self.eta = torch.tensor([eta], dtype=torch.float32, device=device)

        if optimistic_reward_only:
            dim_output = 1

        hidden_layers = [nn.Linear(dim_input, hid_size_v), nn.SiLU()]
        for i in range(hidden_depth_v - 1):
            hidden_layers.extend(
                [
                    nn.LayerNorm(hid_size_v, elementwise_affine=False, eps=1e-4),
                    nn.Linear(hid_size_v, hid_size_v),
                    nn.SiLU(),
                ]
            )

        hidden_layers.extend(
            [
                nn.LayerNorm(hid_size_v, elementwise_affine=False, eps=1e-4),
                nn.Linear(hid_size_v, dim_output),
            ]
        )
        self.hidden_v_layers = nn.Sequential(*hidden_layers)
        self.act_obs = nn.Tanh()

        self.to(device)

    def optimistic_scales(
        self, x, mu, log_var_epi, info, base_mu_scale=1, with_pred=True, log=False, **_
    ):
        if self.optimistic_reward_only:
            return self._optimistic_reward_scales(x, mu, log_var_epi, info, log=log)
        else:
            return self._optimistic_model_scales(x, mu, log_var_epi, info, log=log)

    def _optimistic_reward_scales(self, x, mu, logstd_epi, info, log=False, **_):
        # Attempting to avoid passing large value into activation function,
        # which could happen after initialization and lead vanishing gradient
        # when backpropagation. But this maybe unnecessary.
        mu_scale_reward = self.hidden_v_layers(x) * 0.1
        mu_v_scale_reward = torch.sigmoid(mu_scale_reward)

        # Combine with prior model's prediction
        mu_reward, mu_obs = torch.tensor_split(mu, [1], dim=1)
        logstd_epi_reward, logstd_epi_obs = torch.tensor_split(logstd_epi, [1], dim=1)
        std_epi_reward = logstd_epi_reward.exp()
        # Mu scale for reward deterministic function
        mu_v_reward = mu_reward + std_epi_reward * mu_v_scale_reward
        logstd_epi_v_reward = torch.full_like(logstd_epi_reward, -torch.inf)

        mu_v = torch.cat([mu_v_reward, mu_obs], dim=1)
        logstd_epi_v = torch.cat(([logstd_epi_v_reward, logstd_epi_obs]), dim=1)

        if log:
            with torch.no_grad():
                info[KEY_REW_SCALE] = mu_v_scale_reward.mean()
        return mu_v, logstd_epi_v, info

    def _optimistic_model_scales(self, x, mu, log_var_epi, info, log=False, **_):
        mu_scale = self.hidden_v_layers(x)
        # Attempting to avoid passing large value into activation function,
        # which could happen after initialization and lead vanishing gradient
        # when backpropagation. But this maybe unnecessary.
        mu_scale = mu_scale * 0.1

        if self.learned_reward:
            mu_reward_scale, mu_obs_scale = torch.tensor_split(mu_scale, [1], dim=1)
            mu_reward_v_scale = torch.sigmoid(mu_reward_scale)
            mu_obs_v_scale = self.act_obs(mu_obs_scale)
            mu_v_scale = torch.hstack([mu_reward_v_scale, mu_obs_v_scale])

            if log:
                info[KEY_REW_SCALE] = mu_reward_v_scale.detach().mean()
                info[KEY_STATE_SCALE] = mu_obs_v_scale.detach().std(0).mean()
        else:
            mu_v_scale = self.act_obs(mu_scale)
            if log:
                info[KEY_STATE_SCALE] = mu_v_scale.detach().std(0).mean()

        # Combine with prior model's prediction
        std_epi = (log_var_epi * 0.5).exp()
        mu_v = mu + std_epi * mu_v_scale * self.eta

        if self.base_log_scale <= 0:
            logstd_epi.detach_()[:] = -torch.inf
        else:
            logstd_epi = (std_epi * self.base_log_scale).log()

        return mu_v, logstd_epi, info


class BehaviorGaussianEnsemble(GaussianMLP):
    def moment_dist(self, x, log=False, with_pred=True, pred_context=None, **kwargs):
        # Called inside model_env.step method and the final step at policy improvement.
        # Predict prior distribution
        mu, log_var_epi, log_var_ale = super(
            BehaviorGaussianEnsemble, self
        ).moment_dist(x, **kwargs)
        # Exploration by variational model
        if pred_context == CTX_EXPLR:
            # Predict variational epistemic distribution parameters
            mu, log_var_epi = self.optimistic_scales(
                x, mu, log_var_epi, log=log, with_pred=with_pred, **kwargs
            )
            # variance = log_var_epi.exp() + log_var_ale.exp()
        return mu, log_var_epi, log_var_ale

    def sample(self, *args, **kwargs):
        next_obs, reward, terminated, _ = super().sample(*args, **kwargs)
        return next_obs, reward, terminated, self._info


class VariationalGaussianEnsemble(BehaviorGaussianEnsemble, VariationalModel):
    pass


class HallucinatedGaussianEnsemble(BehaviorGaussianEnsemble, HallucinatedModel):
    pass
