from dataclasses import dataclass
from contextlib import contextmanager, nullcontext
from hydra.utils import instantiate
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from erl_lib.agent.svg import SVGAgent, ContextModules

from erl_lib.base import KEY_CRITIC_LOSS, KEY_ACTOR_LOSS, Q_MEAN
from erl_lib.util.misc import calc_grad_norm, soft_update_params
from erl_lib.agent.model_based.modules.gaussian_mlp import PS_TS1, PS_INF
from erl_lib.agent.module.actor.two_heads import TwoHeadsDGA


Q_EPI_STD = "q_epi_std"
Q_DIFF_STD = "q_diff_std"
V_PRIOR = "v_prior"
IMPRV = "improvement"
REL_IMPRV = "relative_improvement"
XVE_TARGET = "xve_target"
XTRME_Q = "q_extreme"


class VariationalCriticEnsemble(SVGAgent):
    def __init__(
        self,
        *args,
        target_improvement=0.1,
        beta_init=1.0,
        lr_beta=0.2,
        sample_std: bool = True,
        extreme_epi_q: bool = False,
        **kwargs,
    ):
        self.num_model_members = kwargs["dynamics_model"].num_members
        super(VariationalCriticEnsemble, self).__init__(*args, **kwargs)
        if self.num_critic_ensemble % self.num_model_members != 0:
            raise ValueError(
                f"{self.num_critic_ensemble} != C x {self.num_model_members})"
            )
        self.num_critic_ensemble = int(self.critic.num_members / self.num_model_members)
        size_exp = self.num_model_members, self.batch_size
        self.done_act = self.done.unsqueeze(0).expand(size_exp)
        self.beta = torch.tensor(beta_init, dtype=torch.float32, device=self.device)
        self.lr_beta = lr_beta
        self.sample_std = sample_std
        self.extreme_epi_q = extreme_epi_q
        self.target_improvement = max(target_improvement, 1e-8)
        self.beta_norm = None
        self.beta_std = None

        self.init_optimizer()

    # noinspection LanguageDetectionInspection
    def eval_rollout(
        self,
        batch_sa,
        log_pis,
        batch_masks,
        rewards,
        alpha,
        last_sa,
        discounts,
        critic,
        critic_target,
    ):
        """

        Args:
            batch_sa: [M, (L+1)B, D]
            log_pis: [L+1, M, B]
            batch_masks: [L+1, M, B]
            rewards: [L, M, B]
            last_sa: [M, B, D]
            discounts: [L, L+1]

        Returns:
            target_values: [L, B, M, 1]
            pred_values: [L, B, M, C]

        """
        with torch.no_grad():
            q_values = critic_target(last_sa)
            target_rewards = torch.cat(
                [rewards.transpose(1, 2)[..., None], q_values]
            )  # [L+1, B, M, 1]
            target_rewards.sub_(alpha.detach() * log_pis.transpose(1, 2)[..., None])
            target_values = target_rewards * batch_masks.transpose(1, 2)[..., None]
            target_values = torch.sum(
                discounts[..., None, None, None] * target_values[None], 1
            )

        pred_values = critic(batch_sa.detach())  # [L, B, M, C]
        horizon = discounts.shape[1] - discounts.shape[0]
        mask_ensemble = batch_masks[:-horizon].transpose(1, 2)[..., None]
        pred_values *= mask_ensemble
        return target_values, pred_values

    # def eval_rollout(
    #     self,
    #     batch_sa,
    #     log_pis,
    #     batch_masks,
    #     rewards,
    #     alpha,
    #     last_sa,
    #     discounts,
    #     critic,
    #     critic_target,
    # ):
    #     """
    #
    #     Args:
    #         batch_sa: [M, (L+1)B, D]
    #         log_pis: [L+1, M, B]
    #         batch_masks: [L+1, M, B]
    #         rewards: [L, M, B]
    #         last_sa: [M, B, D]
    #         discounts: [L, L+1]
    #
    #     Returns:
    #         pred_values: [M, LB, C]
    #
    #     """
    #     with torch.no_grad():
    #         q_values = critic_target(last_sa)
    #         target_rewards = torch.cat([rewards.transpose(1, 2), q_values])  # [L+1, B, M]
    #         target_rewards.sub_(alpha.detach() * log_pis.transpose(1, 2))
    #         target_values = target_rewards * batch_masks.transpose(1, 2)
    #         target_values = torch.sum(
    #             discounts[..., None, None] * target_values[None], 1
    #         )
    #         # target_values = discounts[0, :, None, None] *
    #         # target_values = torch.sum(target_values, 0)
    #
    #     pred_values = critic(batch_sa.detach())
    #     # assert self.mve_horizon == 1
    #     horizon = discounts.shape[1] - discounts.shape[0]
    #     mask_ensemble = (
    #         batch_masks[:-horizon, ...]
    #         .transpose(0, 1)
    #         .contiguous()
    #         .view(self.num_model_members, -1)
    #     )
    #     pred_values *= mask_ensemble[..., None]
    #     return target_values[..., None], pred_values

    def pred_q_value(self, obs_action, batch_size=None):
        """

        Args:
            obs_action: [M, B, D]

        Returns:
            pred_q: [L, B, M, C]

        """
        if batch_size is None:
            batch_size = self.batch_size
        shape = (-1, batch_size, self.num_model_members, self.num_critic_ensemble)
        obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
        pred_q = self.critic(obs_action).view(shape)
        if self.scaled_critic:
            pred_q = pred_q * self._q_width + self._q_center
        return pred_q

    def pred_target_q_value(self, obs_action):
        """

        Args:
            obs_action: [M, B, D]
        Returns:
            target_q: [1, B, M, 1]
        """
        # -> [MC, B, D]
        obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
        shape = (1, self.batch_size, self.num_model_members, self.num_critic_ensemble)
        target_q = self.critic_target(obs_action).view(shape)
        target_q = self._reduce(target_q, dim=3, keepdim=True)
        if self.scaled_critic:
            target_q = target_q * self._q_width + self._q_center
        if self.bounded_critic:
            target_q = self._q_ub - torch.relu(self._q_ub - target_q)
            target_q = self._q_lb + torch.relu(target_q - self._q_lb)
        return target_q

    # def pred_q_value(self, obs_action, batch_size=None):
    #     """
    #
    #     Args:
    #         obs_action: [M, B, D]
    #
    #     Returns:
    #         pred_q: [M, B, C]
    #
    #     """
    #     # if batch_size is None:
    #     #     batch_size = self.batch_size
    #     obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
    #     pred_q = (
    #         self.critic(obs_action)
    #         .t()
    #         .view(self.num_model_members, self.num_critic_ensemble, -1)
    #         .transpose(1, 2)
    #     )
    #     if self.scaled_critic:
    #         pred_q = pred_q * self._q_width + self._q_center
    #     return pred_q
    #
    # def pred_target_q_value(self, obs_action):
    #     """
    #
    #     Args:
    #         obs_action: [M, B, D]
    #     Returns:
    #         target_q: [M, B]
    #     """
    #     # -> [MC, B, D]
    #     obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
    #     # -> [M, B, C]
    #     target_q = (
    #         self.critic_target(obs_action)
    #         .t()
    #         .view(self.num_model_members, self.num_critic_ensemble, -1)
    #         .transpose(1, 2)
    #     )
    #     target_q = self._reduce(target_q, dim=2, keepdim=False)
    #     if self.scaled_critic:
    #         target_q = target_q * self._q_width + self._q_center
    #     if self.bounded_critic:
    #         target_q = self._q_ub - torch.relu(self._q_ub - target_q)
    #         target_q = self._q_lb + torch.relu(target_q - self._q_lb)
    #     return target_q.unsqueeze(0)

    def pred_terminal_q(self, obs_action):
        pred_qss = self.pred_q_value(obs_action)  # [L, B, M, C]
        pred_qs = self._reduce(pred_qss, self.actor_reduction, dim=3, keepdim=False)
        return pred_qs.unsqueeze(0)

    def actor_loss(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        """

        Args:
            ctx_modules:
            batch_sa: [M, B, D]
            batch_mask:
            log_pis: [L+1, M, B]
            rewards:
            last_sa:
            info:
            log:

        Returns:

        """
        pred_qs = ctx_modules.pred_terminal_q(batch_sa)
        rewards = torch.cat([rewards, pred_qs])
        rewards.sub_(ctx_modules.alpha.detach() * log_pis)
        pred_qs = self.discount_mat[0, :, None, None] * rewards * batch_mask
        pred_qs = pred_qs.sum(0)

        softmax_q = pred_qs
        loss_actor = -softmax_q.mean()

        entropy = -log_pis[0].detach().mean()
        self._info.update(**{KEY_ACTOR_LOSS: loss_actor.detach(), "entropy": entropy})

        if self.learnable_alpha:
            alpha_loss = -ctx_modules.alpha * (self.target_entropy - entropy)
            loss_actor += alpha_loss

            self._info.update(
                **{
                    "alpha_loss": alpha_loss.detach(),
                    "alpha_value": ctx_modules.alpha.detach(),
                }
            )

        # Take a SGD step
        ctx_modules.actor_optimizer.zero_grad()
        loss_actor.backward()

        # Beta tuning
        with torch.no_grad():
            baseline = pred_qs.mean()
            abs_baseline = torch.abs(baseline).clamp(min=1e-5)
            if self.beta_norm is None:
                self.beta_norm = abs_baseline
            else:
                self.beta_norm.lerp_(abs_baseline, 0.1)

            soft_q_mean = -loss_actor
            improvement = (soft_q_mean - baseline) / self.beta_norm
            q_var = pred_qs.var(0).mean()

            # new_beta = torch.relu(improvement - self.target_improvement) * q_var / 2.0
            new_beta = q_var / (2 * self.target_improvement * self.beta_norm)
            self.beta = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta

        if log:
            self._info["actor_grad_norm"] = calc_grad_norm(ctx_modules.actor)
            with torch.no_grad():
                raw_q_std = q_var.sqrt()
                relative_q_std = raw_q_std / abs_baseline
                self._info.update(
                    **{
                        "soft_q_mean": soft_q_mean,
                        "improvement_v": improvement,
                        "soft_q_std": raw_q_std,
                        "relative_soft_q_std": relative_q_std,
                        "beta": self.beta,
                    }
                )

        ctx_modules.actor_optimizer.step()

    def mb_policy_evaluation(
        self,
        ctx_modules,
        obs,
        log=False,
        **kwargs,
    ):
        obs = obs.expand((self.num_model_members,) + obs.shape)
        returns = super().mb_policy_evaluation(
            ctx_modules,
            obs,
            done=self.done_act.clone(),
            log=log,
            prediction_strategy="full",
        )
        target_q = returns[4]
        if not self.extreme_epi_q:
            if self.sample_std:
                self._info[Q_EPI_STD] = target_q[0].detach().std()
            else:
                self._info[Q_EPI_STD] = target_q[0].detach().std(1).mean()

        return returns

    def distribution_rollout(self, **rollout_kwargs):
        rollout_kwargs["prediction_strategy"] = PS_TS1
        super().distribution_rollout(**rollout_kwargs)

    def _update(self, log=False, **ctx_kwargs):
        return super()._update(log=log, **dict(ctx_kwargs, prediction_strategy="full"))


class VariationalValueGradient(VariationalCriticEnsemble):
    trange_kv = dict(Beta="beta", **SVGAgent.trange_kv)

    def __init__(
        self,
        actor,
        *args,
        v_lr_ratio: float = 1.0,
        softpluss_beta: float = 0.0,
        beta_min: float = 1e-8,
        beta_update: str = "softmax",
        mvve_improvement: bool = False,
        max_beta_inc: float = 2.0,
        discount_beta: bool = False,
        # Unused
        detached_std: bool = False,
        state_ent_balance: float = 0.0,
        separate_actor: bool = False,
        **kwargs,
    ):
        self.v_lr_ratio = v_lr_ratio
        # self.beta_min = np.log(beta_min)
        self.beta_min = beta_min
        self.beta_update = beta_update
        self.softpluss_beta = softpluss_beta
        self.mvve_improvement = mvve_improvement
        self.max_beta_inc = max_beta_inc
        self.detached_std = detached_std
        self.discount_beta = discount_beta
        self.state_ent_balance = state_ent_balance
        self.separate_actor = separate_actor

        super().__init__(*args, actor=actor, **kwargs)

        self.separate_actor = isinstance(self.actor, TwoHeadsDGA)
        if self.separate_actor:
            self._act = self._act_separate

        idx = (
            torch.arange(self.num_model_members)
            .repeat_interleave(self.batch_size // self.num_model_members)
            .to(self.device)
        )
        self.idx_mask = (
            F.one_hot(idx, self.num_model_members)
            .to(torch.float32)
            .t()[None]
            .repeat_interleave(self.training_rollout_horizon + 1, dim=0)
        )
        self.sa_idx_mask = F.one_hot(
            idx.repeat(self.training_rollout_horizon + 1), self.num_model_members
        ).t()[..., None]

    def _act_separate(self, obs, sample=False):
        with torch.no_grad():
            if isinstance(obs, np.ndarray):
                obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
            if self.normalize_input:
                obs = self.input_normalizer.normalize(obs)

            if sample:
                self.actor.is_optimistic = True
                dist = self.actor(obs)
                action = dist.sample()
            else:
                self.actor.is_optimistic = False
                dist = self.actor(obs)
                action = dist.mean
            action = action.cpu().numpy()
        return action

    def _actor_loss(self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False):
        # if self.separate_actor:
        #     return self._separate_actor_loss(
        #         ctx_modules, log_pis, batch_sa, rewards, masks, log
        #     )
        # else:
        return self._single_actor_loss(
            ctx_modules, log_pis, batch_sa, rewards, masks, log
        )

    # def _separate_actor_loss(
    #     self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False
    # ):
    #     pred_qs = self.pred_q_value(batch_sa[..., -self.batch_size :, :])
    #     pred_qs = self._reduce(pred_qs, self.actor_reduction, dim=-1, keepdim=False)[
    #         None
    #     ]
    #     expanded_rewards = torch.cat([rewards, pred_qs])
    #     expanded_rewards.sub_(ctx_modules.alpha.detach() * log_pis)
    #     mve = torch.sum(
    #         (self.discount_mat[0, :, None, None] * expanded_rewards * masks), 0
    #     )
    #     optimized_actor_loss = -mve.mean()
    #
    #     with self.policy_evaluation_context(explore=True) as ctx_modules:
    #         obs = batch_sa[0, : self.batch_size, : self.dim_obs]
    #         action, log_pi, _ = self.sample_action(ctx_modules.actor, obs, log=log)
    #         (
    #             obss,
    #             actions,
    #             log_pis,
    #             rewards,
    #             masks,
    #         ) = self.rollout(
    #             ctx_modules,
    #             obs,
    #             action,
    #             self.training_rollout_horizon,
    #             done=self.done,
    #             mve_horizon=self.mve_horizon,
    #             log_pi=log_pi.squeeze(-1),
    #             log=log,
    #             prediction_strategy=PS_INF,
    #         )
    #     masks = torch.stack(masks).float()
    #     batch_sa = torch.cat([torch.cat(obss, -2), torch.cat(actions, -2)], -1)
    #     log_pis = torch.stack(log_pis)
    #     rewards = torch.stack(rewards)
    #
    #     sa, last_sa = torch.tensor_split(batch_sa, [-self.batch_size], 0)
    #     exp_actor_loss = self._add_x_rewards(sa, last_sa, log_pis, rewards, masks, log)
    #
    #     with torch.no_grad():
    #         self._info[Q_MEAN] = mve.mean()
    #     return optimized_actor_loss + exp_actor_loss

    def _single_actor_loss(
        self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False
    ):
        """

        Args:
            log_pis:
            batch_sa: [M, (L+1)B, S+A]
            rewards: [L, M, B]
            masks:
            info:
            log:

        Returns:

        """
        log_pis = torch.sum(log_pis * self.idx_mask, 1)
        masks = torch.sum(masks * self.idx_mask, 1)
        rewards = torch.sum(rewards * self.idx_mask[:-1], 1)
        sas = torch.sum(batch_sa * self.sa_idx_mask, 0)
        sa, last_sa = torch.tensor_split(sas, [-self.batch_size], 0)
        return self._add_x_rewards(sa, last_sa, log_pis, rewards, masks, log)

    def _add_x_rewards(self, sa, last_sa, log_pis, rewards, masks, log=False):
        base_scale, mu_scale, kl = self.dynamics_model.internal_reward(
            sa, self.sa_idx_mask[:, : -self.batch_size, :]
        )
        reward_x = base_scale * mu_scale
        kl_term = self.beta * kl
        reward_x -= kl_term - kl_term.detach()
        reward_x = (
            reward_x.view(self.training_rollout_horizon, self.batch_size) + rewards
        )

        pred_q_x, pred_q = self.pred_terminal_q(last_sa, log=log)

        reward_x = torch.cat([reward_x, pred_q_x])
        reward_x = reward_x - self.alpha.detach() * log_pis

        vmve = self.discount_mat.mm(reward_x * masks)
        actor_loss = -vmve.mean()

        with torch.no_grad():
            if self.mvve_improvement:
                if not self.separate_actor:
                    rewards = torch.cat([rewards, pred_q]).clone()
                    rewards[1:] -= self.alpha * log_pis[1:]
                    baseline = self.discount_mat.mm(rewards * masks)
                    self._info[Q_MEAN] = baseline.mean()

                sdqe = vmve + self.alpha * log_pis[:1, :]
                # sdqe = vmve
                self._info[XTRME_Q] = sdqe.mean()
                if self.extreme_epi_q:
                    self._info[Q_EPI_STD] = sdqe.std()

            if log:
                self._info.update(
                    **{
                        "mu_reward_scale": mu_scale.mean(),
                        "kl": kl.mean(),
                        "kl_term": kl_term.mean(),
                        # "state_entropy": state_entropy.mean(),
                    }
                )
        return actor_loss

    def pred_terminal_q(self, obs_action, log=False, batch_size=None):
        """

        Args:
            obs_action: [B, D]

        Returns:
            pred_q: [M, B, C]

        """
        if batch_size is None:
            batch_size = self.batch_size

        q_mcb = (
            self.critic(obs_action)
            .t()
            .view(self.num_model_members, self.num_critic_ensemble, batch_size)
        )
        q_mb = q_mcb.mean(1)

        q_mu = q_mb.mean(0)[..., None]
        q_std = q_mb.std(0)[..., None]
        q_v_scale, q_kl = self.dynamics_model.variational_terms(
            q_mu, q_std, obs_action, is_terminal=True
        )
        q_kl_term = self.beta * q_kl / (1 - self.discount)
        if self.actor_reduction == "ub":
            base_scale = q_mcb.std((0, 1))[..., None]
        else:
            base_scale = q_std

        q_bonus = base_scale * q_v_scale

        if self.scaled_critic:
            q_bonus *= self._q_width
            q_mu = q_mu * self._q_width + self._q_center
        q_bonus -= q_kl_term - q_kl_term.detach()
        q_x = q_mu + q_bonus

        with torch.no_grad():
            info = {}
            if not self.mvve_improvement:
                info[XTRME_Q] = q_x.mean()
                if self.extreme_epi_q:
                    info[Q_EPI_STD] = q_x.std()
                else:
                    info[Q_EPI_STD] = q_mb.mean(0).std() * self._q_width

            if log:
                info.update(
                    **{
                        "q_bonus": q_bonus.mean(),
                        "q_scale": q_v_scale.mean(),
                        "q_kl": q_kl.mean(),
                        "q_kl_term": q_kl_term.mean(),
                    }
                )

        self._info.update(**info)
        return q_x.t(), q_mu.t()

    def update_actor(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        super(VariationalCriticEnsemble, self).update_actor(
            ctx_modules, batch_sa, batch_mask, log_pis, rewards, log
        )
        # Beta tuning
        with torch.no_grad():
            baseline = self._info[Q_MEAN]
            q_std = self._info[Q_EPI_STD]
            if self.beta_norm is None:
                self.beta_norm = baseline
                self.beta_std = q_std
            else:
                self.beta_norm.lerp_(baseline, 0.1)
                self.beta_std.lerp_(q_std, 0.1)

            improvement = (self._info[XTRME_Q] - baseline) / self.beta_norm.abs().clamp(
                self.beta_min
            )
            if self.beta_update == "std":
                new_beta = (
                    self.beta_std.square()
                    * (1 - self.discount)
                    * 0.5
                    / self.target_improvement
                )
            else:
                raw_beta = improvement - self.target_improvement
                if 0 < self.softpluss_beta:
                    raw_beta = F.softplus(raw_beta, beta=self.softpluss_beta)

                if self.beta_update == "relu_std":
                    raw_beta *= self.beta_std
                elif self.beta_update == "relu_square":
                    raw_beta *= self.beta_std.square()
                if self.discount_beta:
                    raw_beta *= 1 - self.discount

                max_beta = (
                    self.max_beta_inc * self.beta / self.lr_beta
                    if self.max_beta_inc
                    else None
                )
                new_beta = torch.clamp(raw_beta, min=0, max=max_beta)

            self.beta.lerp_(new_beta, self.lr_beta).clamp_(min=self.beta_min)
            self._info.update(
                **{
                    "relative_q_std": self.beta_std
                    / self.beta_norm.abs().clamp(self.beta_min),
                    "improvement_v": improvement,
                    "beta": self.beta,
                    "baseline": self.beta_norm,
                }
            )

    def build_critics(self, critic_cfg):
        super(SVGAgent, self).build_critics(critic_cfg)
        num_critic_ensemble = int(self.critic.num_members / self.num_model_members)

        if self.weighted_critic:
            size = (
                self.batch_size,
                self.num_model_members,
                num_critic_ensemble,
            )
            weight_rate = torch.ones(size, device=self.device)
            self.critic_loss_weight = torch.poisson(weight_rate)[None]

    def init_optimizer(self):
        self._info = defaultdict(list)
        actor_params = [
            {"params": self.actor.parameters()},
            {"params": [self.raw_alpha], "lr": self.lr_alpha},
        ]
        if hasattr(self, "dynamics_model") and hasattr(
            self.dynamics_model, "variational_parameters"
        ):
            actor_params.append(
                {
                    "params": self.dynamics_model.variational_parameters(),
                    "lr": self.lr * self.v_lr_ratio,
                    # "betas": (0.0, 0.9),
                    # "maximize": True,
                }
            )
        if hasattr(self, "behavior_actor"):
            actor_params.append(
                {"params": self.behavior_actor.parameters(), "lr": self.lr}
            )

        self.actor_optimizer = torch.optim.Adam(actor_params, lr=self.lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.lr)

    @contextmanager
    def policy_evaluation_context(self, detach=False, explore=False, **kwargs):
        if 0 < self.rollout_horizon:
            buffer = self.rollout_buffer
        else:
            buffer = self.replay_buffer

        if explore:
            self.actor.exploring = True

        context_modules = ContextModules(
            self.actor,
            self.actor_optimizer,
            self.pred_q_value,
            self.pred_target_q_value,
            self.pred_terminal_q,
            self.critic,
            self.critic_target,
            self.critic_optimizer,
            self.alpha,
            self.model_step_context,
            buffer,
            detach,
        )
        try:
            yield context_modules
        finally:
            self.actor.exploring = False
