# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import *

import sys
if sys.version_info < (3, 8):
    from typing_extensions import Protocol

import abc
import attrs
import itertools
import contextlib
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.kl import kl_divergence

from ..networks import ActivationKind, ValueModel, ActorModel, PostModel, EliteModel
from ..base import AgentBase
from ..utils import lambda_return, n_step_return
from .base import OptimizerCtorCallable, BaseLearning

from ...memory import ExperienceReplay


def value_model_parser(dense_activation_fn, hidden_size, *,
                       world_model: AgentBase):
    return ValueModel(
        world_model.transition_model.x_belief_size,
        world_model.transition_model.x_state_size,
        hidden_size,
        dense_activation_fn,
        action_size=None,
    )


@attrs.define(kw_only=True, auto_attribs=True)
class ValueModelConfig:
    _target_: str = attrs.Factory(lambda: f"{value_model_parser.__module__}.{value_model_parser.__qualname__}")
    _partial_: bool = True

    class InstantiatedT(Protocol):  # for typing
        def __call__(self, *, world_model: AgentBase) -> ValueModel: ...

    dense_activation_fn: ActivationKind = ActivationKind.elu
    hidden_size: int = attrs.field(default=400, validator=attrs.validators.gt(0))  # hidden_layer size



def elite_model_parser(dense_activation_fn, hidden_size, *,
                       world_model: AgentBase):
    return EliteModel(
        world_model.transition_model.x_belief_size,
        world_model.transition_model.x_state_size,
        hidden_size,
        dense_activation_fn,
        action_size=world_model.actor_model.action_size,
    )


@attrs.define(kw_only=True, auto_attribs=True)
class EliteModelConfig:
    _target_: str = attrs.Factory(lambda: f"{elite_model_parser.__module__}.{elite_model_parser.__qualname__}")
    _partial_: bool = True

    class InstantiatedT(Protocol):  # for typing
        def __call__(self, *, world_model: AgentBase) -> EliteModel: ...

    dense_activation_fn: ActivationKind = ActivationKind.elu
    hidden_size: int = attrs.field(default=100, validator=attrs.validators.gt(0))  # hidden_layer size

def q_model_parser(dense_activation_fn, hidden_size, *,
                   world_model: AgentBase):
    return ValueModel(
        world_model.transition_model.x_belief_size,
        world_model.transition_model.x_state_size,
        hidden_size,
        dense_activation_fn,
        action_size=world_model.actor_model.action_size,
    )


@attrs.define(kw_only=True, auto_attribs=True)
class QModelConfig:
    _target_: str = attrs.Factory(lambda: f"{q_model_parser.__module__}.{q_model_parser.__qualname__}")
    _partial_: bool = True

    class InstantiatedT(Protocol):  # for typing
        def __call__(self, *, world_model: AgentBase) -> ValueModel: ...

    dense_activation_fn: ActivationKind = ActivationKind.elu
    hidden_size: int = attrs.field(default=400, validator=attrs.validators.gt(0))  # hidden_layer size



def post_model_parser(dense_activation_fn, hidden_size, *,
                      world_model: AgentBase):
    return PostModel(
        world_model.transition_model.x_belief_size,
        world_model.transition_model.x_state_size,
        hidden_size,
        world_model.actor_model.action_size,
        dense_activation_fn,
    )


@attrs.define(kw_only=True, auto_attribs=True)
class PostModelConfig:
    _target_: str = attrs.Factory(lambda: f"{post_model_parser.__module__}.{post_model_parser.__qualname__}")
    _partial_: bool = True

    class InstantiatedT(Protocol):  # for typing
        def __call__(self, *, world_model: AgentBase) -> ValueModel: ...

    dense_activation_fn: ActivationKind = ActivationKind.elu
    hidden_size: int = attrs.field(default=400, validator=attrs.validators.gt(0))  # hidden_layer size


    
