import io
import os
import pathlib
import warnings
from copy import deepcopy
import time
from typing import Union, Tuple, Optional, Dict, Any, List

import gym
import numpy as np
import torch
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F

from rl.sac.policies import SACCustomPolicy
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.common.buffers import ReplayBuffer, DictReplayBuffer
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.type_aliases import GymEnv, Schedule, MaybeCallback, TrainFreq, TrainFrequencyUnit, \
    RolloutReturn
from stable_baselines3.common.utils import polyak_update, should_collect_more_steps, safe_mean


class SACGNNBase(BaseAlgorithm):
    def __init__(
            self,
            policy: SACCustomPolicy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[ReplayBuffer] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
            support_multi_env: bool = True,
            monitor_wrapper: bool = True,
            sde_support: bool = True,
            rollout_logging_full_buffer: bool = False,
    ):

        super(SACGNNBase, self).__init__(
            policy=policy,
            env=env,
            policy_base=SACCustomPolicy,
            learning_rate=learning_rate,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            support_multi_env=support_multi_env,
            create_eval_env=create_eval_env,
            monitor_wrapper=monitor_wrapper,
            seed=seed,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
        )
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.learning_starts = learning_starts
        self.tau = tau
        self.gamma = gamma
        self.gradient_steps = gradient_steps
        self.action_noise = action_noise
        self.optimize_memory_usage = optimize_memory_usage
        self.replay_buffer_class = replay_buffer_class
        if replay_buffer_kwargs is None:
            replay_buffer_kwargs = {}
        self.replay_buffer_kwargs = replay_buffer_kwargs
        self._episode_storage = None
        self._last_embedding = None

        # Save train freq parameter, will be converted later to TrainFreq object
        self.train_freq = train_freq

        self.actor = None  # type: Optional[torch.nn.Module]
        self.replay_buffer = replay_buffer
        # Update policy keyword arguments
        if sde_support:
            self.policy_kwargs["use_sde"] = self.use_sde
        # For gSDE only
        self.use_sde_at_warmup = use_sde_at_warmup

        self.policy = policy.to(self.device)
        self.target_entropy = target_entropy
        self.log_ent_coef = None  # type: Optional[torch.Tensor]
        # Entropy coefficient / Entropy temperature
        # Inverse of the reward scale
        self.ent_coef = ent_coef
        self.target_update_interval = target_update_interval
        self.ent_coef_optimizer = None
        self.transition_loss_coef = transition_loss_coef
        self.reward_loss_coef = reward_loss_coef
        self.rollout_logging_full_buffer = rollout_logging_full_buffer

        if _init_setup_model:
            self._setup_model()

    def _convert_train_freq(self) -> None:
        """
        Convert `train_freq` parameter (int or tuple)
        to a TrainFreq object.
        """
        if not isinstance(self.train_freq, TrainFreq):
            train_freq = self.train_freq

            # The value of the train frequency will be checked later
            if not isinstance(train_freq, tuple):
                train_freq = (train_freq, "step")

            try:
                train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
            except ValueError:
                raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")

            if not isinstance(train_freq[0], int):
                raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

            self.train_freq = TrainFreq(*train_freq)

    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        # Use DictReplayBuffer if needed
        if self.replay_buffer_class is None:
            if isinstance(self.observation_space, gym.spaces.Dict):
                self.replay_buffer_class = DictReplayBuffer
            else:
                self.replay_buffer_class = ReplayBuffer

        elif self.replay_buffer_class == HerReplayBuffer:
            assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"

            # If using offline sampling, we need a classic replay buffer too
            if self.replay_buffer_kwargs.get("online_sampling", True):
                replay_buffer = None
            else:
                replay_buffer = DictReplayBuffer(
                    self.buffer_size,
                    self.observation_space,
                    self.action_space,
                    device=self.device,
                    optimize_memory_usage=self.optimize_memory_usage,
                )

            self.replay_buffer = HerReplayBuffer(
                self.env,
                self.buffer_size,
                device=self.device,
                replay_buffer=replay_buffer,
                **self.replay_buffer_kwargs,
            )

        if self.replay_buffer is None:
            self.replay_buffer = self.replay_buffer_class(
                self.buffer_size,
                self.observation_space,
                self.action_space,
                device=self.device,
                n_envs=self.n_envs,
                optimize_memory_usage=self.optimize_memory_usage,
                **self.replay_buffer_kwargs,
            )

        self.policy = self.policy.to(self.device)

        # Convert train freq parameter to TrainFreq object
        self._convert_train_freq()

        self._create_aliases()
        # Target entropy is used when learning the entropy coefficient
        if self.target_entropy == "auto":
            # automatically set target entropy if needed
            if isinstance(self.action_space, gym.spaces.Box):
                self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
            elif isinstance(self.action_space, gym.spaces.Discrete):
                self.target_entropy = -np.log(1.0 / self.action_space.n) * 0.98
        else:
            # Force conversion
            # this will also throw an error for unexpected string
            self.target_entropy = float(self.target_entropy)

        # The entropy coefficient or entropy can be learned automatically
        # see Automating Entropy Adjustment for Maximum Entropy RL section
        # of https://arxiv.org/abs/1812.05905
        if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
            # Default initial value of ent_coef when learned
            init_value = 1.0
            if "_" in self.ent_coef:
                init_value = float(self.ent_coef.split("_")[1])
                assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"

            # Note: we optimize the log of the entropy coeff which is slightly different from the paper
            # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
            self.log_ent_coef = torch.log(torch.ones(1, device=self.device) * init_value).requires_grad_(True)
            self.ent_coef_optimizer = torch.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
        else:
            # Force conversion to float
            # this will throw an error if a malformed string (different from 'auto')
            # is passed
            self.ent_coef_tensor = torch.tensor(float(self.ent_coef)).to(self.device)

    def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
        """
        Save the replay buffer as a pickle file.

        :param path: Path to the file where the replay buffer should be saved.
            if path is a str or pathlib.Path, the path is automatically created if necessary.
        """
        assert self.replay_buffer is not None, "The replay buffer is not defined"
        save_to_pkl(path, self.replay_buffer, self.verbose)

    def load_replay_buffer(
        self,
        path: Union[str, pathlib.Path, io.BufferedIOBase],
        truncate_last_traj: bool = True,
    ) -> None:
        """
        Load a replay buffer from a pickle file.

        :param path: Path to the pickled replay buffer.
        :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling:
            If set to ``True``, we assume that the last trajectory in the replay buffer was finished
            (and truncate it).
            If set to ``False``, we assume that we continue the same trajectory (same episode).
        """
        self.replay_buffer = load_from_pkl(path, self.verbose)
        assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"

        # Backward compatibility with SB3 < 2.1.0 replay buffer
        # Keep old behavior: do not handle timeout termination separately
        if not hasattr(self.replay_buffer, "handle_timeout_termination"):  # pragma: no cover
            self.replay_buffer.handle_timeout_termination = False
            self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)

        if isinstance(self.replay_buffer, HerReplayBuffer):
            assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
            self.replay_buffer.set_env(self.get_env())
            if truncate_last_traj:
                self.replay_buffer.truncate_last_trajectory()

    def _setup_learn(
        self,
        total_timesteps: int,
        eval_env: Optional[GymEnv],
        callback: MaybeCallback = None,
        eval_freq: int = 10000,
        n_eval_episodes: int = 5,
        log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
        tb_log_name: str = "run",
    ) -> Tuple[int, BaseCallback]:
        """
        cf `BaseAlgorithm`.
        """
        # Prevent continuity issue by truncating trajectory
        # when using memory efficient replay buffer
        # see https://github.com/DLR-RM/stable-baselines3/issues/46

        # Special case when using HerReplayBuffer,
        # the classic replay buffer is inside it when using offline sampling
        if isinstance(self.replay_buffer, HerReplayBuffer):
            replay_buffer = self.replay_buffer.replay_buffer
        else:
            replay_buffer = self.replay_buffer

        truncate_last_traj = (
            self.optimize_memory_usage
            and reset_num_timesteps
            and replay_buffer is not None
            and (replay_buffer.full or replay_buffer.pos > 0)
        )

        if truncate_last_traj:
            warnings.warn(
                "The last trajectory in the replay buffer will be truncated, "
                "see https://github.com/DLR-RM/stable-baselines3/issues/46."
                "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
                "to avoid that issue."
            )
            # Go to the previous index
            pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
            replay_buffer.dones[pos] = True

        total_timesteps, callback = super()._setup_learn(
            total_timesteps,
            eval_env,
            callback,
            eval_freq,
            n_eval_episodes,
            log_path,
            reset_num_timesteps,
            tb_log_name,
        )

        if self.policy.is_frozen_features_extractor:
            if self._last_embedding is None:
                self._last_embedding = self.policy.extract_features(
                    obs=torch.as_tensor(self._last_obs, dtype=torch.float32)
                ).cpu().numpy()

        return total_timesteps, callback

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "run",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "SACGNNBase":

        total_timesteps, callback = self._setup_learn(
            total_timesteps,
            eval_env,
            callback,
            eval_freq,
            n_eval_episodes,
            eval_log_path,
            reset_num_timesteps,
            tb_log_name,
        )

        callback.on_training_start(locals(), globals())

        while self.num_timesteps < total_timesteps:
            rollout = self.collect_rollouts(
                self.env,
                train_freq=self.train_freq,
                action_noise=self.action_noise,
                callback=callback,
                learning_starts=self.learning_starts,
                replay_buffer=self.replay_buffer,
                log_interval=log_interval,
            )

            if rollout.continue_training is False:
                break

            if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
                # If no `gradient_steps` is specified,
                # do as many gradients steps as steps performed during the rollout
                gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
                # Special case when the user passes `gradient_steps=0`
                if gradient_steps > 0:
                    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)

        callback.on_training_end()

        return self

    def _create_aliases(self) -> None:
        self.actor = self.policy.actor
        self.critic = self.policy.critic
        self.critic_target = self.policy.critic_target

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        raise NotImplementedError()

    def predict(
            self,
            observation: np.ndarray,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        embedding = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if self.policy.features_extractor is not None:
            embedding = self.policy.extract_features(obs=embedding)

        return self.policy.actor.predict(embedding, state, episode_start, deterministic)

    def _excluded_save_params(self) -> List[str]:
        return super(SACGNNBase, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables

    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.

        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
            # Warmup phase
            unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
        else:
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, gym.spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action

    def _dump_logs(self) -> None:
        """
        Write log.
        """
        time_elapsed = time.time() - self.start_time
        fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
        self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
        if (not self.rollout_logging_full_buffer or len(self.ep_info_buffer) == self.ep_info_buffer.maxlen) and len(
                self.ep_info_buffer[0]) > 0:
            self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
            self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
        self.logger.record("time/fps", fps)
        self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
        self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
        if self.use_sde:
            self.logger.record("train/std", (self.actor.get_std()).mean().item())

        if not self.rollout_logging_full_buffer or len(self.ep_success_buffer) == self.ep_success_buffer.maxlen:
            self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
        # Pass the number of timesteps for tensorboard
        self.logger.dump(step=self.num_timesteps)

    def _on_step(self) -> None:
        """
        Method called after each step in the environment.
        It is meant to trigger DQN target network update
        but can be used for other purposes
        """
        pass

    def _store_transition(
        self,
        replay_buffer: ReplayBuffer,
        buffer_action: np.ndarray,
        new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
        reward: np.ndarray,
        dones: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:
        """
        Store transition in the replay buffer.
        We store the normalized action and the unnormalized observation.
        It also handles terminal observations (because VecEnv resets automatically).

        :param replay_buffer: Replay buffer object where to store the transition.
        :param buffer_action: normalized action
        :param new_obs: next observation in the current episode
            or first observation of the episode (when dones is True)
        :param reward: reward for the current transition
        :param dones: Termination signal
        :param infos: List of additional information about the transition.
            It may contain the terminal observations and information about timeout.
        """
        # Store only the unnormalized version
        if self._vec_normalize_env is not None:
            assert False, f'Do not support VecNormalize wrapper'
            new_obs_ = self._vec_normalize_env.get_original_obs()
            reward_ = self._vec_normalize_env.get_original_reward()
        else:
            # Avoid changing the original ones
            self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward

        # Avoid modification by reference
        next_obs = deepcopy(new_obs_)
        if self.policy.is_frozen_features_extractor:
            new_obs_embedding = self.policy.extract_features(
                obs=torch.as_tensor(new_obs, dtype=torch.float32),
                prev_slots=torch.as_tensor(self._last_embedding, dtype=torch.float32, device=self.device),
            ).cpu().numpy()
            next_obs_embedding = new_obs_embedding.copy()
        # As the VecEnv resets automatically, new_obs is already the
        # first observation of the next episode
        for i, done in enumerate(dones):
            if done and infos[i].get("terminal_observation") is not None:
                if isinstance(next_obs, dict):
                    assert False, f'Do not support dict observations'
                    next_obs_ = infos[i]["terminal_observation"]
                    # VecNormalize normalizes the terminal observation
                    if self._vec_normalize_env is not None:
                        next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
                    # Replace next obs for the correct envs
                    for key in next_obs.keys():
                        next_obs[key][i] = next_obs_[key]
                else:
                    next_obs[i] = infos[i]["terminal_observation"]
                    if self.policy.is_frozen_features_extractor:
                        new_obs_embedding[i:i + 1] = self.policy.extract_features(
                            obs=torch.as_tensor(new_obs[i:i + 1], dtype=torch.float32),
                        ).cpu().numpy()
                        next_obs_embedding[i:i + 1] = self.policy.extract_features(
                            obs=torch.as_tensor(next_obs[i:i + 1], dtype=torch.float32),
                            prev_slots=torch.as_tensor(self._last_embedding[i:i + 1], dtype=torch.float32, device=self.device),
                        ).cpu().numpy()
                    # VecNormalize normalizes the terminal observation
                    if self._vec_normalize_env is not None:
                        next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])

        if self.policy.is_frozen_features_extractor:
            replay_buffer.add(
                self._last_embedding,
                next_obs_embedding,
                buffer_action,
                reward_,
                dones,
                infos,
            )
            self._last_embedding = new_obs_embedding
        else:
            replay_buffer.add(
                self._last_original_obs,
                next_obs,
                buffer_action,
                reward_,
                dones,
                infos,
            )

        self._last_obs = new_obs

        # Save the unnormalized observation
        if self._vec_normalize_env is not None:
            self._last_original_obs = new_obs_

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        train_freq: TrainFreq,
        replay_buffer: ReplayBuffer,
        action_noise: Optional[ActionNoise] = None,
        learning_starts: int = 0,
        log_interval: Optional[int] = None,
    ) -> RolloutReturn:
        """
        Collect experiences and store them into a ``ReplayBuffer``.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param train_freq: How much experience to collect
            by doing rollouts of current policy.
            Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
            or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
            with ``<n>`` being an integer greater than 0.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param replay_buffer:
        :param log_interval: Log data every ``log_interval`` episodes
        :return:
        """
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        num_collected_steps, num_collected_episodes = 0, 0

        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."

        if env.num_envs > 1:
            assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."

        # Vectorize action noise if needed
        if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise):
            action_noise = VectorizedActionNoise(action_noise, env.num_envs)

        if self.use_sde:
            self.actor.reset_noise(env.num_envs)

        callback.on_rollout_start()
        continue_training = True

        while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
            if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.actor.reset_noise(env.num_envs)

            # Select action randomly or according to policy
            actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)

            # Rescale and perform action
            new_obs, rewards, dones, infos = env.step(actions)

            self.num_timesteps += env.num_envs
            num_collected_steps += 1

            # Give access to local variables
            callback.update_locals(locals())
            # Only stop training if return value is False, not when it is None.
            if callback.on_step() is False:
                return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)

            # Retrieve reward and episode length if using Monitor wrapper
            self._update_info_buffer(infos, dones)

            # Store data in replay buffer (normalized action and unnormalized observation)
            self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)

            self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

            # For DQN, check if the target network should be updated
            # and update the exploration schedule
            # For SAC/TD3, the update is dones as the same time as the gradient update
            # see https://github.com/hill-a/stable-baselines/issues/900
            self._on_step()

            for idx, done in enumerate(dones):
                if done:
                    # Update stats
                    num_collected_episodes += 1
                    self._episode_num += 1

                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)

                    # Log training infos
                    if log_interval is not None and self._episode_num % log_interval == 0:
                        self._dump_logs()

        callback.on_rollout_end()

        return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)


