"""Variational Exploration."""

from typing import NamedTuple, Callable
from contextlib import nullcontext, contextmanager
from dataclasses import dataclass
from tqdm import trange
import hydra
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam

from erl_lib.agent.svg import SVGAgent, ContextModules
from erl_lib.agent.model_based.modules.gaussian_mlp import Model
from erl_lib.agent.module.critic.two_heads import TwoHeadsCritic
from erl_lib.agent.module.actor.two_heads import TwoHeadsDGA

# from erl_lib.base.constant import
from erl_lib.agent.model_based.modules.variational_model import VariationalModel

# from erl_lib.base.module import WithoutReplacementBuffer
from erl_lib.util.misc import (
    calc_grad_norm,
    soft_update_params,
    ReplayBuffer,
)
from erl_lib.base import (
    OBS,
    ACTION,
    REWARD,
    NEXT_OBS,
    MASK,
    Q_MEAN,
    Q_STD,
    KEY_ACTOR_LOSS,
    KEY_CRITIC_LOSS,
    CTX_EXPLR,
    KL,
)


# @dataclass
# class ContextModules:
#     actor: torch.nn.Module
#     actor_optimizer: torch.optim.Optimizer
#     pred_q: Callable
#     pred_target_q: Callable
#     pred_terminal_q: Callable
#     critic: torch.nn.Module
#     critic_target: torch.nn.Module
#     critic_optimizer: torch.optim.Optimizer
#     alpha: torch.Tensor
#     model_step: Callable
#     buffer: ReplayBuffer
#     # is_explore: bool


class VariationalExploration(SVGAgent):
    greedy_exploration = False
    dynamics_model: Model = None
    # beta_lr: float = None
    model_type: str

    def __init__(
        self,
        actor,
        lr,
        # VIE specific arguments
        v_lr_ratio=1,
        beta_init=1e-1,
        beta_min=0.0,
        lr_beta: float = 0,
        improvement_threshold=0.05,
        std_scaled_improve: bool = False,
        epsilon: float = 1e-8,
        ema_factor: float = 0.1,
        greedy_behavior_ratio: float = 1.0,
        sep_act_optim: bool = True,
        separate_rollout_buffer=False,
        target_kl_coef: float = 0.1,
        **kwargs,
    ):
        self.v_lr = lr * v_lr_ratio
        # self.beta_init = float(beta_init)
        self.lr_beta = lr_beta
        self.beta_min = beta_min
        self.improvement_threshold = improvement_threshold
        self.std_v_improve = std_scaled_improve
        self.greedy_behavior_ratio = greedy_behavior_ratio
        self.separate_rollout_buffer = separate_rollout_buffer

        super(VariationalExploration, self).__init__(
            lr=lr,
            actor=actor,
            **kwargs,
        )
        self.kl_constraint = isinstance(self.dynamics_model, VariationalModel)

        # assert isinstance(self.actor, TwoHeadsDGA)
        self.separate_actor = not isinstance(self.actor, TwoHeadsDGA)
        if self.separate_actor:
            self.behavior_actor = hydra.utils.instantiate(actor).to(self.device)
        else:
            self.behavior_actor = self.actor
        self.sep_act_optim = sep_act_optim or self.separate_actor

        self.behavior_raw_alpha = (
            self.raw_alpha.clone().detach().requires_grad_(self.learnable_alpha)
        )

        self.is_explore = False
        self._last_variational_q_mean = None
        self._last_greedy_q_mean = None

        self.model_out_size = self.dynamics_model.dim_output

        self.target_kl = 0.5 * self.dim_obs * target_kl_coef

        raw_init_beta = np.log(max(beta_init, epsilon))
        self.raw_beta = nn.Parameter(
            torch.tensor(raw_init_beta, dtype=torch.float32, device=self.device),
            requires_grad=0 < self.target_kl,
        )

        self.init_optimizer()

        self.epsilon = torch.tensor(epsilon, device=self.device)
        self.ema_imprv_scale = None
        self.ema_var_std = None
        self.greedy_behavior = False
        self.ema_factor = torch.tensor(ema_factor, device=self.device)
        self._last_kl = self.epsilon.clone()

        if self.separate_rollout_buffer:
            self.rollout_buffer_v = ReplayBuffer(
                self.capacity,
                self.device,
                split_section_dict={OBS: self.dim_obs},
            )

    def observe(self, obs, action, reward, next_obs, terminated, truncated, info):
        if self.step == 0 and 0.0 < self.greedy_behavior_ratio:
            p = self.greedy_behavior_ratio
            self.greedy_behavior = np.random.choice([True, False], p=(p, 1 - p))
        return super().observe(
            obs, action, reward, next_obs, terminated, truncated, info
        )

    def _act(self, obs, sample=False):
        with torch.no_grad():
            if isinstance(obs, np.ndarray):
                obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)

            if self.normalize_input:
                obs = self.input_normalizer.normalize(obs)

            # if sample and self.greedy_behavior_ratio < 1.0:
            #     p = self.greedy_behavior_ratio
            #     sample = np.random.choice([True, False], p=(p, 1 - p))
            # sample &= not self.greedy_behavior

            if sample:
                if not self.separate_actor:
                    self.actor.is_optimistic = True
                dist = self.actor(obs)
                action = dist.sample()
            else:
                if not self.separate_actor:
                    self.actor.is_optimistic = False
                dist = self.actor(obs)
                action = dist.mean
            action = action.cpu().numpy()
        return action

    def init_optimizer(self):
        super().init_optimizer()
        if self.dynamics_model is not None:
            # self.behavior_actor_optimizer = Adam(
            #     [
            #         {
            #             "params": self.dynamics_model.variational_parameters(),
            #             "lr": self.v_lr,
            #             # "weight_decay": self.dynamics_model.weight_decay_v,
            #             # "betas": (0.0, 0.999),
            #         },
            #         {"params": self.behavior_actor.parameters()},
            #         {"params": [self.behavior_raw_alpha]},
            #     ],
            #     lr=self.lr,
            # )

            if self.sep_act_optim:
                params = [
                    {
                        "params": self.dynamics_model.variational_parameters(),
                        "lr": self.v_lr,
                        # "weight_decay": self.dynamics_model.weight_decay_v,
                        # "betas": (0.0, 0.999),
                    },
                    {"params": self.behavior_actor.parameters()},
                    {"params": [self.behavior_raw_alpha]},
                ]
                if self.raw_beta.requires_grad:
                    params.append({"params": [self.raw_beta], "lr": self.lr_beta})
                self.behavior_actor_optimizer = Adam(params, lr=self.lr)
            else:
                self.actor_optimizer.add_param_group(
                    {
                        "params": self.dynamics_model.variational_parameters(),
                        "lr": self.v_lr,
                    }
                )
                self.actor_optimizer.add_param_group(
                    {
                        "params": [self.behavior_raw_alpha],
                    }
                )
                if self.raw_beta.requires_grad:
                    self.actor_optimizer.add_param_group(
                        {"params": [self.raw_beta], "lr": self.lr_beta}
                    )

                self.behavior_actor_optimizer = self.actor_optimizer

        if self.separate_critic:
            self.behavior_critic_optimizer = Adam(
                self.critic_v.parameters(), lr=self.lr * self.critic_lr_ratio
            )
        else:
            self.behavior_critic_optimizer = self.critic_optimizer

    def build_critics(self, critic_cfg):
        super().build_critics(critic_cfg)
        if isinstance(self.critic, TwoHeadsCritic):
            self.critic_v = self.critic
            self.critic_target_v = self.critic_target
            self.separate_critic = False
        else:
            # if self.separate_critic:
            raise NotImplementedError
            self.critic_v = hydra.utils.instantiate(critic_cfg).to(self.device)
            self.critic_target_v = hydra.utils.instantiate(critic_cfg).to(self.device)
            self.critic_target_v.load_state_dict(self.critic_v.state_dict())
            self.separate_critic = True

    def _actor_loss(
        self, ctx_modules, log_pis, batch_sa, rewards, masks, info, log=False
    ):
        # mc_v_target = super()._actor_loss(*args, log=False, **kwargs)
        pred_qs = ctx_modules.pred_terminal_q(batch_sa[..., -self.batch_size :, :])
        mc_q_pred = torch.cat([rewards, pred_qs])
        mc_q_pred.sub_(ctx_modules.alpha.detach() * log_pis)
        mc_v_target = self.discount_mat[:1, :].mm(mc_q_pred * masks)
        with torch.no_grad():
            value_mean = mc_v_target.mean()
            value_std = mc_v_target.std()
            if self.is_explore:
                self._last_variational_q_mean = value_mean
                self._last_variational_q_std = value_std
            else:
                self._last_greedy_q_mean = value_mean
                self._last_greedy_q_std = value_std
        return -mc_v_target.mean()

    def _update(self, log=False, **kwargs):
        info_old = {}
        if self.is_explore:
            for key, value in self._info.items():
                if not key.startswith("v_"):
                    info_old[key] = value

        super(VariationalExploration, self)._update(log=log, **kwargs)

        if not self.is_explore:
            # Was explored after exiting context
            for key, value in self._info.items():
                if not key.startswith("v_"):
                    info_old[f"v_{key}"] = value
        self._info.update(**info_old)

        actor_updated = not (
            (self._last_variational_q_mean is None)
            or (self._last_greedy_q_mean is None)
        )
        if (not self.is_explore and actor_updated) or log:
            with torch.no_grad():
                greedy_q_mean = self._last_greedy_q_mean
                greedy_q_std = self._last_greedy_q_std
                greedy_q_abs = greedy_q_mean.abs()
                variational_q_mean = self._last_variational_q_mean
                variational_q_std = self._last_variational_q_std
                # variational_q_abs = variational_q_mean.abs()

                if self.ema_imprv_scale is None:
                    # The initial value of improvement normalization scale
                    self.ema_imprv_scale = greedy_q_abs
                else:
                    # Otherwise, take exponential moving average
                    self.ema_imprv_scale.mul_(1 - self.ema_factor)
                    self.ema_imprv_scale.add_(self.ema_factor * greedy_q_abs)
                    self.ema_imprv_scale.clamp_min_(self.epsilon)

                improvement = (
                    variational_q_mean - greedy_q_mean
                ) / self.ema_imprv_scale
                # Update beta
                if (
                    0 == self.target_kl
                    and 0 < self.lr_beta
                    and self.kl_constraint is True
                ):
                    new_beta = torch.relu(improvement - self.improvement_threshold)

                    if self.std_v_improve:
                        ema_metric = variational_q_std
                    else:
                        ema_metric = greedy_q_std
                    if self.ema_var_std is None:
                        self.ema_var_std = ema_metric
                    else:
                        self.ema_var_std.mul_(1 - self.ema_factor)
                        self.ema_var_std.add_(self.ema_factor * ema_metric)
                        # self.ema_var_std.clamp_min_(self.epsilon)
                    new_beta *= self.ema_var_std
                    beta_ = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta
                    beta = torch.relu(beta_ - self.beta_min) + self.beta_min
                    self.raw_beta.copy_(beta.log())

                self._info["v_beta"] = self.beta
                # self._info["v_kl"] = self._last_kl
                self._info["v_improvement"] = improvement
                self._info[Q_MEAN] = greedy_q_mean
                self._info[Q_STD] = greedy_q_std
                self._info["v_q_mean"] = variational_q_mean
                self._info["v_q_std"] = variational_q_std

        return self._info

    @contextmanager
    def policy_evaluation_context(self, explore=None, detach=False, **kwargs):
        if explore is None:
            explore = self.is_explore
        else:
            self.is_explore = bool(explore)

        if explore:
            actor = self.behavior_actor
            alpha = self.behavior_alpha
            critic = self.critic_v
            critic_target = self.critic_target_v

            pred_terminal_q = self.pred_terminal_v_q

            if self.kl_constraint:
                # critic_svg = self.wrap_critic_svg(critic)
                model_step = self.model_v_step_context
            else:
                # critic_svg = critic
                model_step = self.model_h_step_context

            actor_optimizer = self.behavior_actor_optimizer
            critic_optimizer = self.behavior_critic_optimizer

            if not self.separate_actor:
                actor.is_optimistic = True

            if not self.separate_critic:
                self.critic_v.set_latent_state(True)
                self.critic_target_v.set_latent_state(True)

            if self.separate_rollout_buffer:
                rollout_buffer = self.rollout_buffer_v
                # rollout_buffer = self.rollout_buffer
            else:
                rollout_buffer = self.rollout_buffer
        else:
            actor = self.actor
            alpha = self.alpha
            model_step = self.model_step_context

            critic = self.critic
            critic_target = self.critic_target

            pred_terminal_q = self.pred_terminal_q

            actor_optimizer = self.actor_optimizer
            critic_optimizer = self.critic_optimizer

            if not self.separate_actor:
                actor.is_optimistic = False

            if not self.separate_critic:
                critic.set_latent_state(False)
                critic_target.set_latent_state(False)

            rollout_buffer = self.rollout_buffer

        context_modules = ContextModules(
            actor,
            actor_optimizer,
            self.pred_q_value,
            self.pred_target_q_value,
            pred_terminal_q,
            critic,
            critic_target,
            critic_optimizer,
            alpha,
            model_step,
            rollout_buffer,
            detach,
        )
        try:
            yield context_modules
        finally:
            # self.model_env.model.pred_context = original_ctx
            self.is_explore = not self.is_explore
            if not self.separate_actor:
                self.actor.is_optimistic = True

    # def wrap_critic_svg(self, critic):
    #     def critic_forward(obs_action):
    #         pred_q_value = critic(obs_action)
    #
    #         self.dynamics_model.moment_dist(
    #             obs_action, with_pred=False, pred_context=CTX_EXPLR
    #         )
    #
    #         kl = self.dynamics_model._info.pop(KL)
    #         kl_ = kl - self.target_kl
    #         kl_term = kl_.detach() * self.beta - kl_ * self.beta.detach()
    #
    #         pred_q_value = pred_q_value + kl_term
    #
    #         # self._last_kl.mul_(self.discount)
    #         # self._last_kl.add_((1 - self.discount) * kl.detach().mean())
    #         return pred_q_value
    #
    #     return critic_forward

    def pred_terminal_v_q(self, obs_action):
        pred_q = self.pred_q_value(obs_action)
        pred_q = self._reduce(pred_q, self.actor_reduction)
        self.dynamics_model.moment_dist(
            obs_action, with_pred=False, pred_context=CTX_EXPLR
        )

        kl = self.dynamics_model._info.pop(KL)
        kl_ = kl - self.target_kl
        kl_term = kl_.detach() * self.beta - kl_ * self.beta.detach()

        pred_q = pred_q + kl_term

        # self._last_kl.mul_(self.discount)
        # self._last_kl.add_((1 - self.discount) * kl.detach().mean())
        return pred_q.t()

    def model_v_step_context(self, action, obs, log=False, **kwargs):
        next_obs, rewards, done_i, info = self.dynamics_model.sample(
            action,
            obs,
            pred_context=CTX_EXPLR,
            log=log,
            **kwargs,
        )
        kl = info.pop(KL)
        kl_ = kl - self.target_kl
        kl_term = kl_.detach() * self.beta - kl_ * self.beta.detach()
        rewards += kl_term
        if log:
            self._info["kl_term"] = kl_term.detach().mean()
            self._info[KL] = kl.detach().mean()
        return next_obs, rewards, done_i, info

    def model_h_step_context(self, action, obs, **kwargs):
        next_obs, rewards, done_i, info = self.model_env.step(
            action,
            obs,
            pred_context=CTX_EXPLR,
            **kwargs,
        )
        return next_obs, rewards, done_i, info

    def distribution_rollout(self, num_rollout_samples=None, **rollout_kwargs):
        if self.separate_rollout_buffer:
            super().distribution_rollout(
                num_rollout_samples=num_rollout_samples, explor=True
            )
            super().distribution_rollout(
                num_rollout_samples=num_rollout_samples, explor=False
            )
        else:
            super().distribution_rollout(num_rollout_samples=num_rollout_samples)

    def update_critic_bound(self, reward_lb, reward_ub):
        if not self.is_explore:
            super().update_critic_bound(reward_lb, reward_ub)

    @property
    def beta(self):
        return self.raw_beta.exp()

    @property
    def behavior_alpha(self):
        return self.behavior_raw_alpha.exp()