class BasePolicyLearning(BaseLearning):

    @attrs.define(kw_only=True, auto_attribs=True)
    class Config:
        _kind_: ClassVar[str]
        _target_: str
        _partial_: bool = True

        class InstantiatedT(Protocol):  # for typing
            def __call__(self, *, world_model: AgentBase, optimizer_ctor: OptimizerCtorCallable,
                         actor_grad_clip_norm: Optional[float] = None) -> 'BasePolicyLearning': ...

        discount: float = attrs.field(default=0.99, validator=[attrs.validators.gt(0), attrs.validators.lt(1)])
        actor_lr: float = attrs.field(default=3e-4, validator=attrs.validators.gt(0))
        actor_grad_clip_norm: Optional[float] = attrs.field(default=100, validator=attrs.validators.optional(attrs.validators.gt(0)))
        value_grad_clip_norm: Optional[float] = attrs.field(default=100, validator=attrs.validators.optional(attrs.validators.gt(0)))

    discount: float

    def __init__(self, discount: float, actor_lr: float, value_lr: float, *,
                 world_model: AgentBase, optimizer_ctor: OptimizerCtorCallable,
                 actor_grad_clip_norm: Optional[float] = None,
                 value_grad_clip_norm: Optional[float] = None):
        super().__init__(optimizer_ctor=optimizer_ctor, device=world_model.device)
        self.discount = discount
        self.add_optimizer(
           'post-actor', parameters=list(world_model.actor_model.parameters()) + list(world_model.post_model.parameters()),
           lr=actor_lr, grad_clip_norm=actor_grad_clip_norm)
        
        self.add_optimizer(
            'value', parameters=world_model.value_model.parameters(),
            lr=value_lr, grad_clip_norm=value_grad_clip_norm)
        
    @abc.abstractmethod
    def train_step(self, data: ExperienceReplay.Data, train_out: Optional[AgentBase.TrainOutput],
                   world_model: AgentBase) -> Dict[str, torch.Tensor]:
        pass


def zip_strict(*iterables: Iterable) -> Iterable:
    r"""
    ``zip()`` function but enforces that iterables are of equal length.
    Raises ``ValueError`` if iterables not of equal length.
    Code inspired by Stackoverflow answer for question #32954486.
    :param \*iterables: iterables to ``zip()``
    """
    # As in Stackoverflow #32954486, use
    # new object for "empty" in case we have
    # Nones in iterable.
    sentinel = object()
    for combo in itertools.zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError("Iterables have different lengths")
        yield combo


def polyak_update(
    params: Iterable[torch.nn.Parameter],
    target_params: Iterable[torch.nn.Parameter],
    tau: float,
) -> None:
    r"""
    https://github.com/DLR-RM/stable-baselines3/blob/e24147390d2ce3b39cafc954e079d693a1971330/stable_baselines3/common/utils.py#L410

    Perform a Polyak average update on ``target_params`` using ``params``:
    target parameters are slowly updated towards the main parameters.
    ``tau``, the soft update coefficient controls the interpolation:
    ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
    The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
    or a computation graph, reducing memory cost and improving performance.  We scale the target params
    by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
    params (in place).
    See https://github.com/DLR-RM/stable-baselines3/issues/93
    :param params: parameters to use to update the target params
    :param target_params: parameters to update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
    """
    with torch.no_grad():
        # zip does not raise an exception if length of parameters does not match.
        for param, target_param in zip_strict(params, target_params):
            target_param.data.mul_(1 - tau)
            torch.add(target_param.data, param.data, alpha=tau, out=target_param.data)


class RLAlgorithmInput(NamedTuple):
    observation: torch.Tensor
    action: torch.Tensor
    reward: torch.Tensor
    next_observation: torch.Tensor
    nonterminal: Optional[torch.Tensor]


