from contextlib import contextmanager, nullcontext
import hydra
from dataclasses import dataclass
import time
from tqdm import trange
import numpy as np
import torch
from torch.optim import Adam

from erl_lib.agent.svg import SVGAgent
from erl_lib.agent.svg import ContextModules as BaseContextModules
from erl_lib.base import OBS, Q_MEAN
from erl_lib.util.misc import ReplayBuffer


Q_EPI_STD = "q_epi_std"
Q_DIFF_STD = "q_diff_std"
V_PRIOR = "v_prior"
IMPRV = "improvement"
REL_IMPRV = "relative_improvement"
XVE_TARGET = "xve_target"
XTRME_Q = "q_extreme"

UNBIASED_Q = "q_value_mean"
OPTIMISTIC_Q = "u_value_mean"
U_STD = "u_sample_std"
U_WIDTH = "u_width"
Q_WIDTH = "q_width"
Q_SAMPLE_STD = "q_sample_std"
MEAN_RATIO = "mean_ratio"


class PCMLP(SVGAgent):
    trange_kv = dict(Eta="eta", **SVGAgent.trange_kv)

    def __init__(
        self,
        actor,
        critic,
        *args,
        target_improvement=0.2,
        eta_init=1.0,
        lr_eta=3e-4,
        lr_norm: float = 0.01,
        lr_norm2: float = 0.001,
        moving_min: bool = True,
        eta_min=1e-8,
        eta_wd=0.0,
        critic_improve: bool = True,
        smooth_baseline: bool = True,
        # width_baseline: bool = False,
        eta_update: str = "opt_ratio_std",
        long_change_min: float = 1.0,
        eta_reset: bool = False,
        tune_after: bool = False,
        both_values: bool = False,
        **kwargs,
    ):
        self.target_improvement = max(target_improvement, 1e-8)
        # Need lr_norm2 < lr_norm
        if eta_update.startswith("opt_speed"):
            if lr_norm < lr_norm2:
                raise ValueError(f"Need lr_norm2 < lr_nor: but {lr_norm2} > {lr_norm}")
            else:
                self.target_speed = lr_norm2 / (lr_norm - lr_norm2) * target_improvement
                print(f"Target improvement speed ratio: {self.target_speed:.3f}")
                self.update_eta_inner = "inner" in eta_update

        self.lr_eta = lr_eta
        self.lr_norm = lr_norm
        self.lr_norm2 = lr_norm2
        self.eta_min = eta_min
        self.eta_wd = eta_wd
        self.moving_min = moving_min
        self.smooth_baseline = smooth_baseline
        # self.width_baseline = width_baseline
        self.eta_update = eta_update
        self.long_change_min = long_change_min
        self.eta_reset = eta_reset
        self.tune_after = tune_after
        self.both_values = both_values

        self.num_model_members = kwargs["dynamics_model"].num_members
        # size_exp = self.num_model_members, self.batch_size
        # self.done_act = self.done.unsqueeze(0).expand(size_exp)
        self.eta_init = torch.tensor(
            np.log(eta_min) if eta_init is None else np.log(eta_init),
            dtype=torch.float32,
            device=kwargs["device"],
        )
        self.raw_eta = torch.nn.Parameter(self.eta_init.clone(), requires_grad=True)

        super(PCMLP, self).__init__(*args, actor=actor, critic=critic, **kwargs)
        # Actor
        if self.tune_after:
            self.tuned_actor = hydra.utils.instantiate(actor).to(self.device)
            # self.tuned_raw_alpha = (
            #     self.raw_alpha.clone().detach().requires_grad_(self.learnable_alpha)
            # )
            # self.tuned_critic = hydra.utils.instantiate(critic).to(self.device)
            # self.tuned_critic_target = hydra.utils.instantiate(critic).to(self.device)

            self.tuned_rollout_buffer = ReplayBuffer(
                self.capacity,
                self.device,
                split_section_dict={OBS: self.dim_obs},
            )
        else:
            self.tuned_actor = self.actor
            self.tuned_raw_alpha = self.raw_alpha
            self.tuned_rollout_buffer = self.rollout_buffer

        # Critic
        if self.num_critic_ensemble % 2 != 0:
            raise ValueError(
                f"{self.num_critic_ensemble} != C x {self.num_model_members})"
            )
        self.num_critic_ensemble = int(self.critic.num_members / 2)
        self.init_optimizer()

        self._u_lb, self._u_ub, self._u_width, self._u_center = None, None, None, None
        self.eta_norm = None
        # Target speed
        self.short_norm = None
        self.long_norm = None

    def observe(self, obs, action, reward, next_obs, terminated, truncated, info):
        super().observe(obs, action, reward, next_obs, terminated, truncated, info)
        last_step = (
            self.seed_iters <= self.total_iters
            and 0 < self.iter
            and (
                self.total_iters == self.seed_iters
                or self.total_iters % self.iters_per_epoch == 0
            )
        )
        if self.tune_after and last_step:
            first_update = (self.seed_iters == self.total_iters) and self.step == 0
            self.tuned_rollout_buffer.clear()

            actor_state = {
                key: value.clone() for key, value in self.actor.state_dict().items()
            }
            critic_state = {
                key: value.clone() for key, value in self.critic.state_dict().items()
            }
            critic_target_state = {
                key: value.clone()
                for key, value in self.critic_target.state_dict().items()
            }
            raw_alpha = self.raw_alpha.data.clone()

            self.init_optimizer()
            # --------------- Agent Training -----------------
            self.update(tune=True)

            self.tuned_actor.load_state_dict(self.actor.state_dict())

            self.actor.load_state_dict(actor_state)
            self.critic.load_state_dict(critic_state)
            self.critic_target.load_state_dict(critic_target_state)
            self.raw_alpha.data.copy_(raw_alpha)

            self.logger.append(
                "policy_tune",
                {"iteration": self.total_iters},
                {Q_MEAN: self._info[UNBIASED_Q]},
            )

    @torch.no_grad()
    def _act(self, obs, sample=False):
        if isinstance(obs, np.ndarray):
            obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)

        if self.normalize_input:
            obs = self.input_normalizer.normalize(obs)

        if sample:
            dist = self.actor(obs)
            action = dist.sample()
        else:
            dist = self.tuned_actor(obs)
            action = dist.mean
        action = action.cpu().numpy()
        return action

    def eval_rollout(
        self,
        ctx_modules,
        batch_sa,
        log_pis,
        batch_masks,
        rewards,
        # alpha,
        last_sa,
        discounts,
        # critic,
        # critic_target,
    ):
        """

        Args:
            batch_sa: [M, (L+1)B, D]
            log_pis: [L+1, M, B]
            batch_masks: [L+1, M, B]
            rewards: [L, M, B]
            last_sa: [M, B, D]
            discounts: [L, L+1]

        Returns:
            target_values: [L, B, 2, 1]
            pred_values: [L, B, 2, C]

        """
        epi_bonus = torch.stack(self._stack_state_entropies)
        if self.bounded_critic or self.scaled_critic:
            with torch.no_grad():
                reward_pi = rewards - ctx_modules.alpha * log_pis[:-1, :]
                reward_lb, reward_ub = torch.quantile(reward_pi, self.q_th)
                (
                    self._q_lb,
                    self._q_ub,
                    self._q_center,
                    self._q_width,
                ) = self.update_critic_bound(
                    self._q_lb, self._q_ub, reward_lb, reward_ub
                )
                epi_lb, epi_ub = torch.quantile(epi_bonus, self.q_th)
                (
                    self._u_lb,
                    self._u_ub,
                    self._u_center,
                    self._u_width,
                ) = self.update_critic_bound(self._u_lb, self._u_ub, epi_lb, epi_ub)

        u_bonus_term = ctx_modules.eta * epi_bonus
        u_bonus_term = u_bonus_term - u_bonus_term.detach()
        rewards.add_(u_bonus_term)
        with torch.no_grad():
            q_values = ctx_modules.pred_target_q(last_sa)
            target_rewards = torch.stack([epi_bonus, rewards], -1)[..., None]
            target_rewards = torch.cat([target_rewards, q_values])  # [L+1, B, 2, 1]
            target_rewards[:, :, 1, 0] -= log_pis * ctx_modules.alpha.detach()
            target_values = target_rewards * batch_masks[..., None, None].expand(
                batch_masks.shape + (2, 1)
            )
            target_values = torch.sum(
                discounts[..., None, None, None] * target_values[None], 1
            )

        pred_values = ctx_modules.pred_q(batch_sa.detach())  # [L, B, 2, C]
        horizon = discounts.shape[1] - discounts.shape[0]
        mask_ensemble = batch_masks[:-horizon, ..., None, None]
        pred_values *= mask_ensemble
        with torch.no_grad():
            # u_value, q_value = pred_values.mean((0, 1, 3)).chunk(2)
            u_value, q_value = pred_values.mean(3).chunk(2, dim=2)
            mean_ratio = (u_value.abs() / q_value.abs().clamp_min(self.eta_min)).mean()
            if self.eta_update.endswith("median"):
                q_mean = torch.median(q_value)
                u_mean = torch.median(u_value)
            else:
                q_mean = q_value.mean()
                u_mean = u_value.mean()

            self._info.update(
                **{
                    UNBIASED_Q: q_mean,
                    OPTIMISTIC_Q: u_mean,
                    U_STD: u_value.std(),
                    Q_SAMPLE_STD: q_value.std(),
                    MEAN_RATIO: mean_ratio,
                    # U_WIDTH: u_width,
                    # Q_WIDTH: q_width,
                }
            )
        return target_values, pred_values

    def pred_q_value(self, obs_action, batch_size=None):
        """

        Args:
            obs_action: [B, D]

        Returns:
            pred_q: [L, B, 2, C]

        """
        self.critic.train(True)
        if batch_size is None:
            batch_size = self.batch_size
        shape = (-1, batch_size, 2, self.num_critic_ensemble)
        pred_q = self.critic(obs_action).view(shape)
        if self.scaled_critic:
            pred_q[:, :, 0, :] = pred_q[:, :, 0, :] * self._u_width + self._u_center
            pred_q[:, :, 1, :] = pred_q[:, :, 1, :] * self._q_width + self._q_center
            # pred_q = pred_q * self._q_width + self._q_center
        return pred_q

    def pred_target_q_value(self, obs_action):
        """

        Args:
            obs_action: [B, D]
        Returns:
            target_q: [L, B, 2, 1]
        """
        # -> [MC, B, D]
        self.critic_target.train(False)
        shape = (1, self.batch_size, 2, self.num_critic_ensemble)
        target_q = self.critic_target(obs_action).view(shape)
        target_q = self._reduce(target_q, dim=3, keepdim=True)
        if self.scaled_critic:
            target_q[:, :, 0, :] = target_q[:, :, 0, :] * self._u_width + self._u_center
            target_q[:, :, 1, :] = target_q[:, :, 1, :] * self._q_width + self._q_center
        if self.bounded_critic:
            target_q[:, :, 0, :] = self._u_ub - torch.relu(
                self._u_ub - target_q[:, :, 0, :]
            )
            target_q[:, :, 0, :] = self._u_lb + torch.relu(
                target_q[:, :, 0, :] - self._u_lb
            )
            target_q[:, :, 1, :] = self._q_ub - torch.relu(
                self._q_ub - target_q[:, :, 1, :]
            )
            target_q[:, :, 1, :] = self._q_lb + torch.relu(
                target_q[:, :, 1, :] - self._q_lb
            )
        return target_q

    # def pred_terminal_q(self, obs_action):
    #     self.critic.train(False)
    #     u, q = self.pred_q_value(obs_action).chunk(2, dim=2)
    #     pred_qss = q.squeeze(2) + u.squeeze(2) * self.eta.detach()
    #     pred_qs = self._reduce(pred_qss, self.actor_reduction, dim=2, keepdim=False)
    #     return pred_qs

    def pred_terminal_q(self, obs_action):
        self.critic.train(False)
        return self.pred_q_value(obs_action)

    def _actor_loss(self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False):
        u, q = ctx_modules.pred_terminal_q(batch_sa[..., -self.batch_size :, :]).chunk(
            2, dim=2
        )
        pred_qss = q.squeeze(2) + u.squeeze(2) * ctx_modules.eta
        pred_qs = self._reduce(pred_qss, self.actor_reduction, dim=2, keepdim=False)
        mc_q_pred = torch.cat([rewards, pred_qs])
        mc_q_pred.sub_(ctx_modules.alpha.detach() * log_pis)
        mc_v_pred = self.discount_mat[:1, :].mm(mc_q_pred * masks).t()
        return -mc_v_pred.mean()

    def model_step_context(self, action, model_state, **kwargs):
        predicts = super().model_step_context(action, model_state, **kwargs)
        state_entropy = self.dynamics_model.state_entropy
        self._stack_state_entropies.append(state_entropy)
        return predicts

    def update_actor(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        super().update_actor(ctx_modules, batch_sa, batch_mask, log_pis, rewards, log)
        if 0.0 < self.lr_eta and 0 < ctx_modules.eta:
            if self.eta_update.startswith("opt_ratio"):
                self.fixed_ratio_bonus_eta()
            elif self.eta_update.startswith("opt_speed"):
                if self.update_eta_inner:
                    self.fixed_speed(log)
                elif log:
                    self.fixed_speed(log)

            self._info.update(eta=self.eta)

    @torch.no_grad()
    def fixed_ratio_bonus_eta(self):
        q_value = self._info[UNBIASED_Q]
        q_std = self._info[Q_SAMPLE_STD]

        u_value = self._info[OPTIMISTIC_Q]
        u_std = self._info[U_STD]

        mean_ratio = self._info[MEAN_RATIO]

        if self.eta_update.endswith("std"):
            baseline = q_std
            u_range = u_std
        elif self.eta_update.endswith("mean"):
            baseline = q_std * 0.0 + 1.0
            u_range = mean_ratio
        else:
            baseline = q_value
            u_range = u_value

        # The initial baseline
        if self.eta_norm is None:
            self.eta_norm = baseline
            self.eta_norm2 = u_range
        # Or after
        else:
            self.eta_norm.lerp_(baseline, self.lr_norm)
            if self.smooth_baseline:
                # self.eta_norm2.lerp_(u_range * self.eta, self.lr_norm)
                self.eta_norm2.lerp_(u_range, self.lr_norm)
            else:
                # self.eta_norm2 = u_range * self.eta
                self.eta_norm2 = u_range

        uncertainty_bonus_term = self.eta_norm2 * self.eta

        eta_loss, ratio = self.eta_relative_loss(
            uncertainty_bonus_term, self.eta_norm, self.raw_eta
        )

        self.eta_optimizer.zero_grad()
        eta_loss.backward()
        self.eta_optimizer.step()

        self._info.update(
            **{
                "baseline": baseline,
                "bonus_ratio": ratio,
                # "uncertainty": mean_epi_std,
                "eta_loss": eta_loss.detach(),
                "uncertainty_bonus_term": uncertainty_bonus_term,
            }
        )
        self.raw_eta.data.clamp_(min=np.log(self.eta_min))

    @torch.enable_grad()
    def eta_relative_loss(self, uncertainty_term, q_value, raw_eta):
        eta = raw_eta.exp()
        denom = torch.abs(q_value).clamp_min(self.eta_min)
        ratio = uncertainty_term.abs() / denom
        return (
            eta * (ratio - self.target_improvement),
            ratio.detach(),
        )

    @torch.no_grad()
    def fixed_speed(self, last_step=False):
        q_value = self._info[UNBIASED_Q]
        q_std = self._info[Q_SAMPLE_STD]
        u_value = self._info[OPTIMISTIC_Q]

        # The initial baseline
        if self.eta_norm is None:
            self.eta_norm = q_value.clone()
            self.short_norm = q_value.clone()
            self.long_norm = q_value.clone()
        # Or after
        else:
            self.eta_norm.lerp_(q_value, self.lr_norm)
            if last_step:
                self.short_norm.lerp_(self.eta_norm, self.lr_norm)
                self.long_norm.lerp_(self.eta_norm, self.lr_norm2)

        long_improvement = torch.clamp_min(
            torch.abs(self.short_norm - self.long_norm), self.long_change_min
        )
        last_value = self.eta_norm
        if self.both_values:
            last_value = last_value + u_value * self.eta.detach()
        short_improvement = torch.abs(last_value - self.short_norm)

        change_speed = short_improvement / long_improvement
        # change_speed.clamp_(min=-10 * self.target_speed, max=10 * self.target_speed)
        change_speed.clamp_(min=0, max=10 * self.target_speed)

        if self.eta_update.endswith("abs"):
            scale = q_value.abs()
        elif self.eta_update.endswith("std"):
            scale = q_std
        else:
            scale = 1.0
        eta_loss, speed_diff = self.eta_speed_loss(change_speed, self.raw_eta, scale)

        self.eta_optimizer.zero_grad()
        eta_loss.backward()
        # if self.eta_grad_clip:
        #     nn.utils.clip_grad_norm_(self.raw_eta, 1.0)
        self.eta_optimizer.step()
        self.raw_eta.data.clamp_(min=np.log(self.eta_min))

        self._info.update(
            **{
                "eta_loss": eta_loss.detach(),
                "change_speed": change_speed,
                "speed_diff": speed_diff,
                "long_improvement": long_improvement,
                "short_improvement": short_improvement,
                "short_term_ema": self.short_norm,
                "long_term_ema": self.long_norm,
            }
        )

    @torch.enable_grad()
    def eta_speed_loss(self, change_speed, raw_eta, scale):
        eta = raw_eta.exp()
        # target_speed = self.lr_norm2 / (self.lr_norm - self.lr_norm2)
        speed_diff = change_speed - self.target_speed
        eta_loss = eta * speed_diff * scale
        return eta_loss, speed_diff.detach()

    @contextmanager
    def policy_evaluation_context(self, detach=False, tune=False, **kwargs):
        if tune:
            context_modules = ContextModules(
                self.actor,
                self.actor_optimizer,
                self.pred_q_value,
                self.pred_target_q_value,
                self.pred_terminal_q,
                self.critic,
                self.critic_target,
                self.critic_optimizer,
                self.alpha,
                self.model_step_context,
                self.tuned_rollout_buffer,
                detach,
                eta=0.0,
            )
        else:
            context_modules = ContextModules(
                self.actor,
                self.actor_optimizer,
                self.pred_q_value,
                self.pred_target_q_value,
                self.pred_terminal_q,
                self.critic,
                self.critic_target,
                self.critic_optimizer,
                self.alpha,
                self.model_step_context,
                self.rollout_buffer,
                detach,
                eta=self.eta,
            )

        self._stack_state_entropies = []

        try:
            yield context_modules
        finally:
            pass

    def init_optimizer(self):
        super().init_optimizer()
        if self.eta_reset:
            self.raw_eta.data.copy_(self.eta_init)
        self.eta_optimizer = torch.optim.Adam(
            [self.raw_eta], lr=self.lr_eta, weight_decay=self.eta_wd
        )

    # @property
    # def tuned_alpha(self):
    #     return self.tuned_raw_alpha.exp()

    @property
    def eta(self):
        return self.raw_eta.exp()


class ContextModules(BaseContextModules):
    eta: torch.Tensor = 0
