from contextlib import contextmanager, nullcontext
from collections import defaultdict, deque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from erl_lib.agent.svg import SVGAgent, ContextModules

# from erl_lib.base import KEY_CRITIC_LOSS, KEY_ACTOR_LOSS, Q_MEAN
# from erl_lib.util.misc import calc_grad_norm, soft_update_params
# from erl_lib.agent.model_based.modules.gaussian_mlp import PS_TS1, PS_INF

from erl_lib.base import (
    OBS,
    Q_MEAN,
    Q_STD,
    KEY_ACTOR_LOSS,
    KEY_CRITIC_LOSS,
)

STATE_H = "state_entropy"


class SoftModelValueGradient(SVGAgent):
    trange_kv = dict(Eta="eta", **SVGAgent.trange_kv)

    def __init__(
        self,
        *args,
        eta_min: float = 1e-8,
        propagate_ent=True,
        eta_update="opt_speed",
        lr_eta=0.01,
        lambda_slow=3e-4,
        lambda_fast=0.01,
        target_change_ratio: float = 1.0,
        dim_averaged_entropy: bool = False,
        value_baseline: bool = False,
        eta_init=None,
        moving_min: bool = False,
        only_first_entropy: bool = False,
        **kwargs,
    ):
        self.eta_min = eta_min
        self.propagate_ent = propagate_ent
        if eta_update not in ("opt_speed", "opt_ratio", "opt_ratio_std"):
            raise ValueError(
                f"'eta_update' should be 'opt_speed' or 'opt_ratio' or 'opt_ratio_std' but {eta_update}"
            )
        self.eta_update = eta_update
        self.lr_eta = lr_eta
        self.target_change_ratio = target_change_ratio
        self.dim_averaged_entropy = dim_averaged_entropy
        self.value_baseline = value_baseline
        self.moving_min = moving_min
        self.only_first_entropy = only_first_entropy

        if moving_min:
            self.slow_que = deque(maxlen=int(1 / lambda_slow))

        if eta_init is None:
            eta_init = eta_min
        self.raw_eta = nn.Parameter(
            torch.as_tensor(
                np.log(eta_init), dtype=torch.float32, device=kwargs["device"]
            ),
            requires_grad=eta_update.startswith("opt"),
        )
        super().__init__(*args, **kwargs)

        self.ema_slow = None
        self.lambda_slow = lambda_slow
        self.lambda_fast = lambda_fast
        self.target_change = lambda_slow * target_change_ratio

    def model_step_context(self, action, model_state, **kwargs):
        obs, reward, done, info = self.dynamics_model.sample(
            action, model_state, **kwargs
        )
        state_entropy = self.dynamics_model.state_entropy[..., None]
        state_entropy_term = self.eta.detach() * state_entropy
        if self.propagate_ent:
            reward += state_entropy_term
        else:
            reward += state_entropy_term - state_entropy_term.detach()

        self._stack_state_entropies.append(state_entropy)
        return obs, reward, done, info

    @contextmanager
    def policy_evaluation_context(self, detach=False, **kwargs):
        if 0 < self.rollout_horizon:
            buffer = self.rollout_buffer
        else:
            buffer = self.replay_buffer

        self._stack_state_entropies = []

        context_modules = ContextModules(
            self.actor,
            self.actor_optimizer,
            self.pred_q_value,
            self.pred_target_q_value,
            self.pred_terminal_q,
            self.critic,
            self.critic_target,
            self.critic_optimizer,
            self.alpha,
            self.model_step_context,
            buffer,
            detach,
        )
        try:
            yield context_modules
        finally:
            pass

    def update_actor(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        super().update_actor(ctx_modules, batch_sa, batch_mask, log_pis, rewards, log)
        if 0.0 < self.lr_eta:
            info = {}

            with torch.no_grad():
                if self.only_first_entropy:
                    state_entropy = self._stack_state_entropies[0].detach()
                else:
                    state_entropy = torch.vstack(self._stack_state_entropies)
                state_mean_entropy = torch.mean(state_entropy)
                state_entropy_term = self.eta * state_mean_entropy

                current_value = self._info[Q_MEAN]
                # if self.value_baseline:
                #     current_value += self.eta * state_entropy[0, :].mean()

                if self.eta_update == "opt_speed":
                    norm = max(torch.abs(self.ema_fast - self.ema_slow), self.eta_min)
                    change_degree = torch.abs(current_value - self.ema_fast) / norm
                    eta_loss = self.eta_speed_loss(change_degree, self.raw_eta)
                    info["change_degree"] = change_degree
                else:
                    if self.ema_slow is None:
                        self.ema_slow = current_value.clone()
                        # self.ema_fast = current_value.clone()
                        baseline = current_value
                    else:
                        if self.moving_min:
                            self.slow_que.append(current_value)
                            baseline = min(self.slow_que)
                        else:
                            self.ema_slow.lerp_(current_value, self.lambda_slow)
                            baseline = self.ema_slow
                        # self.ema_fast.lerp_(current_value, self.lambda_fast)

                    eta_loss, entropy_term_ratio = self.eta_relative_loss(
                        state_entropy, baseline, self.raw_eta
                    )
                    if self.eta_update == "opt_ratio_std":
                        eta_loss *= self.ema_slow.abs().clamp(min=self.eta_min)
                    info["state_entropy_ratio"] = entropy_term_ratio
                self.eta_optimizer.zero_grad()
                eta_loss.backward()
                self.eta_optimizer.step()

                info.update(
                    **{
                        STATE_H: state_mean_entropy,
                        "eta_loss": eta_loss.detach(),
                        "state_entropy_term": state_entropy_term,
                        "state_entropy_std": state_entropy.std(),
                    }
                )

                self.raw_eta.data.clamp_(min=np.log(self.eta_min))
            self._info.update(**info, eta=self.eta)

    @torch.enable_grad()
    def eta_speed_loss(self, change_degree, raw_eta):
        return (change_degree - self.target_change) * raw_eta.exp()

    @torch.enable_grad()
    def eta_relative_loss(self, state_entropy_term, value, raw_eta):
        eta = raw_eta.exp()
        entropy_term = eta.detach() * state_entropy_term / (1 - self.discount)
        denom = torch.abs(value).clamp_min(self.eta_min)
        entropy_term_ratio = torch.abs(entropy_term).mean() / denom
        return (
            eta * (entropy_term_ratio - self.target_change_ratio),
            entropy_term_ratio.detach(),
        )

    def init_optimizer(self):
        super().init_optimizer()
        if self.eta_update.startswith("opt"):
            self.eta_optimizer = torch.optim.Adam([self.raw_eta], lr=self.lr_eta)

    @property
    def eta(self):
        return self.raw_eta.exp()

    # def pred_target_q_value(self, obs_action):
    #     target_q = super().pred_target_q_value(obs_action)
    #     if self.propagate_ent:
    #         self.dynamics_model.forward(obs_action)
    #         state_entropy = self.dynamics_model.state_entropy.sum(-1, keepdim=True)
    #         if self.dim_averaged_entropy:
    #             state_entropy /= self.dim_obs + 1
    #         target_q += self.eta.detach() * state_entropy
    #     return target_q

    # def pred_terminal_q(self, obs_action):
    #     terminal_q = super().pred_terminal_q(obs_action)
    #     self.dynamics_model.forward(obs_action)
    #     state_entropy = self.dynamics_model.state_entropy.sum(-1, keepdim=True)
    #     if self.dim_averaged_entropy:
    #         state_entropy /= self.dim_obs + 1
    #     state_entropy_term = self.eta.detach() * state_entropy.t()
    #     if self.propagate_ent:
    #         terminal_q += state_entropy_term
    #     else:
    #         terminal_q += state_entropy_term - state_entropy_term.detach()
    #
    #     self._stack_state_entropies.append(state_entropy)
    #     self._info.update(
    #         **{
    #             STATE_H: state_entropy.detach().mean(),
    #             "state_entropy_term": state_entropy_term.detach().mean(),
    #         }
    #     )
    #     return terminal_q

    # def rollout(
    #     self,
    #     ctx_modules,
    #     obs,
    #     action,
    #     rollout_horizon: int,
    #     done=None,
    #     log_pi=None,
    #     log=False,
    #     prediction_strategy=None,
    #     **kwargs,
    # ):
    #     if log_pi is not None:
    #         obs_action = torch.cat([obs, action], -1)
    #         self.dynamics_model.forward(obs_action)
    #         state_entropy = self.dynamics_model.state_entropy.sum(-1)
    #         if self.dim_averaged_entropy:
    #             state_entropy /= self.dim_obs + 1
    #         # log_pi += self.eta.detach() * state_entropy
    #         self._stack_state_entropies.append(state_entropy)
    #
    #     return super().rollout(
    #         ctx_modules,
    #         obs,
    #         action,
    #         rollout_horizon,
    #         done,
    #         log_pi,
    #         log,
    #         prediction_strategy,
    #         **kwargs,
    #     )

    # def model_step_context(self, action, obs, **kwargs):
    #     # if action is None:
    #     #     action, log_pi, pi = self.sample_action(self.actor, obs, **kwargs)
    #     # else:
    #     #     log_pi = None
    #     obs, reward, done, info = self.dynamics_model.sample(action, obs, **kwargs)
    #     # Additional state entropy
    #     state_entropy = self.dynamics_model.state_entropy.sum(-1)
    #     if self.dim_averaged_entropy:
    #         state_entropy /= self.dim_obs + 1
    #     # state_entropy_term = self.eta.detach() * state_entropy
    #     # if self.propagate_ent:
    #     #     reward += state_entropy_term
    #     # else:
    #     #     reward += state_entropy_term - state_entropy_term.detach()
    #
    #     self._stack_state_entropies.append(state_entropy)
    #     # self._info.update(
    #     #     **{
    #     #         # STATE_H: state_entropy.detach().mean(),
    #     #         "state_entropy_term": state_entropy_term.detach().mean(),
    #     #     }
    #     # )
    #     return obs, reward, done, info

    # def eval_rollout(
    #     self,
    #     batch_sa,
    #     log_pis,
    #     batch_masks,
    #     rewards,
    #     alpha,
    #     last_sa,
    #     discounts,
    #     critic,
    #     critic_target,
    # ):
    #     with torch.no_grad():
    #         if self.bounded_critic or self.scaled_critic:
    #             reward_pi = (
    #                 rewards
    #                 - alpha * log_pis[:-1, :]
    #                 + self.eta * torch.vstack(self._stack_state_entropies)[:-1, :]
    #             )
    #             reward_lb, reward_ub = torch.quantile(reward_pi, self.q_th)
    #             self.update_critic_bound(reward_lb, reward_ub)
    #         q_values = critic_target(last_sa)
    #         target_rewards = torch.cat([rewards, q_values[None, ..., 0]])
    #         action_entropy = -alpha.detach() * log_pis[1:, ...]
    #         state_entropy = (
    #             self.eta * torch.vstack(self._stack_state_entropies)[1:, ...]
    #         )
    #         target_rewards[1:, ...].add_(state_entropy + action_entropy)
    #         target_values = discounts.mm(target_rewards * batch_masks).unsqueeze(-1)
    #
    #     pred_values = critic(batch_sa.detach())
    #     pred_values = pred_values.view(
    #         self.mve_horizon, self.batch_size, self.num_critic_ensemble
    #     )
    #     deviation = discounts.shape[1] - discounts.shape[0]
    #     pred_values.mul_(batch_masks[:-deviation, :, None])
    #     return target_values, pred_values

    # def _actor_loss(self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False):
    #     pred_qs = ctx_modules.pred_terminal_q(batch_sa[..., -self.batch_size :, :])
    #     mc_q_pred = torch.cat([rewards, pred_qs])
    #     action_entropy = -ctx_modules.alpha.detach() * log_pis
    #     self._state_entropy = torch.vstack(self._stack_state_entropies)
    #     state_entropy_term = self.eta.detach() * self._state_entropy
    #     mc_q_pred.add_(state_entropy_term + action_entropy)
    #     mc_v_pred = self.discount_mat[:1, :].mm(mc_q_pred * masks).t()
    #     return -mc_v_pred.mean()
