
import dataclasses as dc
import math
import numpy as np
import torch
import torch.nn.functional as F
import typing as ty

from copy import copy, deepcopy
from gym import spaces
from pathlib import Path

from .base import polyak_update
from .networks import ActionValueNetwork4ContinuousAction
from .policy import SquashedGaussianPolicy, RealNVPPolicy, SequantialGaussRealNVPPolicy

from .sac_clean import SoftQNetwork
from .sac import SoftActorCritic
from .networks import ActionValueNetwork4ContinuousAction



@dc.dataclass
class Param4MDAC:

    # for commom
    batch_size      : int = 256
    buffer_size     : int = int(1e7)
    lr              : float = 3e-3
    gamma           : float = 0.99
    seed            : int = 7
    grad_clipping   : float = None
    cuda            : bool = True
    n_hidden        : list = dc.field(default_factory=list)
    layer_norm      : bool = False
    use_double      : bool = False
    polyak          : float = 0.995

    # for actor
    lr_policy: float = 3e-4
    policy_type: str = 'gaussian'
    explorer_type: str = 'rnvp'
    n_mixtures: int = 1
    single_head_gmn: bool = False
    min_bonus: float = - 1.0
    bound_f: str = 'rclip'
    bound_g: str = 'rclip'
    coeff_bmixture: float = 0.1
    autotune_kl_coeff: bool = False
    coeff_kl: float = 0.999
    max_coeff_kl: float = 1e6
    trust_region: float = 0.1
    autotune_ent_coeff: bool = True
    coeff_entropy: float = 0.05
    max_coeff_entropy: float = 1e6
    tune_explorer_entropy: bool = False
    tune_explorer_kl: bool = False
    autotune_redual: bool = True
    sampler: str = 'uniform'
    n_particles: int = 10
    n_particles_per_dim: int = None
    log_std_min: float = -5.
    log_std_max: float = 2.
    min_action: float = None
    max_action: float = None
    action_scaler: str = 'spinningup'
    qf_emulate_clean_rl: str = True
    explore_by_aux: bool = False


class MirrorDescentActorCritic(SoftActorCritic):

    '''
        Mirror Descent Actor Ctiric
    '''

    def __init__(
        self,
        observation_space, action_space,
        batch_size: int = 16,
        lr: float = 1e-2,
        gamma: float = 0.99,
        seed: int = 13,
        buffer_size: int = int(1e4),
        grad_clipping: float = None,
        cuda: bool = False,
        n_hidden: list = [256, 256], # used if for representation='NN'
        layer_norm: bool = False,
        use_double: bool = False,
        polyak: float = 0.995,
        # actor
        lr_policy: float = 1e-3,
        policy_type: str = 'gaussian',
        explorer_type: str = 'gaussian',
        n_mixtures: int = 1,
        single_head_gmn: bool = False,
        min_bonus: float = - 1.0,
        bound_f: str = 'rclip',
        bound_g: str = 'rclip',
        coeff_bmixture: float = 0.1,
        autotune_kl_coeff: bool = False,
        coeff_kl: float = 0.999,
        max_coeff_kl: float = 1e6,
        trust_region: float = 0.1,
        autotune_ent_coeff: bool = True,
        coeff_entropy: float = 0.9,
        max_coeff_entropy: float = 1e6,
        tune_explorer_entropy: bool = False,
        tune_explorer_kl: bool = False,
        autotune_redual: bool = True,
        sampler: str = 'uniform',
        n_particles: int = 10,
        n_particles_per_dim: int = None,
        log_std_min: float = -5.,
        log_std_max: float = 2.,
        min_action: float = -1.,
        max_action: float = 1.,
        action_scaler: str = 'spinningup',
        qf_emulate_clean_rl: str = True,
        explore_by_aux: bool = False,
    ):
        super(MirrorDescentActorCritic, self).__init__(
            observation_space, action_space,
            batch_size=batch_size, lr=lr, gamma=gamma, seed=seed,
            buffer_size=buffer_size, grad_clipping=grad_clipping, cuda=cuda,
            n_hidden=n_hidden, layer_norm=layer_norm, use_double=use_double,
            polyak=polyak, lr_policy=lr_policy, policy_type=policy_type,
            n_mixtures=n_mixtures, single_head_gmn=single_head_gmn,
            autotune_ent_coeff=autotune_ent_coeff,
            coeff_entropy=coeff_entropy, max_coeff_entropy=max_coeff_entropy,
            log_std_min=log_std_min, log_std_max=log_std_max,
            min_action=min_action, max_action=max_action,
            action_scaler=action_scaler,
            qf_emulate_clean_rl=qf_emulate_clean_rl,
        )

        self.autotune_redual = autotune_redual
        self.autotune_kl_coeff = autotune_kl_coeff
        self.trust_region = torch.tensor([trust_region], dtype=torch.float32, device=self.device)
        self.max_kappa = torch.tensor([1 - (1 - self.gamma)**2], dtype=torch.float32, device=self.device)
        if self.autotune_kl_coeff:
            self.max_coeff_kl = max_coeff_kl
            self.log_coeff_kl = torch.zeros(1, requires_grad=True, device=self.device)
            self.optimizer_kl = torch.optim.Adam([self.log_coeff_kl], lr=lr)
            print(f'trust_region = {self.trust_region}')
            print(f'log_coeff_kl = {self.log_coeff_kl}')

        else:
            self.coeff_kl = torch.tensor([coeff_kl], dtype=torch.float32, device=self.device)

        self.log_unif = - torch.log(self.action_scale).sum()
        self.sampler = sampler
        if n_particles_per_dim is None:
            self.n_particles = n_particles
        else:
            self.n_particles = n_particles_per_dim * self.dim_ac
        self.min_bonus = min_bonus
        self.bound_f = bound_f
        self.bound_g = bound_g
        self.coeff_bmixture = coeff_bmixture

        self.explorer_type = explorer_type
        if self.explorer_type != self.policy_type:
            if self.explorer_type == 'gaussian':
                self.explorer = SquashedGaussianPolicy(
                    observation_space, action_space,
                    action_scaler=action_scaler,
                    min_action=min_action, max_action=max_action,
                    n_hidden=n_hidden, device=self.device,
                    representation='NN',
                    log_std_min=log_std_min, log_std_max=log_std_max,
                ).to(self.device)
            elif self.explorer_type == 'rnvp':
                self.explorer = RealNVPPolicy(
                    observation_space, action_space,
                    n_hidden=n_hidden, device=self.device,
                    noise_type='gaussian',
                ).to(self.device)
            elif self.explorer_type == 'sgrnvp':
                self.explorer = SequantialGaussRealNVPPolicy(
                    observation_space, action_space,
                    n_hidden=n_hidden, device=self.device,
                ).to(self.device)
            self.optimizer_explorer = torch.optim.Adam(self.explorer.parameters(), lr=lr_policy)
            print(f'explorer = {self.explorer}')

        self.tune_explorer_entropy = tune_explorer_entropy
        self.tune_explorer_kl = tune_explorer_kl
        if self.bound_f == 'vscale' or self.bound_g == 'vscale':
            self.vscale = 0.
        if self.bound_g == 'clip_logt' or self.bound_g == 'clip_t' or 'clip_t' in self.bound_g:
            self.gstep = 0

        self.explore_by_aux = explore_by_aux


    def update_vscale(self, rew: torch.tensor) -> torch.tensor:
        self.vscale = torch.max(rew.min().abs(), rew.max().abs()) / (1 - self.gamma)


    # def b(self, x: torch.tensor) -> torch.tensor:
    #     if self.bounding == None:
    #         return x
    #     elif self.bounding == 'squash':
    #         return torch.tanh(x*0.1)
    #     elif self.bounding == 'vscale':
    #         return torch.tanh(x*0.1) * self.vscale


    def act(
        self,
        observation: ty.Union[torch.tensor, np.ndarray],
        greedy: bool=False, as_tensor=False,
    ) -> ty.Union[torch.tensor, np.ndarray]:

        if not isinstance(observation, torch.Tensor):
            if isinstance(observation, int):
                observation = np.array([observation])
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)

        if greedy:
            action = self.policy.act(observation, greedy=True, as_tensor=as_tensor)
        elif self.policy_type != self.explorer_type and self.explore_by_aux:
            action = self.explorer.act(observation, greedy=False, as_tensor=as_tensor)
        else:
            action = self.policy.act(observation, greedy=False, as_tensor=as_tensor)


        action = action.squeeze()
        if action.shape==():
            action = action.reshape((1,))

        return action



    def get_coeff_entropy(self):
        """
            Returned value is always \tau, where \tau = \alpha + \beta.
            As coeff_entropy,
                if autotune_redual, \tau is directly stored.
                otherwise alpha is stored.
        """
        if self.autotune_ent_coeff:
            if self.autotune_redual:
                tau = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
            else:
                alpha = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
                beta = self.log_coeff_kl.exp().clamp(0, self.max_coeff_entropy)
                tau = alpha + beta
        else:
            tau = self.coeff_entropy
        return tau


    def get_coeff_kl(self):
        """
            Returned value is always \kappa, where \tau = \beta / (\alpha + \beta).
        """
        if self.autotune_kl_coeff:
            if self.autotune_redual:
                kappa = torch.sigmoid(self.log_coeff_kl)
            else:
                alpha = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_kl)
                beta = self.log_coeff_kl.exp().clamp(0, self.max_coeff_kl)
                kappa = beta / (alpha + beta)
            kappa = kappa * self.max_kappa
        else:
            kappa = self.coeff_kl
        return kappa


    def get_alpha_beta(self):
        if self.autotune_redual:
            tau = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
            kappa = torch.sigmoid(self.log_coeff_kl) * self.max_kappa
            alpha = tau * (1 - kappa)
            beta = tau * kappa
        else:
            alpha = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
            beta = self.log_coeff_kl.exp().clamp(0, self.max_coeff_entropy)
        return alpha, beta


    def compute_target(self,
                       ob: torch.tensor,
                       ac: torch.tensor,
                       rew: torch.tensor,
                       next_ob: torch.tensor,
                       done: torch.tensor):

        batch_size = ob.shape[0]

        with torch.no_grad():

            ''' entropy & KL bonus for soft policy iteration '''
            coeff_entropy = self.get_coeff_entropy()
            coeff_kl = self.get_coeff_kl()

            if self.policy_type != self.explorer_type:
                policy = self.explorer
            else:
                policy = self.policy

            log_pi = policy.log_prob(ob, ac)
            log_pi = log_pi.view(batch_size, 1)

            bonus = coeff_entropy * log_pi
            if self.bound_f == 'identity':
                pass
            elif self.bound_f == 'nclip':
                """ naive clip """
                bonus = bonus.clamp(-1, 1)
            elif self.bound_f == 'rclip':
                bonus = (bonus*0.1).clamp(-1, 1)
            elif self.bound_f == 'squash' or self.bound_f == 'ntanh' or self.bound_f == 'tanh':
                bonus = torch.tanh(bonus)
            elif self.bound_f == 'rsquash' or self.bound_f == 'rtanh':
                bonus = torch.tanh(bonus*0.1)
            elif self.bound_f == 'binary':
                bonus = torch.where(bonus >= 0, 1., -1.)
            elif self.bound_f == 'nconst':
                bonus = -1 * torch.ones_like(bonus)
            elif self.bound_f == 'vscale':
                bonus = torch.tanh(bonus*0.1) * self.vscale

            assert rew.shape == log_pi.shape
            bonus = coeff_kl * bonus
            rew_augumented = rew + bonus

            if self.logged_reward is not None:
                self.logged_reward.append(rew.cpu())
                self.logged_log_pi.append(log_pi.cpu())
                self.logged_tau.append(coeff_entropy.cpu())
                self.logged_kappa.append(coeff_kl.cpu())

                bonus_for_log = coeff_entropy * log_pi
                if self.bound_f == 'nclip':
                    clipped = torch.where((bonus_for_log < -1) | (bonus_for_log > 1), 1., 0.)
                elif self.bound_f == 'rclip':
                    clipped = torch.where((bonus_for_log*0.1 < -1) | (bonus_for_log*0.1 > 1), 1., 0.)
                else:
                    clipped = torch.zeros_like(bonus_for_log)
                self.logged_clipped.append(clipped.cpu())
                self.logged_bonus.append(bonus_for_log.cpu())

            next_ac_pi, next_log_pi = policy.act(next_ob, as_tensor=True, with_log_prob=True)
            next_log_pi = next_log_pi.view(batch_size, 1)

            ''' double q '''
            next_q_sa = self.target_q_net(next_ob, next_ac_pi)
            if self.use_double:
                next_q_sa2 = self.target_q_net2(next_ob, next_ac_pi)
                next_q_sa = torch.min(next_q_sa, next_q_sa2)
            if not self.qf_emulate_clean_rl:
                next_q_sa = next_q_sa.unsqueeze(-1)
            assert next_q_sa.shape == (batch_size, 1), f'{next_q_sa.shape} != {(batch_size, 1)}'

            next_bonus = coeff_entropy * next_log_pi
            if self.bound_g == 'identity':
                pass
            elif self.bound_g == 'nclip':
                next_bonus = (next_bonus).clamp(-1, 1)
            elif self.bound_g == 'rclip':
                next_bonus = (next_bonus*0.1).clamp(-1, 1)
            elif self.bound_g == 'squash' or self.bound_g == 'ntanh' or self.bound_g == 'tanh':
                next_bonus = torch.tanh(next_bonus)
            elif self.bound_g == 'rsquash' or self.bound_g == 'rtanh':
                next_bonus = torch.tanh(next_bonus*0.1)
            elif self.bound_g == 'binary':
                next_bonus = torch.where(next_bonus >= 0, 1., -1.)
            elif self.bound_g == 'nconst':
                next_bonus = -1 * torch.ones_like(next_bonus)
            elif self.bound_g == 'vscale':
                next_bonus = torch.tanh(next_bonus*0.1) * self.vscale
            elif self.bound_g == 'clip_logt':
                self.gstep += 1
                log_gstep = math.log10(self.gstep)
                clip_scale = log_gstep / (log_gstep + 5)
                clip_level = 1 + log_gstep
                next_bonus = (next_bonus*clip_scale).clamp(-clip_level, clip_level)
            elif self.bound_g == 'clip_t':
                self.gstep += 1
                gstep = (self.gstep+1e6) / 1e6
                clip_scale = gstep / (gstep + 10)
                clip_level = gstep
                next_bonus = (next_bonus*clip_scale).clamp(-clip_level, clip_level)
            elif 'clip_t' in self.bound_g:
                assert self.bound_g.split('.')[0] == 'clip_t'
                horizon = float(self.bound_g.split('.')[1])
                self.gstep += 1
                gstep = (self.gstep+horizon) / horizon
                clip_scale = gstep / (gstep + 10)
                clip_level = gstep
                next_bonus = (next_bonus*clip_scale).clamp(-clip_level, clip_level)

            next_soft_q_sa = next_q_sa - next_bonus

            target = rew_augumented + self.gamma * (1.0 - done) * next_soft_q_sa

            if self.logged_reward is not None:
                self.logged_next_q.append(next_q_sa.cpu())
                self.logged_next_log_pi.append(next_log_pi.cpu())

                next_bonus_for_log = coeff_entropy * next_log_pi
                if self.bound_g == 'nclip':
                    next_clipped = torch.where((next_bonus_for_log < -1) | (next_bonus_for_log > 1), 1., 0.)
                elif self.bound_g == 'rclip':
                    next_clipped = torch.where((next_bonus_for_log*0.1 < -1) | (next_bonus_for_log*0.1 > 1), 1., 0.)
                else:
                    next_clipped = torch.zeros_like(next_bonus_for_log)
                self.logged_next_clipped.append(next_clipped.cpu())
                self.logged_next_bonus.append(next_bonus_for_log.cpu())

            assert target.shape == (batch_size, 1)

        return target


    def compute_actor_loss(
            self,
            ob: torch.tensor,
            policy: torch.nn.Module,
        ):
        batch_size = ob.shape[0]

        ob_repeat = ob

        ''' policy evaluation '''
        ac_pi, log_pi = policy.act(ob_repeat, as_tensor=True, with_log_prob=True)

        with torch.no_grad():
            coeff_entropy = self.get_coeff_entropy()

        q_sa = self.q_net(ob_repeat, ac_pi)
        if self.use_double:
            q_sa2 = self.q_net2(ob_repeat, ac_pi)
            q_sa = torch.min(q_sa, q_sa2)

        log_pi = log_pi.view(batch_size, 1)
        if not self.qf_emulate_clean_rl:
            q_sa = q_sa.view(batch_size, 1)
        soft_greedy_loss = (coeff_entropy * log_pi - q_sa).mean()

        if self.logged_curr_q is not None:
            self.logged_curr_q.append(q_sa.detach().cpu())

        return soft_greedy_loss


    def update_actor(
        self,
        ob: torch.tensor,
    ):
        self.optimizer_policy.zero_grad()
        actor_loss = self.compute_actor_loss(ob, self.policy)
        actor_loss.backward()
        self.grad_clip_actor()
        self.optimizer_policy.step()
        actor_loss = actor_loss.detach().item()
        return actor_loss


    def update_explorer(
        self,
        ob: torch.tensor,
    ):
        self.optimizer_explorer.zero_grad()
        explorer_loss = self.compute_actor_loss(ob, self.explorer)
        explorer_loss.backward()
        self.optimizer_explorer.step()
        explorer_loss = explorer_loss.detach().item()
        return explorer_loss


    def compute_kl_loss(self, ob: torch.tensor):
        with torch.no_grad():
            ac_pi, log_pi = self.policy.act(ob, as_tensor=True, with_log_prob=True)
            log_rho = self.anchor.log_prob(ob, ac_pi)#.view(batch_size, 1)

        tau = self.get_coeff_entropy()
        kappa = self.get_coeff_kl()
        coeff_kl = tau * kappa
        entropy_loss = (coeff_kl * (log_pi - log_rho - self.trust_region).detach()).mean()

        return entropy_loss


    def grad_clip_kl(self):
        if self.grad_clipping is not None:
            torch.nn.utils.clip_grad_norm_([self.log_coeff_kl], self.grad_clipping)


    def update_kl_coeff(
        self,
        ob: torch.tensor,
    ):
        self.optimizer_kl.zero_grad()
        kl_loss = self.compute_optimal_dual_loss(ob)
        kl_loss.backward()
        self.grad_clip_kl()
        self.optimizer_kl.step()
        kl_loss = kl_loss.detach().item()
        return kl_loss


    def compute_joint_entropy_loss(self, ob: torch.tensor):
        with torch.no_grad():
            if self.tune_explorer_entropy and self.policy_type != self.explorer_type:
                policy = self.explorer
            else:
                policy = self.policy
            _, log_pi = policy.act(ob, as_tensor=True, with_log_prob=True)

        tau = self.get_coeff_entropy()
        kappa = self.get_coeff_kl()
        coeff_entropy = tau * (1 - kappa)
        entropy_loss = - (coeff_entropy * (log_pi + self.target_entropy).detach()).mean()

        if self.logged_log_pi_H is not None:
            self.logged_log_pi_H.append(log_pi.cpu())

        return entropy_loss


    def compute_entropy_loss(self, ob: torch.tensor):
        return self.compute_joint_entropy_loss(ob)


    def compute_optimal_dual_loss(self, ob: torch.tensor) -> torch.tensor:

        batch_size = ob.shape[0]
        assert ob.shape == (batch_size, self.dim_ob)

        with torch.no_grad():

            ob_repeat = ob.unsqueeze(1).repeat(1, self.n_particles, 1).view(batch_size * self.n_particles, self.dim_ob)
            if self.sampler == 'actor':
                if self.policy_type != self.explorer_type:
                    policy = self.explorer
                else:
                    policy = self.policy
                ac, log_q = policy.act(ob_repeat, as_tensor=True, with_log_prob=True)
                log_q = log_q.view(batch_size, self.n_particles)

            elif self.sampler == 'uniform':
                """ ac \in [-1, 1] """
                ac = torch.rand(
                    (batch_size, self.n_particles, self.dim_ac),
                    dtype=torch.float32, device=self.device) * 2 - 1
                ac = ac * self.action_scale + self.action_bias
                ac = ac.view(batch_size * self.n_particles, self.dim_ac)
                log_q = self.log_unif * torch.ones(
                    (batch_size, self.n_particles),
                    dtype=torch.float32, device=self.device)

            Q_soft = self.q_net(ob_repeat, ac).view(batch_size, self.n_particles)
            if self.use_double:
                Q_soft2 = self.q_net2(ob_repeat, ac).view(batch_size, self.n_particles)
                Q_soft = torch.min(Q_soft, Q_soft2)

        if self.autotune_redual:
            tau = self.get_coeff_entropy()
            kappa = self.get_coeff_kl()

            """ log_q is for importance sampling """
            V_soft = tau * torch.logsumexp(Q_soft / tau - log_q, dim=1, keepdim=True)
            assert V_soft.shape == (batch_size, 1)
            """ τ(κ − 1)ε + τκδ """
            return (tau * (kappa - 1) * self.target_entropy \
                    + tau * kappa * self.trust_region \
                    + V_soft).mean()
        else:
            alpha = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
            beta = self.log_coeff_kl.exp().clamp(0, self.max_coeff_entropy)

            """ log_q is for importance sampling """
            V_soft = (alpha + beta) * torch.logsumexp(Q_soft / (alpha + beta) - log_q, dim=1, keepdim=True)
            assert V_soft.shape == (batch_size, 1)
            """ βδ - αε """
            return (beta * self.trust_region - alpha * self.target_entropy + V_soft).mean()


    def update_target_network(self, update_policy: bool = True):
        '''
            theta' <- polyak * theta' * (1 - polyak) * theta
        '''
        with torch.no_grad():
            polyak_update(self.q_net, self.target_q_net, self.polyak)
            if self.use_double:
                polyak_update(self.q_net2, self.target_q_net2, self.polyak)



    def update(self, use_all_samples: bool = False, update_policy: bool = True):

        ob, ac, ob2, rew, done = self.get_mini_batch(use_all_samples)

        if self.bound_f == 'vscale' or self.bound_g == 'vscale':
            self.update_vscale(rew)

        critic_loss, error = self.update_critic(ob, ac, rew, ob2, done)

        if update_policy:
            if self.policy_type != self.explorer_type:
                explorer_loss = self.update_explorer(ob)
            actor_loss = self.update_actor(ob)

            if self.autotune_ent_coeff:
                entropy_loss = self.update_entropy_coeff(ob)
            if self.autotune_kl_coeff:
                entropy_loss = self.update_kl_coeff(ob)

        self.update_target_network(update_policy)

        return critic_loss, error


    def save(self, model_dir: Path, tag: str, verbose=False):

        def save_one_model(module, save_path):
            torch.save(module.state_dict(), save_path)
            if verbose:
                print(f'saved as {save_path}')

        save_one_model(self.policy, model_dir/f'policy_{tag}.pt')
        if self.policy_type != self.explorer_type:
            save_one_model(self.explorer, model_dir/f'explorer_{tag}.pt')
        save_one_model(self.q_net, model_dir/f'q_net_{tag}.pt')
        save_one_model(self.target_q_net, model_dir/f'target_q_net_{tag}.pt')
        if self.use_double:
            save_one_model(self.q_net2, model_dir/f'q_net2_{tag}.pt')
            save_one_model(self.target_q_net2, model_dir/f'target_q_net2_{tag}.pt')


    def load(self, model_dir: Path, tag: str, verbose=False):

        def load_one_model(module, load_path):
            module.load_state_dict(torch.load(load_path))
            if verbose:
                print(f'loaded from {load_path}')

        load_one_model(self.policy, model_dir/f'policy_{tag}.pt')
        if self.policy_type != self.explorer_type:
            load_one_model(self.explorer, model_dir/f'explorer_{tag}.pt')
        load_one_model(self.q_net, model_dir/f'q_net_{tag}.pt')
        load_one_model(self.target_q_net, model_dir/f'target_q_net_{tag}.pt')
        if self.use_double:
            load_one_model(self.q_net2, model_dir/f'q_net2_{tag}.pt')
            load_one_model(self.target_q_net2, model_dir/f'target_q_net2_{tag}.pt')


    def get_state_dict(self) -> dict:
        checkpoint = dict()
        checkpoint['buffer'] = self.buffer

        checkpoint['q_net_state_dict'] = self.q_net.state_dict()
        checkpoint['target_q_net_state_dict'] = self.target_q_net.state_dict()
        if self.use_double:
            checkpoint['q_net2_state_dict'] = self.q_net2.state_dict()
            checkpoint['target_q_net2_state_dict'] = self.target_q_net2.state_dict()
        checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()

        checkpoint['policy'] = self.policy.state_dict()
        checkpoint['optimizer_policy'] = self.optimizer_policy.state_dict()
        if self.policy_type != self.explorer_type:
            checkpoint['explorer'] = self.explorer.state_dict()
            checkpoint['optimizer_explorer'] = self.optimizer_explorer.state_dict()
        if self.autotune_ent_coeff:
            checkpoint['log_coeff_entropy'] = self.log_coeff_entropy
            checkpoint['optimizer_entropy'] = self.optimizer_entropy.state_dict()
        if self.autotune_kl_coeff:
            checkpoint['log_coeff_kl'] = self.log_coeff_kl
            checkpoint['optimizer_kl'] = self.optimizer_kl.state_dict()

        return checkpoint


    def set_state_dict(self, checkpoint: dict):
        self.buffer = checkpoint['buffer']

        self.q_net.load_state_dict(checkpoint['q_net_state_dict'])
        self.target_q_net.load_state_dict(checkpoint['target_q_net_state_dict'])
        if self.use_double:
            self.q_net2.load_state_dict(checkpoint['q_net2_state_dict'])
            self.target_q_net2.load_state_dict(checkpoint['target_q_net2_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        self.policy.load_state_dict(checkpoint['policy'])
        self.optimizer_policy.load_state_dict(checkpoint['optimizer_policy'])
        if self.policy_type != self.explorer_type:
            self.explorer.load_state_dict(checkpoint['explorer'])
            self.optimizer_explorer.load_state_dict(checkpoint['optimizer_explorer'])
        if self.autotune_ent_coeff:
            self.log_coeff_entropy = checkpoint['log_coeff_entropy']
            self.optimizer_entropy.load_state_dict(checkpoint['optimizer_entropy'])
        if self.autotune_kl_coeff:
            self.log_coeff_kl = checkpoint['log_coeff_kl']
            self.optimizer_kl.load_state_dict(checkpoint['optimizer_kl'])