class GeneralRLLearning(BasePolicyLearning):

    @attrs.define(kw_only=True, auto_attribs=True)
    class Config(BasePolicyLearning.Config):
        actor_lr: float = attrs.field(default=3e-4, validator=attrs.validators.gt(0))
        actor_grad_clip_norm: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0)))


    def create_input(self, data: ExperienceReplay.Data, train_out: Optional[AgentBase.TrainOutput]):
        r"""
        Convert `s`-space trajectorcy to `x`-space trajectory
        """
        with contextlib.ExitStack() as stack:
            stack.enter_context(torch.no_grad())
            assert train_out is not None
            posterior_latent_state = train_out.transition_output.posterior_latent_state
            posterior_latent_state = posterior_latent_state.new_emptydim(y=True, z=True).detach()
            full_obs = posterior_latent_state.x_feature.detach()
            # print(full_obs.shape)
            full_rew = train_out.reward_denoised_prediction_mean

            obs = full_obs[:-1].flatten(0, 1)
            next_obs = full_obs[1:].flatten(0, 1)
            act = data.action[1:].flatten(0, 1)
            rew = full_rew[1:].flatten(0, 1)
            # For `nonterminal`, we assume that the env only ends by timing out. So any transition is
            # non-terminal really, except for the potential (a_T, s'_next_traj_0) in DenoisedMDP data.
            if data.next_observation_nonfirststep is None:
                nonterminal = None
            else:
                raise NotImplementedError
            return RLAlgorithmInput(
                observation=obs,
                action=act,
                reward=rew,
                next_observation=next_obs,
                nonterminal=nonterminal,
            )

    @abc.abstractmethod
    def rl_algorithm_step(self, input: RLAlgorithmInput, actor_model: ActorModel) -> Dict[str, torch.Tensor]:
        pass

    def train_step(self, data: ExperienceReplay.Data, train_out: Optional[AgentBase.TrainOutput],
                   world_model: AgentBase) -> Dict[str, torch.Tensor]:
        return self.rl_algorithm_step(
            self.create_input(data, train_out),
            world_model.actor_model,
        )