class SACWMGNN(SACGNNBase):
    def __init__(
            self,
            policy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[ReplayBuffer] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
            rollout_logging_full_buffer: bool = False,
    ):

        super(SACWMGNN, self).__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            replay_buffer,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
            transition_loss_coef,
            reward_loss_coef,
            rollout_logging_full_buffer=rollout_logging_full_buffer,
        )

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer]
        if self.policy.use_wm_optimizer:
            optimizers.append(self.critic.optimizer_wm)
            optimizers.append(self.critic.optimizer_value)
        else:
            optimizers.append(self.critic.optimizer)

        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, q_losses, transition_losses, reward_losses, wm_losses, critic_losses = [], [], [], [], [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            features_detach = torch.as_tensor(replay_data.observations, dtype=torch.float32, device=self.device)
            # Action by the current actor for the sampled state
            actions_pi, log_prob = self.actor.action_log_prob(features_detach)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with torch.no_grad():
                # Select action according to policy
                next_features_detach = torch.as_tensor(replay_data.next_observations, dtype=torch.float32, device=self.device)
                next_actions, next_log_prob = self.actor.action_log_prob(next_features_detach)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.stack(
                    self.critic_target(next_features_detach, next_actions, self.actor),
                    dim=1)
                next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            replay_data_actions = replay_data.actions
            if not self.actor.squash_output:
                replay_data_actions = replay_data.actions.squeeze(dim=1)
            current_q_values = self.critic(features_detach, replay_data_actions, self.actor)

            # Compute critic loss
            q_loss = 0.5 * sum(
                [F.mse_loss(current_q, target_q_values.squeeze(dim=1)) for current_q in current_q_values])
            q_losses.append(q_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_value.zero_grad()
                q_loss.backward()
                self.critic.optimizer_value.step()

            next_features_predictions = \
                [transition_model(features_detach, replay_data_actions) for transition_model in
                 self.critic.transition_models]
            transition_loss = 0.5 * sum(
                [F.mse_loss(next_features_prediction, next_features_detach) for next_features_prediction in
                 next_features_predictions])
            transition_losses.append(transition_loss.item())

            rewards_predictions = \
                [reward_model(features_detach, replay_data_actions) for reward_model in self.critic.reward_models]
            reward_loss = 0.5 * sum(
                [F.mse_loss(rewards_prediction, replay_data.rewards.squeeze(dim=1)) for rewards_prediction in
                 rewards_predictions])
            reward_losses.append(reward_loss.item())

            wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
            wm_losses.append(wm_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_wm.zero_grad()
                wm_loss.backward()
                self.critic.optimizer_wm.step()

            critic_loss = q_loss + wm_loss
            critic_losses.append(critic_loss.item())

            if not self.policy.use_wm_optimizer:
                self.critic.optimizer.zero_grad()
                critic_loss.backward()
                self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            q_values_pi = torch.stack(self.critic(features_detach, actions_pi, self.actor), dim=1)
            min_qf_pi, _ = torch.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/transition_loss", np.mean(transition_losses))
        self.logger.record("train/reward_loss", np.mean(reward_losses))
        self.logger.record("train/wm_loss", np.mean(wm_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "SACWMGNN",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> SACGNNBase:

        return super(SACWMGNN, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer"]
        if self.policy.use_wm_optimizer:
            state_dicts.append("critic.optimizer_wm")
            state_dicts.append("critic.optimizer_value")
        else:
            state_dicts.append("critic.optimizer")

        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables


class SACGNN(SACGNNBase):
    def __init__(
            self,
            policy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[ReplayBuffer] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            rollout_logging_full_buffer: bool = False,
    ):

        super(SACGNN, self).__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            replay_buffer,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
            rollout_logging_full_buffer=rollout_logging_full_buffer,
        )

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, q_losses, = [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            features = torch.as_tensor(replay_data.observations, dtype=torch.float32, device=self.device)
            # Action by the current actor for the sampled state
            actions_pi, log_prob = self.actor.action_log_prob(features)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with torch.no_grad():
                # Select action according to policy
                next_features_detach = torch.as_tensor(replay_data.next_observations, dtype=torch.float32, device=self.device)
                next_actions, next_log_prob = self.actor.action_log_prob(next_features_detach)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.stack(
                    self.critic_target(next_features_detach, next_actions, self.actor),
                    dim=1)
                next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            features_detach = features.detach()
            current_q_values = self.critic(features_detach, replay_data.actions, self.actor)

            # Compute critic loss
            q_loss = 0.5 * sum(
                [F.mse_loss(current_q, target_q_values.squeeze(dim=1)) for current_q in current_q_values])
            q_losses.append(q_loss.item())

            # Optimize the critic
            self.critic.optimizer.zero_grad()
            q_loss.backward()
            self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            q_values_pi = torch.stack(self.critic(features, actions_pi, self.actor), dim=1)
            min_qf_pi, _ = torch.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "SACGNN",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> SACGNNBase:

        return super(SACGNN, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )


class DiscreteSACWMGNN(SACGNNBase):
    def __init__(
            self,
            policy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[Dict[str, Any]] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
            rollout_logging_full_buffer: bool = False,
            contrastive_learning_kwargs=None,
    ):

        super(DiscreteSACWMGNN, self).__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            replay_buffer,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
            transition_loss_coef,
            reward_loss_coef,
            rollout_logging_full_buffer=rollout_logging_full_buffer,
        )
        if contrastive_learning_kwargs is not None:
            assert 'hinge' in contrastive_learning_kwargs
            assert 'coef' in contrastive_learning_kwargs

        self.contrastive_learning_kwargs = contrastive_learning_kwargs

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer]
        if self.policy.use_wm_optimizer:
            optimizers.append(self.critic.optimizer_wm)
            optimizers.append(self.critic.optimizer_value)
        else:
            optimizers.append(self.critic.optimizer)

        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, q_losses, transition_losses, reward_losses, wm_losses, critic_losses = [], [], [], [], [], []
        contrastive_losses = []
        batch_averaged_q_values = [[] for _ in range(self.critic.n_critics + 1)]
        batch_v_values = [[] for _ in range(self.critic.n_critics + 1)]

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            features = torch.as_tensor(replay_data.observations, dtype=torch.float32, device=self.device)
            if not self.policy.is_frozen_features_extractor:
                features = self.policy.extract_features(features)

            # Action by the current actor for the sampled state
            log_prob_action, prob, log_prob = self.actor(features.detach())
            log_prob_action = log_prob_action.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob_action + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            next_features = torch.as_tensor(replay_data.next_observations, dtype=torch.float32, device=self.device)
            if not self.policy.is_frozen_features_extractor:
                next_features = self.policy.extract_features(next_features)

            with torch.no_grad():
                # Select action according to policy
                _, next_prob, next_log_prob = self.actor(next_features)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.minimum(*self.critic_target(next_features, self.actor))
                # add entropy term
                next_q_values = next_prob * (next_q_values - ent_coef * next_log_prob)
                next_q_values = next_q_values.sum(dim=1, keepdim=True)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(features.detach(), self.actor)

            # Compute critic loss
            q_loss = 0.5 * sum(
                [F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values])
            q_losses.append(q_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_value.zero_grad()
                q_loss.backward()
                self.critic.optimizer_value.step()

            next_features_predictions = \
                [transition_model(features, replay_data.actions.squeeze(dim=1)) for transition_model in
                 self.critic.transition_models]
            transition_loss = 0.5 * sum(
                [F.mse_loss(next_features_prediction, next_features) for next_features_prediction in
                 next_features_predictions])
            transition_losses.append(transition_loss.item())

            rewards_predictions = \
                [reward_model(features, replay_data.actions.squeeze(dim=1)) for reward_model in self.critic.reward_models]
            reward_loss = 0.5 * sum(
                [F.mse_loss(rewards_prediction, replay_data.rewards.squeeze(dim=1)) for rewards_prediction in
                 rewards_predictions])
            reward_losses.append(reward_loss.item())

            wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
            if self.contrastive_learning_kwargs is not None:
                batch_size = features.size(0)
                distance_to_negative_samples = (next_features[torch.randperm(batch_size)] - features) ** 2
                contrastive_loss = torch.max(
                    torch.zeros(batch_size, dtype=torch.float32, device=self.device),
                    self.contrastive_learning_kwargs['hinge'] - distance_to_negative_samples.mean(dim=(1, 2))
                ).mean()
                contrastive_losses.append(contrastive_loss.item())
                wm_loss = wm_loss + self.contrastive_learning_kwargs['coef'] * contrastive_loss
            wm_losses.append(wm_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_wm.zero_grad()
                wm_loss.backward()
                self.critic.optimizer_wm.step()

            critic_loss = q_loss + wm_loss
            critic_losses.append(critic_loss.item())

            if not self.policy.use_wm_optimizer:
                self.critic.optimizer.zero_grad()
                critic_loss.backward()
                self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            action_values = self.critic(features.detach(), self.actor)
            min_qf_pi = torch.minimum(*action_values)
            actor_loss = torch.sum(prob * (ent_coef * log_prob - min_qf_pi), dim=1).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

            # Collect statistics on discrepancy between averaged action values and state values
            for q_values, q_value in zip(batch_averaged_q_values, action_values):
                q_values.append(torch.sum(prob * q_value, dim=1).detach())
            batch_averaged_q_values[-1].append(torch.sum(prob * min_qf_pi, dim=1).detach())

            state_values = [value_model(features) for value_model in self.critic.value_models]
            for v_values, value in zip(batch_v_values, state_values):
                v_values.append(value.detach())
            min_vf = torch.minimum(*state_values)
            batch_v_values[-1].append(min_vf.detach())

        self._n_updates += gradient_steps
        batch_v_values_diff = []
        for v_values, averaged_q_values in zip(batch_v_values, batch_averaged_q_values):
            v_values = torch.cat(v_values, dim=0)
            averaged_q_values = torch.cat(averaged_q_values, dim=0)
            batch_v_values_diff.append(averaged_q_values - v_values)

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/transition_loss", np.mean(transition_losses))
        self.logger.record("train/reward_loss", np.mean(reward_losses))
        self.logger.record("train/wm_loss", np.mean(wm_losses))
        if len(contrastive_losses) > 0:
            self.logger.record("train/contrastive_loss", np.mean(contrastive_losses))

        for i, value_diff in enumerate(batch_v_values_diff[:-1]):
            self.logger.record(f"train/mean(averaged_q-v)_{i}", torch.mean(value_diff).item())
            self.logger.record(f"train/std(averaged_q-v)_{i}", torch.std(value_diff).item())

        self.logger.record("train/mean(averaged_q-v)_min", torch.mean(batch_v_values_diff[-1]).item())
        self.logger.record("train/std(averaged_q-v)_min", torch.std(batch_v_values_diff[-1]).item())

        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "SACWMGNN",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> SACGNNBase:

        return super(DiscreteSACWMGNN, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer"]
        if self.policy.use_wm_optimizer:
            state_dicts.append("critic.optimizer_wm")
            state_dicts.append("critic.optimizer_value")
        else:
            state_dicts.append("critic.optimizer")

        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables

    def load_checkpoint(self, path):
        policy_weights = torch.load(os.path.join(path, 'policy.pth'))
        log_ent_coef_weights = torch.load(os.path.join(path, 'pytorch_variables.pth'))
        actor_optimizer_weights = torch.load(os.path.join(path, 'actor.optimizer.pth'))
        critic_optimizer_value_weights = torch.load(os.path.join(path, 'critic.optimizer_value.pth'))
        critic_optimizer_wm_weights = torch.load(os.path.join(path, 'critic.optimizer_wm.pth'))
        ent_coef_optimizer_weights = torch.load(os.path.join(path, 'ent_coef_optimizer.pth'))

        self.policy.load_state_dict(policy_weights)
        with torch.no_grad():
            self.log_ent_coef[0] = log_ent_coef_weights['log_ent_coef'][0]

        self.policy.actor.optimizer.load_state_dict(actor_optimizer_weights)
        self.policy.critic.optimizer_value.load_state_dict(critic_optimizer_value_weights)
        self.policy.critic.optimizer_wm.load_state_dict(critic_optimizer_wm_weights)
        self.ent_coef_optimizer.load_state_dict(ent_coef_optimizer_weights)


class DiscreteSAC(SACGNNBase):
    def __init__(
            self,
            policy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[Dict[str, Any]] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
            rollout_logging_full_buffer: bool = False,
    ):

        super(DiscreteSAC, self).__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            replay_buffer,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
            transition_loss_coef,
            reward_loss_coef,
            rollout_logging_full_buffer=rollout_logging_full_buffer,
        )

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer]
        optimizers.append(self.critic.optimizer)

        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, q_losses, transition_losses, reward_losses, wm_losses, critic_losses = [], [], [], [], [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            # Action by the current actor for the sampled state
            log_prob_action, prob, log_prob = self.actor(replay_data.observations)
            log_prob_action = log_prob_action.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob_action + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with torch.no_grad():
                # Select action according to policy
                _, next_prob, next_log_prob = self.actor(replay_data.next_observations)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.minimum(*self.critic_target(replay_data.next_observations, self.actor))
                # add entropy term
                next_q_values = next_prob * (next_q_values - ent_coef * next_log_prob)
                next_q_values = next_q_values.sum(dim=1, keepdim=True)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, self.actor)

            # Compute critic loss
            q_loss = 0.5 * sum(
                [F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values])
            q_losses.append(q_loss.item())

            self.critic.optimizer.zero_grad()
            q_loss.backward()
            self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            min_qf_pi = torch.minimum(*self.critic(replay_data.observations, self.actor))
            actor_loss = torch.sum(prob * (ent_coef * log_prob - min_qf_pi), dim=1).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "SACWMGNN",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> SACGNNBase:

        return super(DiscreteSAC, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer"]
        state_dicts.append("critic.optimizer")

        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables


class DiscreteSACWMMLP(SACGNNBase):
    def __init__(
            self,
            policy,
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            replay_buffer: Optional[Dict[str, Any]] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[torch.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
            rollout_logging_full_buffer: bool = False,
    ):

        super(DiscreteSACWMMLP, self).__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            replay_buffer,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
            transition_loss_coef,
            reward_loss_coef,
            rollout_logging_full_buffer=rollout_logging_full_buffer,
        )

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer]
        if self.policy.use_wm_optimizer:
            optimizers.append(self.critic.optimizer_wm)
            optimizers.append(self.critic.optimizer_value)
        else:
            optimizers.append(self.critic.optimizer)

        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, q_losses, transition_losses, reward_losses, wm_losses, critic_losses = [], [], [], [], [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            # Action by the current actor for the sampled state
            log_prob_action, prob, log_prob = self.actor(replay_data.observations)
            log_prob_action = log_prob_action.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob_action + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with torch.no_grad():
                # Select action according to policy
                _, next_prob, next_log_prob = self.actor(replay_data.next_observations)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.minimum(*self.critic_target(replay_data.next_observations, self.actor))
                # add entropy term
                next_q_values = next_prob * (next_q_values - ent_coef * next_log_prob)
                next_q_values = next_q_values.sum(dim=1, keepdim=True)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, self.actor)

            # Compute critic loss
            q_loss = 0.5 * sum(
                [F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values])
            q_losses.append(q_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_value.zero_grad()
                q_loss.backward()
                self.critic.optimizer_value.step()

            one_hot_actions = F.one_hot(replay_data.actions.squeeze(dim=1), num_classes=self.critic.num_actions).to(torch.float32)
            next_features_predictions = \
                [transition_model(replay_data.observations, one_hot_actions) for transition_model in
                 self.critic.transition_models]
            transition_loss = 0.5 * sum(
                [F.mse_loss(next_features_prediction, replay_data.next_observations) for next_features_prediction in
                 next_features_predictions])
            transition_losses.append(transition_loss.item())

            rewards_predictions = \
                [reward_model(replay_data.observations, one_hot_actions) for reward_model in self.critic.reward_models]
            reward_loss = 0.5 * sum(
                [F.mse_loss(rewards_prediction, replay_data.rewards) for rewards_prediction in
                 rewards_predictions])
            reward_losses.append(reward_loss.item())

            wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
            wm_losses.append(wm_loss.item())

            if self.policy.use_wm_optimizer:
                self.critic.optimizer_wm.zero_grad()
                wm_loss.backward()
                self.critic.optimizer_wm.step()

            critic_loss = q_loss + wm_loss
            critic_losses.append(critic_loss.item())

            if not self.policy.use_wm_optimizer:
                self.critic.optimizer.zero_grad()
                critic_loss.backward()
                self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            min_qf_pi = torch.minimum(*self.critic(replay_data.observations, self.actor))
            actor_loss = torch.sum(prob * (ent_coef * log_prob - min_qf_pi), dim=1).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/transition_loss", np.mean(transition_losses))
        self.logger.record("train/reward_loss", np.mean(reward_losses))
        self.logger.record("train/wm_loss", np.mean(wm_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "DiscreteSACWMMLP",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> SACGNNBase:

        return super(DiscreteSACWMMLP, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer"]
        if self.policy.use_wm_optimizer:
            state_dicts.append("critic.optimizer_wm")
            state_dicts.append("critic.optimizer_value")
        else:
            state_dicts.append("critic.optimizer")

        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables
