import os
import random
import warnings
from typing import Dict, Iterator, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim import Adam
from typing_extensions import deprecated

from lambda_ac.nn.common import hard_update, soft_update
from lambda_ac.nn.encoder_models import IdentityEncoder
from lambda_ac.nn.model_loss import ModelLossModule
from lambda_ac.replay.replay_memory import DataBuffer
from lambda_ac.rl_types import (
    EncoderActorModule,
    EncoderCriticModule,
    EncoderModelBasedActorCriticAgent,
    EncoderModelNetwork,
    ModelTrajectory,
    PlanningStrategy,
)
from lambda_ac.util.model_util import (
    get_model_target,
    get_real_target,
    rollout_model_with_actions,
    rollout_with_vf_target,
)
from lambda_ac.util.pixel_util import batch_to_cuda
from lambda_ac.util.schedulers import LinearSchedule
from lambda_ac.util.torch_util import get_gradient_angle


class QPN(EncoderModelBasedActorCriticAgent):
    def __init__(
        self,
        critic: EncoderCriticModule,
        critic_target: EncoderCriticModule,
        actor: EncoderActorModule,
        model: EncoderModelNetwork,
        planning_strategy: PlanningStrategy,
        actor_horizon_scheduler: LinearSchedule,
        critic_horizon_scheduler: LinearSchedule,
        agent_hidden_dim: int,
        agent_hidden_layers: int,
        encoder_normalize: bool,
        critic_spectral_norm: bool,
        model_spectral_norm: bool,
        lr: float = 1e-4,
        actor_lr: float = 1e-4,
        model_lr: float = 1e-4,
        encoder_lr: float = 1e-4,
        critic_gamma: float = 0.99,
        actor_gamma: float = 1.0,
        tau: float = 0.005,
        target_update_interval: int = 1,
        model_grad_clip: float = 1e8,
        rl_grad_clip: float = 1e8,
        n_step_target_depth: int = 1,
        actor_rollout_depth: int = 1,
        critic_rollout_depth: int = 1,
        model_train_depth: int = 1,
        rho: float = 1.0,
        automatic_entropy_tuning: bool = True,
        alpha: float = 1.0,
        share_encoder: bool = False,
        update_encoder_model: bool = True,
        update_encoder_critic: bool = True,
        update_encoder_actor: bool = True,
        use_svg_policy_update: bool = False,
        use_model_value_expansion: bool = False,
        exploration_noise_model: bool = False,
        td_average: bool = False,
        update_old_values: bool = False,
        encoder_delayed_target: bool = False,
        update_encoder_every_n_steps: int = 1,
        adversarial_regularization: bool = False,
        discretize_done_actor: bool = False,
        discretize_done_critic: bool = False,
        depend_on_hidden: bool = False,
        predict_done: bool = False,
        robust_regularization: bool = False,
        robust_lambda: float = 0.1,
        model_based_search: bool = False,
        ensemble_size: int = 1,
        do_redq: bool = False,
        use_muzero_target: bool = False,
        start_model_target_from_zero: bool = False,
        model_losses: List[Tuple[str, float]] = list(),
        device: str = "cuda",
    ):
        """_summary_

        Args:
            critic (EnsembleEncoderCriticModule): _description_
            critic_target (EnsembleEncoderCriticModule): _description_
            actor (EnsembleEncoderActorModule): _description_
            model (EncoderModelNetwork): _description_
            agent_hidden_dim (int): _description_
            agent_hidden_layers (int): _description_
            encoder_normalize (bool): _description_
            critic_spectral_norm (bool): _description_
            lr (float, optional): _description_. Defaults to 1e-4.
            model_lr (float, optional): _description_. Defaults to 1e-4.
            encoder_lr (float, optional): _description_. Defaults to 1e-4.
            critic_gamma (float, optional): _description_. Defaults to 0.99.
            actor_gamma (float, optional): _description_. Defaults to 1.0.
            tau (float, optional): _description_. Defaults to 0.005.
            target_update_interval (int, optional): _description_. Defaults to 1.
            grad_clip (float, optional): _description_. Defaults to 1e8.
            n_step_target_depth (int, optional): _description_. Defaults to 1.
            actor_rollout_depth (int, optional): _description_. Defaults to 1.
            critic_rollout_depth (int, optional): _description_. Defaults to 1.
            model_train_depth (int, optional): _description_. Defaults to 1.
            automatic_entropy_tuning (bool, optional): _description_. Defaults to True.
            alpha (float, optional): _description_. Defaults to 1.0.
            share_encoder (bool, optional): _description_. Defaults to False.
            update_encoder_model (bool, optional): _description_. Defaults to True.
            update_encoder_critic (bool, optional): _description_. Defaults to True.
            update_encoder_actor (bool, optional): _description_. Defaults to True.
            use_svg_policy_update (bool, optional): _description_. Defaults to False.
            use_model_value_expansion (bool, optional): _description_. Defaults to False.
            update_old_values (bool, optional): _description_. Defaults to False.
            encoder_delayed_target (bool, optional): _description_. Defaults to False.
            update_encoder_every_n_steps (int, optional): _description_. Defaults to 1.
            adversarial_regularization (bool, optional): _description_. Defaults to False.
            discretize_done_actor (bool, optional): _description_. Defaults to False.
            discretize_done_critic (bool, optional): _description_. Defaults to False.
            robust_regularization (bool, optional): _description_. Defaults to False.
            robust_lambda (float, optional): _description_. Defaults to 0.1.
            model_based_search (bool, optional): _description_. Defaults to False.
            ensemble_size (int, optional): _description_. Defaults to 1.
            do_redq (bool, optional): _description_. Defaults to False.
            model_losses (List[str], optional): _description_. Defaults to list().
            device (str, optional): _description_. Defaults to "cuda".
        """
        super().__init__(critic, critic_target, actor, model)

        # actor critic hyper parameters
        print(type(actor))
        self.lr = lr
        self.actor_lr = actor_lr
        self.model_lr = model_lr
        self.encoder_lr = encoder_lr
        self.critic_gamma = critic_gamma
        self.actor_gamma = actor_gamma
        self.tau = tau
        self.rho = rho
        self.target_update_interval = target_update_interval
        self.model_grad_clip = model_grad_clip
        self.rl_grad_clip = rl_grad_clip
        self.ensemble_size = ensemble_size
        self.do_redq = do_redq
        self.device = device
        self.model_losses = model_losses
        self.use_muzero_target = use_muzero_target
        self.start_model_target_from_zero = start_model_target_from_zero

        # depth parameters
        self.n_step_target_depth = n_step_target_depth
        self.actor_rollout_depth = actor_rollout_depth
        self.critic_rollout_depth = critic_rollout_depth
        self.model_train_depth = model_train_depth

        # SAC target entropy settings
        self.automatic_entropy_tuning = automatic_entropy_tuning
        self.target_alpha = -float(self.actor.action_dim)
        self.log_alpha = torch.tensor(np.log(alpha)).to(self.device)
        self.log_alpha.requires_grad = True

        # encoder settings
        self.share_encoder = share_encoder
        if self.share_encoder:
            self.actor.init_encoder(self.critic.encoder)
            self.model.init_encoder(self.critic.encoder)
        self.detach_encoder_model = not update_encoder_model and self.share_encoder
        self.detach_encoder_critic = not update_encoder_critic and self.share_encoder
        self.detach_encoder_actor = not update_encoder_actor and self.share_encoder

        # model usage settings
        self.encoder_delayed_target = encoder_delayed_target
        self.update_encoder_every_n_steps = update_encoder_every_n_steps
        self.use_svg_policy_update = use_svg_policy_update
        self.use_model_value_expansion = use_model_value_expansion
        self.model_data_ratio = 0.0
        self.exploration_noise_model = exploration_noise_model
        self.use_td_average = td_average

        # regularization settings
        self.critic_spectral_norm = critic_spectral_norm
        self.robust_regularization = robust_regularization
        self.robust_lambda = robust_lambda
        self.adversarial_regularization = adversarial_regularization

        # done discretization settings
        self.discretize_done_actor = discretize_done_actor
        self.discretize_done_critic = discretize_done_critic
        self.predict_done = predict_done

        # planning settings
        self.model_based_search = model_based_search
        self.planning_strategy = planning_strategy
        self.planning_strategy.set_networks(self.actor, self.critic, self.model)

        self.actor_horizon_scheduler = actor_horizon_scheduler
        self.critic_horizon_scheduler = critic_horizon_scheduler

        if (
            self.critic_spectral_norm
            + self.robust_regularization
            + self.adversarial_regularization
            > 1
        ):
            warnings.warn(
                "Multiple regularization methods are used. This is not recommended.",
                UserWarning,
            )

        self.model_loss_module = ModelLossModule(
            self.model, self.actor, self.critic_target
        )

        # lambda_ac settings
        self.update_old_values = update_old_values
        self.update_idx = 0
        self.dropout_matrix = torch.ones(self.ensemble_size).to(self.device)

        self.model.to(self.device)
        self.critic.to(self.device)
        self.actor.to(self.device)
        self.critic_target.to(self.device)
        self.model_loss_module.to(self.device)

        for loss in model_losses:
            self.model_loss_module.register_loss(loss[0], loss[1])

        # optimizer initialization
        self._init_optimizers()

        hard_update(self.critic_target, self.critic)

        self.device = torch.device(device)

    def _init_optimizers(self):
        if self.share_encoder:
            self.critic_optim = Adam(self.critic.head.parameters(), lr=self.lr)
            self.model_optim = Adam(
                list(self.model.latent_network.parameters())
                + self.model_loss_module.parameter_list,
                lr=self.model_lr,
            )
            if not isinstance(self.critic.encoder, IdentityEncoder):
                self.encoder_optim = Adam(
                    self.critic.encoder.parameters(), lr=self.encoder_lr
                )
            else:
                self.encoder_optim = None
            self.actor_optim = Adam(self.actor.head.parameters(), lr=self.actor_lr)
        else:
            self.actor_optim = Adam(self.actor.parameters(), lr=self.actor_lr)
            self.model_optim = Adam(
                list(self.model.parameters()) + list(self.critic.head.parameters()),
                lr=self.lr,
            )

        if self.automatic_entropy_tuning:
            self.alpha_optim = Adam([self.log_alpha], lr=self.lr)

    def register_model_loss(self, loss_type: str, weight: float = 1.0):
        self.model_loss_module.register_loss(loss_type, weight)

    @property
    def alpha(self):
        return torch.exp(self.log_alpha)

    @torch.no_grad()
    def select_action(
        self,
        state: torch.Tensor,
        eval=False,
        step: int = 0,
        episode: int = 0,
    ) -> torch.Tensor:
        """
        Select action for given state

        Args:
            state (torch.Tensor): Input state to select action for
            eval (bool, optional): Whether to use the mean of the distribution or sample from it. Defaults to False.

        Returns:
            torch.Tensor: Action
        """
        self.actor.eval()
        self.model.eval()
        if state.dim() == 1 or state.dim() == 3:
            state = state.unsqueeze(0)
        state = state.float().to(self.device)

        return self.planning_strategy.plan(
            state=state, alpha=self.alpha, eval=eval, step=step, episode=episode
        )

    def update_critic(self, memory: DataBuffer, updates: int):
        """
        Update the critic

        Args:
            memory (DataBuffer): Replay buffer
            updates (int): Number of updates

        Returns:
            torch.Tensor: Loss
        """
        soft_update(
            self.critic_target,
            self.critic,
            self.tau,
            update_idx=0,
        )
        return {}

    def update_actor_real(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        done: torch.Tensor,
        discount: torch.Tensor,
    ):
        """
        Estimate the actor target

        Args:
            state (torch.Tensor): State
            discount (torch.Tensor): Discount factor
            model_data_ratio (float, optional): Ratio of model data. Defaults to 0.0.
            detach_encoder (bool, optional): Whether to detach the encoder. Defaults to False.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Target and action used to compute the target
        """

        with torch.no_grad():
            trajectory = rollout_model_with_actions(
                self.model.encoder(state),
                action,
                done,
                self.model,
                discretize_done=True,
                predict_done=False,
                detach_encoder=True,
                return_first=True,
            )
        action, log_pi, _ = self.actor.head.forward_sample_log_prob(trajectory.states)
        qf1, qf2 = self.critic.head(trajectory.states, action)
        # qf = (qf1 + qf2) / 2.0
        qf = torch.min(qf1, qf2)
        target = qf - self.alpha.detach() * log_pi

        actor_loss = -(target).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        actor_grad_norm = clip_grad_norm_(
            self.actor.head.parameters(), self.rl_grad_clip, error_if_nonfinite=True
        )
        self.actor_optim.step()
        self.actor_optim.zero_grad()

        with torch.no_grad():
            action, log_pi, dist = self.actor.head.forward_sample_log_prob(
                trajectory.states
            )

        if self.automatic_entropy_tuning:
            alpha_loss = -(
                self.log_alpha * (log_pi + self.target_alpha).detach()
            ).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha_optim.zero_grad()
        else:
            alpha_loss = torch.tensor(0.0).to(self.device)

        return {
            "actor/loss": actor_loss.item(),
            "actor/grad_norm": actor_grad_norm,
            "actor/alpha_loss": alpha_loss.item(),
            "actor_stats/avg_log_pi": torch.mean(log_pi).item(),
            "actor_stats/alpha": self.alpha,
            "actor_stats/avg_action": torch.mean(action).item(),
            "actor_stats/avg_abs_action": torch.mean(action.abs()).item(),
            "actor_stats/action_variance": torch.mean(dist.log_std).item(),
        }

    def update_actor_model(
        self,
        state: torch.Tensor,
        done: torch.Tensor,
        depth: int,
        discount: torch.Tensor,
    ):
        target, features, _ = get_model_target(
            state,
            done,
            discount,
            self.alpha,
            self.model,
            self.actor,
            self.critic,
            depth,
            # lambda x, y: (x + y) / 2.0,
            torch.min,
            get_policy_target=True,
            discretize_done=self.discretize_done_actor,
            predict_done=self.predict_done,
            detach_encoder=True,
            add_exploration_noise=self.exploration_noise_model,
            use_td_average=False,
        )

        actor_loss = -(target).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        actor_grad_norm = clip_grad_norm_(
            self.actor.head.parameters(), self.rl_grad_clip, error_if_nonfinite=True
        )
        self.actor_optim.step()
        self.actor_optim.zero_grad()

        with torch.no_grad():
            action, log_pi, dist = self.actor.head.forward_sample_log_prob(features)

        if self.automatic_entropy_tuning:
            alpha_loss = -(
                self.log_alpha * (log_pi + self.target_alpha).detach()
            ).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha_optim.zero_grad()
        else:
            alpha_loss = torch.tensor(0.0).to(self.device)

        return {
            "actor/loss": actor_loss.item(),
            "actor/grad_norm": actor_grad_norm,
            "actor/alpha_loss": alpha_loss.item(),
            "actor/mdr": self.model_data_ratio,
            "actor_stats/avg_log_pi": torch.mean(log_pi).item(),
            "actor_stats/alpha": self.alpha,
            "actor_stats/avg_action": torch.mean(action).item(),
            "actor_stats/avg_abs_action": torch.mean(action.abs()).item(),
            "actor_stats/action_variance": torch.mean(dist.log_std).item(),
        }

    def update_actor(self, memory: DataBuffer, updates: int):
        """
        Update the actor

        Args:
            memory (DataBuffer): Replay buffer
            updates (int): Number of updates

        Returns:
            torch.Tensor: Loss
        """
        with torch.no_grad():
            batch = next(memory)
            (state, action, reward, _, _, done, _, _) = batch_to_cuda(batch)
            discount = torch.ones_like(reward[:, 0]) * self.actor_gamma
        if self.use_svg_policy_update:
            mdr = self.model_data_ratio
        else:
            mdr = 0.0

        depth = int(self.actor_horizon_scheduler(updates))

        if random.random() < mdr and depth > 0:
            return self.update_actor_model(state, done, depth, discount)
        else:
            return self.update_actor_real(state, action, done, discount)

    @torch.no_grad()
    def update_priorities(
        self,
        memory: Iterator[Tuple[torch.Tensor, ...]],
        buffer: DataBuffer,
        updates: int,
    ):
        pass
        # with torch.no_grad():
        #     batch = next(memory)
        #     (
        #         state,
        #         action,
        #         reward,
        #         next_state,
        #         done,
        #         timelimit_mask,
        #         weights,
        #         idx,
        #     ) = batch_to_cuda(batch)
        # timelimit_mask = torch.logical_not(timelimit_mask)

        # mdr = self.model_data_ratio
        # depth = int(self.critic_horizon_scheduler(updates))
        # actions = actions.detach()
        # if self.detach_encoder_critic:
        #     feature = feature.detach()
        # qf1, qf2 = self.critic.head.forward(feature, actions)
        # prio_loss_1 = F.l1_loss(qf1, Q, reduction="none")
        # prio_loss_2 = F.l1_loss(qf2, Q, reduction="none")
        # prio_loss = 0.5 * (prio_loss_1 + prio_loss_2)

        # buffer.weight_update(idx.int() + idx_offset, prio_loss)

    def update_model(self, memory: Iterator[Tuple[torch.Tensor, ...]], updates: int):
        with torch.no_grad():
            batch = next(memory)
            (
                state,
                action,
                reward,
                next_state,
                done,
                timelimit_mask,
                weights,
                _,
            ) = batch_to_cuda(batch)
        timelimit_mask = torch.logical_not(timelimit_mask)
        weights = weights.squeeze(-1)

        mdr = self.model_data_ratio

        depth = int(self.critic_horizon_scheduler(updates))
        use_model = random.random() < mdr and depth > 0

        discount = torch.ones_like(reward[:, 0]) * self.critic_gamma

        Q, Q_action, trajectory, idx = rollout_with_vf_target(
            state,
            action,
            reward,
            next_state,
            done,
            timelimit_mask,
            discount,
            self.alpha,
            self.model,
            self.actor,
            self.critic_target,
            depth,
            torch.min,
            use_model=use_model,
            use_muzero=self.use_muzero_target,
            add_exploration_noise=self.exploration_noise_model,
            start_from_zero=self.start_model_target_from_zero,
        )

        (
            model_loss,
            model_reward_loss,
            model_done_loss,
            info,
        ) = self.multi_step_model_loss(
            trajectory,
            action,
            reward,
            next_state,
            done,
            weights,
            timelimit_mask,
        )

        idx_aux = torch.arange(state.shape[0]).to(state.device)
        mask = timelimit_mask[idx_aux, idx]
        weighing = self.rho**idx * mask.view(-1) * weights[:, 0]
        # Q_action = Q_action.detach()
        feature = trajectory.states[idx_aux, idx]
        if self.detach_encoder_critic:
            feature = feature.detach()

        qf1, qf2 = self.critic.head.forward(feature, Q_action.detach())
        qf1_loss = torch.sum((qf1 - Q) ** 2, dim=-1) * weighing
        qf2_loss = torch.sum((qf2 - Q) ** 2, dim=-1) * weighing
        total_qf_loss = qf1_loss + qf2_loss

        # mf_q1, mf_q2 = self.critic.forward(state, action[:, 0])
        # with torch.no_grad():
        #     mf_q_target, _ = get_real_target(
        #         next_state[:, 0],
        #         action[:, 0],
        #         reward[:, 0],
        #         done[:, 0],
        #         discount,
        #         self.alpha,
        #         self.model,
        #         self.actor,
        #         self.critic_target,
        #         torch.min,
        #         depth=depth,
        #         use_model=use_model,
        #         add_exploration_noise=True,
        #     )
        # mf_q1_loss = torch.sum((mf_q1 - mf_q_target) ** 2, dim=-1)
        # mf_q2_loss = torch.sum((mf_q2 - mf_q_target) ** 2, dim=-1)

        # total_qf_loss += 0.1 * (mf_q1_loss + mf_q2_loss)

        self.model_optim.zero_grad()
        self.critic_optim.zero_grad()

        (model_loss + model_reward_loss + model_done_loss).backward(retain_graph=True)
        self.critic_optim.zero_grad()
        (0.1 * total_qf_loss).mean().backward()
        clip_grad_norm_(
            self.model.latent_network.parameters(),
            self.model_grad_clip,
            error_if_nonfinite=True,
        )
        clip_grad_norm_(
            self.critic.head.parameters(),
            self.rl_grad_clip,
            error_if_nonfinite=True,
        )
        self.critic_optim.step()
        self.model_optim.step()
        self.model_optim.zero_grad()
        self.critic_optim.zero_grad()

        if self.share_encoder and not isinstance(self.critic.encoder, IdentityEncoder):
            clip_grad_norm_(
                self.critic.encoder.parameters(),
                self.model_grad_clip,
                error_if_nonfinite=True,
            )
            self.encoder_optim.step()
            self.encoder_optim.zero_grad()

        return_dict = {
            "model/model_loss": model_loss.item(),
            "model/reward_loss": model_reward_loss.item(),
            "model/done_loss": model_done_loss.item(),
            "critic/loss": torch.mean(total_qf_loss).item(),
            "q_stats/avg_q_value": torch.mean(qf1).item(),
            "q_stats/avg_batch_reward": torch.mean(reward).item(),
            "q_stats/q_state_diversity": torch.mean(torch.var(qf1, dim=0)).item(),
        }
        return_dict.update(info)

        return return_dict

    def multi_step_model_loss(
        self,
        trajectory: ModelTrajectory,
        action: torch.Tensor,
        reward: torch.Tensor,
        next_state: torch.Tensor,
        done: torch.Tensor,
        weights: torch.Tensor,
        timelimit_mask: torch.Tensor,
    ) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Union[float, torch.Tensor]]
    ]:
        """
        state: (batch_size x obs_dim)
        action: (batch_size x nstep x action_dim)
        reward: (batch_size x nstep x 1)
        discount: float
        next_state: (batch_size x nstep x obs_dim)
        """
        nstep = action.shape[1]

        model_loss = torch.zeros_like(reward[:, 0])
        model_reward_loss = torch.zeros_like(reward[:, 0])
        model_done_loss = torch.zeros_like(reward[:, 0])

        # logging losses
        vaml_loss = torch.zeros_like(reward[:, 0])
        norm_vaml_loss = torch.zeros_like(reward[:, 0])
        mse_loss = torch.zeros_like(reward[:, 0])
        classification_accuracy = 0.0

        for n in range(nstep):
            with torch.no_grad():
                if self.encoder_delayed_target:
                    next_feature_encoding = self.critic_target.encoder(next_state[:, n])
                else:
                    next_feature_encoding = self.model.encoder(next_state[:, n])

            model_next_encoding = trajectory.encodings[n + 1]
            model_reward = trajectory.rewards[:, n]
            model_done = trajectory.done_predictions[:, n]

            weighing = (
                self.rho**n
                * weights[:, n].view(-1, 1)
                * timelimit_mask[:, n].view(-1, 1)
            )

            model_loss += weighing * self.model_loss_module(
                model_next_encoding,
                next_feature_encoding,
                next_state[:, n],
            )

            model_reward_loss += weighing * (model_reward - reward[:, n]) ** 2

            model_done_loss += weighing * F.binary_cross_entropy_with_logits(
                model_done, done[:, n].float(), reduction="none"
            )

            with torch.no_grad():
                vaml_loss += weighing * self.model_loss_module._vaml_loss(
                    model_next_encoding,
                    next_feature_encoding,
                    next_state[:, n],
                )
                norm_vaml_loss += (
                    weighing
                    * self.model_loss_module._normalized_vaml_loss(
                        model_next_encoding,
                        next_feature_encoding,
                        next_state[:, n],
                    )
                )

                mse_loss += weighing * (
                    self.model_loss_module._mse_loss(
                        model_next_encoding,
                        next_feature_encoding,
                        next_state[:, n],
                    )
                )
                classification_accuracy += torch.mean(
                    ((model_done > 0.0) == done[:, n])[
                        torch.where(timelimit_mask[:, n])
                    ].float()
                )

        model_loss = torch.mean(model_loss / nstep)
        vaml_loss = torch.mean(vaml_loss / nstep)
        norm_vaml_loss = torch.mean(norm_vaml_loss / nstep)
        mse_loss = torch.mean(mse_loss / nstep)
        model_reward_loss = torch.mean(model_reward_loss / nstep)
        model_done_loss = torch.mean(model_done_loss / nstep)
        classification_accuracy = classification_accuracy / nstep

        return (
            model_loss,
            model_reward_loss,
            model_done_loss,
            {
                "model/vaml_loss": vaml_loss.mean(),
                "model/norm_vaml_loss": norm_vaml_loss.mean(),
                "model/mse_loss": mse_loss.mean(),
                "model/encoder_diversity": trajectory.states[:, 0]
                .encoded.std(dim=0)
                .mean(),
                "model/classification_accuracy": classification_accuracy,
            },
        )

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists("checkpoints/"):
            os.makedirs("checkpoints/")
        if ckpt_path is None:
            ckpt_path = "checkpoints/sac_checkpoint_{}_{}".format(env_name, suffix)
        torch.save(
            {
                "actor_state_dict": self.actor.state_dict(),
                "critic_state_dict": self.critic.state_dict(),
                "critic_target_state_dict": self.critic_target.state_dict(),
                "actor_optimizer_state_dict": self.actor_optim.state_dict(),
                "critic_optimizer_state_dict": self.critic_optim.state_dict(),
                "decoder_state_dict": self.model_loss_module.decoder.state_dict(),
                "model_state_dict": self.model.state_dict(),
                "model_optimizer_state_dict": self.model_optim.state_dict(),
                "encoder_optimizer_state_dict": self.encoder_optim.state_dict()
                if self.encoder_optim is not None
                else None,
                "update_idx": self.update_idx,
                "update_matrix": self.dropout_matrix,
            },
            ckpt_path,
        )

    # Load model parameters
    def load_checkpoint(self, env_name, suffix="", ckpt_path=None, evaluate=False):
        if ckpt_path is None:
            ckpt_path = "checkpoints/sac_checkpoint_{}_{}".format(env_name, suffix)
        print("Loading models from {}".format(ckpt_path))

        checkpoint = torch.load(ckpt_path)
        self.actor.load_state_dict(checkpoint["actor_state_dict"])
        self.critic.load_state_dict(checkpoint["critic_state_dict"])
        self.critic_target.load_state_dict(checkpoint["critic_target_state_dict"])
        self.model_loss_module.decoder.load_state_dict(checkpoint["decoder_state_dict"])
        self.model.load_state_dict(checkpoint["model_state_dict"])

        self.update_idx = checkpoint["update_idx"]
        self.dropout_matrix = checkpoint["update_matrix"]

        if self.share_encoder:
            self.actor.init_encoder(self.critic.encoder)
            self.model.init_encoder(self.critic.encoder)

        self._init_optimizers()

        self.actor_optim.load_state_dict(checkpoint["actor_optimizer_state_dict"])
        self.critic_optim.load_state_dict(checkpoint["critic_optimizer_state_dict"])
        self.model_optim.load_state_dict(checkpoint["model_optimizer_state_dict"])
        if self.encoder_optim is not None:
            self.encoder_optim.load_state_dict(
                checkpoint["encoder_optimizer_state_dict"]
            )