class DirectProbabilisticControl(GeneralRLLearning):

    @attrs.define(kw_only=True, auto_attribs=True)
    class Config(GeneralRLLearning.Config):
        _kind_: ClassVar[str] = 'dpc'
        _target_: str = attrs.Factory(lambda: f"{DirectProbabilisticControl.__module__}.{DirectProbabilisticControl.__qualname__}")

        target_update_interval: int = attrs.field(default=1, validator=attrs.validators.gt(0))
        tau: float = attrs.field(default=0.005, validator=[attrs.validators.gt(0), attrs.validators.lt(1)])
        quantile_parameter: float = attrs.field(default=0.5, validator=[attrs.validators.gt(0), attrs.validators.lt(1)])
        num_samples: int = attrs.field(default=1, validator=attrs.validators.gt(0))
        vae_lr: float = attrs.field(default=8e-5, validator=attrs.validators.gt(0))
        elite_lr: float = attrs.field(default=3e-4, validator=attrs.validators.gt(0))
        q: QModelConfig = attrs.Factory(QModelConfig)
        p: PostModelConfig = attrs.Factory(PostModelConfig)
        e: EliteModelConfig = attrs.Factory(EliteModelConfig)
        
    plan_horizon_discount: torch.Tensor

    def __init__(self, discount: float,
                 actor_lr: float, actor_grad_clip_norm: Optional[float],
                 target_update_interval: int, tau: float,
                 q: QModelConfig.InstantiatedT, quantile_parameter: float, num_samples: float,
                 p: PostModelConfig.InstantiatedT,  vae_lr: float,
                 e: EliteModelConfig.InstantiatedT, elite_lr: float,
                #  imagine_action_noise: float,
                 *,
                 world_model: AgentBase, optimizer_ctor: OptimizerCtorCallable):
        super().__init__(
            discount=discount,
            actor_lr=actor_lr, actor_grad_clip_norm=actor_grad_clip_norm,
            world_model=world_model, optimizer_ctor=optimizer_ctor)
        self.target_update_interval = target_update_interval
        self.tau = tau
        self.quantile_parameter = quantile_parameter
        self.num_samples = num_samples

        self.U_model: ValueModel = q(world_model=world_model).to(self.device)
        self.U_model_target: ValueModel = q(world_model=world_model).to(self.device)

        self.p_model: PostModel = p(world_model=world_model).to(self.device)

        self.e_model: EliteModel = e(world_model=world_model).to(self.device)
                
        self.add_optimizer(
            'vae',
            parameters=itertools.chain(self.U_model.parameters(), self.p_model.parameters()),
            lr=vae_lr)
            #lr=vae_lr, grad_clip_norm=10)

        self.add_optimizer(
            'elite',
            parameters=self.e_model.parameters(),
            lr=elite_lr)
         
        self.num_steps = 0

    def trainable_parameters(self):
        yield from self.U_model.parameters()
        yield from self.p_model.parameters()
        yield from self.e_model.parameters()
        
    def rl_algorithm_step(self, input: RLAlgorithmInput, actor_model: ActorModel) -> Dict[str, torch.Tensor]:
        act_distn = actor_model.get_action_distn(input.observation)
        on_pi_act: torch.Tensor = act_distn.rsample()
        on_pi_act_log_prob: torch.Tensor = act_distn.log_prob(on_pi_act)

        losses: Dict[str, torch.Tensor] = {}

        # q loss
        with torch.no_grad():
            # target q
            next_act_distn = actor_model.get_action_distn(input.next_observation)
            next_on_pi_act: torch.Tensor = next_act_distn.rsample()

            next_U_dist = self.U_model_target(input.next_observation, next_on_pi_act)
            next_U = next_U_dist.rsample()
            
            if input.nonterminal is None:
                target_U = input.reward.unsqueeze(-1) + self.discount * next_U
            else:
                target_U = input.reward.unsqueeze(-1) + input.nonterminal * self.discount * next_U
            assert (next_U.size() == target_U.size())
            target_U = target_U.detach()

                        
        policy = actor_model.get_action_distn(input.observation) # T x B x A
        posterior = self.p_model.get_action_distn(input.observation, target_U)
        kl_loss = kl_divergence(posterior, policy).sum(-1).mean()

        z = posterior.rsample((self.num_samples, )) # ns x T x B x A

        #with torch.no_grad():
        repeated_observation = input.observation.unsqueeze(0).repeat(self.num_samples, 1, 1)

        U_dist = self.U_model(repeated_observation, z)

        expanded_target_U = target_U.unsqueeze(0).repeat(self.num_samples, 1, 1)
        U_loss = U_dist.log_prob(expanded_target_U).mean()# - 0.01 * U_dist.stddev.pow(2).mean()

        vae_loss = 1 * kl_loss - 1 * U_loss
        
        with self.optimizers['vae'].update_context():
            vae_loss.backward()
        losses['vae_loss'] = vae_loss

        # Ranking distributions by the q-quantile
        new_action = self.e_model(input.observation)
        U_dist = self.U_model(input.observation, new_action)
        Uq = U_dist.mean + U_dist.stddev * math.sqrt(2) * torch.erfinv(torch.FloatTensor([2 * self.quantile_parameter - 1]).to(self.device))
        Q_value = U_dist.mean.detach()

        elite_loss = -Uq.mean()

        with self.optimizers['elite'].update_context():
            elite_loss.backward()
        losses['elite_loss'] = elite_loss

        # project policy to the posterior
        with torch.no_grad():
            elite_action = self.e_model(input.observation)
            elite_U_dist = self.U_model(input.observation, elite_action)
            elite_Uq = elite_U_dist.mean + elite_U_dist.stddev * math.sqrt(2) * torch.erfinv(torch.FloatTensor([2 * self.quantile_parameter - 1]).to(self.device))
            posterior = self.p_model.get_action_distn(input.observation, elite_Uq)

        policy = actor_model.get_action_distn(input.observation)
        actor_loss = kl_divergence(policy, posterior).sum(-1).mean()

        with self.optimizers['actor'].update_context():
            actor_loss.backward()
        losses['actor_loss'] = actor_loss

        # Update target networks
        self.num_steps += 1
        if self.num_steps % self.target_update_interval == 0:
            polyak_update(self.U_model.parameters(), self.U_model_target.parameters(), self.tau)

        if self.num_steps % 100 == 0:
            print(f"Q value {Q_value.mean()}")
            print(f"vae_loss {vae_loss}")
            print(f"elite_loss {elite_loss}")
            print(f"actor_loss {actor_loss}")

        return losses

class SoftActorCritic(GeneralRLLearning):
    r"""
    https://github.com/DLR-RM/stable-baselines3/blob/e24147390d2ce3b39cafc954e079d693a1971330/stable_baselines3/sac/sac.py
    """

    @attrs.define(kw_only=True, auto_attribs=True)
    class Config(GeneralRLLearning.Config):
        _kind_: ClassVar[str] = 'sac'
        _target_: str = attrs.Factory(lambda: f"{SoftActorCritic.__module__}.{SoftActorCritic.__qualname__}")

        target_update_interval: int = attrs.field(default=1, validator=attrs.validators.gt(0))
        tau: float = attrs.field(default=0.005, validator=[attrs.validators.gt(0), attrs.validators.lt(1)])
        q: QModelConfig = attrs.Factory(QModelConfig)
        q_ent_coef_lr: float = attrs.field(default=3e-4, validator=attrs.validators.gt(0))
        # use ent_coef = auto, target_entropy = auto


    plan_horizon_discount: torch.Tensor

    def __init__(self, discount: float,
                 actor_lr: float, actor_grad_clip_norm: Optional[float],
                 target_update_interval: int, tau: float,
                 q: QModelConfig.InstantiatedT, q_ent_coef_lr: float,
                #  imagine_action_noise: float,
                 *,
                 world_model: AgentBase, optimizer_ctor: OptimizerCtorCallable):
        super().__init__(
            discount=discount,
            actor_lr=actor_lr, actor_grad_clip_norm=actor_grad_clip_norm,
            world_model=world_model, optimizer_ctor=optimizer_ctor)
        self.target_update_interval = target_update_interval
        self.tau = tau

        self.q_model_0: ValueModel = q(world_model=world_model).to(self.device)
        self.q_model_1: ValueModel = q(world_model=world_model).to(self.device)

        self.q_model_0_target: ValueModel = q(world_model=world_model).to(self.device)
        self.q_model_0_target.load_state_dict(self.q_model_0.state_dict())
        self.q_model_1_target: ValueModel = q(world_model=world_model).to(self.device)
        self.q_model_1_target.load_state_dict(self.q_model_1.state_dict())

        self.add_optimizer(
            'q',
            parameters=itertools.chain(self.q_model_0.parameters(), self.q_model_1.parameters()),
            lr=q_ent_coef_lr)

        self.log_ent_coef = nn.Parameter(torch.zeros((), device=self.device, requires_grad=True))
        self.target_entropy = -float(world_model.actor_model.action_size)

        self.add_optimizer(
            'ent_coef',
            parameters=[self.log_ent_coef],
            lr=q_ent_coef_lr)

        self.num_steps = 0

    def trainable_parameters(self):
        yield from self.q_model_0.parameters()
        yield from self.q_model_1.parameters()
        yield self.log_ent_coef

    def rl_algorithm_step(self, input: RLAlgorithmInput, actor_model: ActorModel) -> Dict[str, torch.Tensor]:
        act_distn = actor_model.get_action_distn(input.observation)
        on_pi_act: torch.Tensor = act_distn.rsample()
        on_pi_act_log_prob: torch.Tensor = act_distn.log_prob(on_pi_act)

        losses: Dict[str, torch.Tensor] = {}

        ent_coef = self.log_ent_coef.detach().exp()

        # q loss
        with torch.no_grad():
            # target q
            next_act_distn = actor_model.get_action_distn(input.next_observation)
            next_on_pi_act: torch.Tensor = next_act_distn.rsample()
            next_on_pi_act_log_prob: torch.Tensor = next_act_distn.log_prob(on_pi_act)

            next_q = torch.min(
                self.q_model_0_target(input.next_observation, next_on_pi_act),
                self.q_model_1_target(input.next_observation, next_on_pi_act),
            )
            assert next_on_pi_act_log_prob.ndim == next_q.ndim == 1
            next_q = next_q - ent_coef * next_on_pi_act_log_prob
            if input.nonterminal is None:
                target_q = input.reward + self.discount * next_q
            else:
                target_q = input.reward + input.nonterminal * self.discount * next_q
            assert target_q.ndim == 1
            target_q = target_q.detach()

        data_q0 = self.q_model_0(input.observation, input.action)
        data_q1 = self.q_model_1(input.observation, input.action)
        q0_loss = 0.5 * F.mse_loss(data_q0, target_q)
        q1_loss = 0.5 * F.mse_loss(data_q1, target_q)
        with self.optimizers['q'].update_context():
            q0_loss.backward()
            q1_loss.backward()
        losses['q0_loss'] = q0_loss
        losses['q1_loss'] = q1_loss

        # actor loss
        current_q = torch.min(
            self.q_model_0(input.observation, on_pi_act),
            self.q_model_1(input.observation, on_pi_act),
        )
        actor_loss = (ent_coef * on_pi_act_log_prob - current_q)
        assert actor_loss.ndim == 1
        actor_loss = actor_loss.mean()
        with self.optimizers['actor'].update_context():
            actor_loss.backward()
        losses['actor_loss'] = actor_loss

        # ent_coef loss
        with self.optimizers['ent_coef'].update_context():
            ent_coef_loss = -(self.log_ent_coef * (on_pi_act_log_prob.detach() + self.target_entropy)).mean()
            ent_coef_loss.backward()
        losses['ent_coef_loss'] = ent_coef_loss

        # Update target networks
        self.num_steps += 1
        if self.num_steps % self.target_update_interval == 0:
            polyak_update(self.q_model_0.parameters(), self.q_model_0_target.parameters(), self.tau)
            polyak_update(self.q_model_1.parameters(), self.q_model_1_target.parameters(), self.tau)

        return losses


class DynamicsBackpropagateActorCritic(BasePolicyLearning):
    r"""
    Dreamer-style
    """

    @attrs.define(kw_only=True, auto_attribs=True)
    class Config(BasePolicyLearning.Config):
        _kind_: ClassVar[str] = 'dynamics_backprop'
        _target_: str = attrs.Factory(lambda: f"{DynamicsBackpropagateActorCritic.__module__}.{DynamicsBackpropagateActorCritic.__qualname__}")
        planning_horizon: int = attrs.field(default=15, validator=attrs.validators.gt(0))
        discount: float = attrs.field(default=0.99, validator=[attrs.validators.gt(0), attrs.validators.lt(1)])
        actor_lr: float = attrs.field(default=8e-5, validator=attrs.validators.gt(0))
        actor_grad_clip_norm: Optional[float] = attrs.field(default=100, validator=attrs.validators.optional(attrs.validators.gt(0)))
        value_lr: float = attrs.field(default=8e-5, validator=attrs.validators.gt(0))
        value_grad_clip_norm: Optional[float] = attrs.field(default=100, validator=attrs.validators.optional(attrs.validators.gt(0)))
                
    def __init__(self, discount: float, actor_lr: float, actor_grad_clip_norm: float,
                 value_lr: float, value_grad_clip_norm: float,
                 planning_horizon: int, *,
                 world_model: AgentBase, optimizer_ctor: OptimizerCtorCallable):
        super().__init__(discount=discount, actor_lr=actor_lr, value_lr=value_lr, actor_grad_clip_norm=actor_grad_clip_norm, value_grad_clip_norm=value_grad_clip_norm,
                         world_model=world_model, optimizer_ctor=optimizer_ctor)
        self.planning_horizon = planning_horizon
        
        self.num_steps = 0

    def train_step(self, data: ExperienceReplay.Data, train_out: Optional[AgentBase.TrainOutput],
                   world_model: AgentBase) -> Dict[str, torch.Tensor]:
        self.num_steps += 1
        assert train_out is not None
        imagine_out = world_model.imagine_ahead_noiseless(
            previous_latent_state=train_out.posterior_latent_state,
            freeze_latent_model=True,
            planning_horizon=self.planning_horizon,
        )
        
        value_first_dist = world_model.value_model(imagine_out.latent_states[0], imagine_out.actions[0])
        value_last_dist = world_model.value_model(imagine_out.latent_states[-1], imagine_out.actions[-1])
        value_prediction_last = value_last_dist.rsample()

        return_prediction, value_target, stddev = n_step_return(
            imagine_out.reward_mean, value_last_dist.mean, value_last_dist.stddev, bootstrap=value_prediction_last, discount=self.discount
        )
            
        reg_loss = kl_divergence(torch.distributions.Normal(return_prediction[0], stddev[0]), value_first_dist).mean()
        
        policy = world_model.actor_model.get_action_distn(imagine_out.latent_states[0])
        posterior = world_model.post_model.get_action_distn(imagine_out.latent_states[0])
        kl_loss = kl_divergence(posterior, policy).mean()
        
        rpg_loss = -value_target[0].mean()
        
        loss = rpg_loss + reg_loss + kl_loss
        
        with self.optimizers['post-actor'].update_context():
            torch.autograd.backward(
                loss,
                retain_graph=True,
                inputs=list(world_model.actor_model.parameters()) + list(world_model.post_model.parameters()),
            )
            
        value_loss: torch.Tensor = -value_first_dist.log_prob(value_target[0].detach()).mean()

        with self.optimizers['value'].update_context():
            torch.autograd.backward(
                value_loss,
                inputs=list(world_model.value_model.parameters()),
            )
            
        return dict(ap_loss=loss, rpg_loss=rpg_loss, value_loss=value_loss, kl_loss=kl_loss, reg_loss=reg_loss)
