from typing import Sequence, Optional, Dict

import torch
import torch.optim as optim
from torch.nn import functional as F
import torch.nn as nn

from .base_agent import Action, EnvAction, EnvObservation, IAgent
from .replay import OffPolicyMemory


class SAC:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.
        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            # qf1_next_target, qf2_next_target = agent.target_q_value(
            #     observation=batch.next_state(), action=next_state_actions
            # )
            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions
            )

            # min_qf_next_target = (
            #     torch.min(qf1_next_target, qf2_next_target)
            #     - self.alpha * next_state_log_pi
            # )
            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        return info_dict


def sac_candidate_agent(
    agent: IAgent, mock_observation: torch.Tensor, mock_action: torch.Tensor
):
    mass = agent.action(observation=mock_observation)

    action = mass.sample()

    log_probs = mass.log_prob(action=action)

    qf = agent.q_value(observation=mock_observation, action=mock_action)
    qf_target = agent.target_q_value(observation=mock_observation, action=mock_action)

    assert (
        action.dim() == 2
    ), "Max dimensions for actions is 2 (Batch_size, action number)"

    assert log_probs.dim() == 2, "Max dimensions for log probs is two (Batch_size, 1)"
    assert log_probs.size()[1] == 1, "Last dimension for log probs should equal to 1"

    assert isinstance(qf, Sequence), "Q functions should be an Sequence"
    assert len(qf) == 2, "The numbers of Q values should be equals 2"

    assert qf[0].dim() == 2, "Max dimensions for Q functions is two (Batch_size, 1)"
    assert qf[0].size()[1] == 1, "Last dimension for Q functions should equal to 1"

    assert qf[1].dim() == 2, "Max dimensions for Q functions 2 is two (Batch_size, 1)"
    assert qf[1].size()[1] == 1, "Last dimension for Q functions 2 should equal to 1"

    assert isinstance(qf_target, Sequence), "Q functions target should be an Sequence"
    assert len(qf_target) == 2, "The numbers of Q values target should be equals 2"

    assert (
        qf_target[0].dim() == 2
    ), "Max dimensions for Q function target 1 is two (Batch_size, 1)"
    assert (
        qf_target[0].size()[1] == 1
    ), "Last dimension for Q functions target 1 should equal to 1"

    assert (
        qf_target[1].dim() == 2
    ), "Max dimensions for Q functions 2 is two (Batch_size, 1)"
    assert (
        qf_target[1].size()[1] == 1
    ), "Last dimension for Q functions 2 should equal to 1"
    assert (
        isinstance(agent.parameters_actor(), nn.Parameter) is not None
    ), "Parameter actor should be set"

    assert (
        isinstance(agent.parameters_critic(), nn.Parameter) is not None
    ), "Parameter actor should be set"

    # TODO ADD good soft update agents


class SACMPC:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
        mpc_weight: float = 0.1,
        mpc_time_step: int = 4,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.

        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1
        self.mpc_weight = mpc_weight
        self.mpc_time_step = mpc_time_step

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        optimizer_mpc: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        # MODEL PREDICTIVE CONTROL

        # RuntimeError: Trying to backward through the graph a second time (or directly access
        # saved tensors after they have already been freed). Saved intermediate values of
        # the graph are freed when you call .backward() or autograd.grad().
        # Specify retain_graph=True if you need to backward through the graph a second time
        # or if you need to access saved tensors after calling backward.
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        infos_mpc = agent.infos()
        predicted_next_state_1 = infos_mpc["predicted_next_state_1"]
        predicted_next_state_2 = infos_mpc["predicted_next_state_2"]

        mpc_loss_1 = F.mse_loss(predicted_next_state_1, batch.next_state())
        mpc_loss_2 = F.mse_loss(predicted_next_state_2, batch.next_state())
        mpc_loss = (mpc_loss_1 + mpc_loss_2) / 2
        optimizer_mpc.zero_grad()
        mpc_loss.backward()
        mpc_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_mpc.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            mpc_loss=mpc_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        mpc_loss: Optional[torch.Tensor],
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        if mpc_loss is not None:
            info_dict["mpc_loss"] = mpc_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        return info_dict


class SACMPCHydra:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
        mpc_weight: float = 0.1,
        mpc_time_step: int = 4,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.
        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1
        self.mpc_weight = mpc_weight
        self.mpc_time_step = mpc_time_step

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        optimizer_mpc: optim.Optimizer,
        optimizer_critic_no_reg: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        # MODEL PREDICTIVE CONTROL

        # RuntimeError: Trying to backward through the graph a second time (or directly access
        # saved tensors after they have already been freed). Saved intermediate values of
        # the graph are freed when you call .backward() or autograd.grad().
        # Specify retain_graph=True if you need to backward through the graph a second time
        # or if you need to access saved tensors after calling backward.
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        infos_mpc = agent.infos()
        predicted_next_state_1 = infos_mpc["predicted_next_state_1"]
        predicted_next_state_2 = infos_mpc["predicted_next_state_2"]

        mpc_loss_1 = F.mse_loss(predicted_next_state_1, batch.next_state())
        mpc_loss_2 = F.mse_loss(predicted_next_state_2, batch.next_state())
        mpc_loss = (mpc_loss_1 + mpc_loss_2) / 2
        optimizer_mpc.zero_grad()
        mpc_loss.backward()
        mpc_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_mpc.step()

        # CRITIC NO REG

        _, _ = agent.q_value(observation=batch.state(), action=batch.action())  # type: ignore

        _ = agent.target_q_value(
            observation=batch.next_state(), action=next_state_actions  # type: ignore
        )
        infos_mpc = agent.infos()
        q_no_reg_value_1 = infos_mpc["q_no_reg_value_1"]
        q_no_reg_value_2 = infos_mpc["q_no_reg_value_2"]

        target_q_no_reg_value_1 = infos_mpc["target_q_no_reg_value_1"]
        target_q_no_reg_value_2 = infos_mpc["target_q_no_reg_value_2"]
        with torch.no_grad():
            min_qf_no_reg_next_target = torch.min(
                target_q_no_reg_value_1, target_q_no_reg_value_2
            )
            next_q_value = (
                batch.reward()
                + (1 - batch.done())
                * self.discount
                * (min_qf_no_reg_next_target).detach()
            )

        q_no_reg_loss_1 = F.mse_loss(q_no_reg_value_1, next_q_value)
        q_no_reg_loss_2 = F.mse_loss(q_no_reg_value_2, next_q_value)
        q_no_reg_loss = (q_no_reg_loss_1 + q_no_reg_loss_2) / 2
        optimizer_critic_no_reg.zero_grad()
        q_no_reg_loss.backward()
        q_no_reg_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic_no_reg.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            mpc_loss=mpc_loss,
            q_no_reg_grad_norm=q_no_reg_grad_norm,
            q_no_reg_loss=q_no_reg_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        mpc_loss: Optional[torch.Tensor],
        q_no_reg_loss: Optional[torch.Tensor],
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
        q_no_reg_grad_norm: Optional[torch.Tensor],
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        if q_no_reg_loss is not None:
            info_dict["q_no_reg_loss"] = q_no_reg_loss.item()
        if mpc_loss is not None:
            info_dict["mpc_loss"] = mpc_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        if q_no_reg_grad_norm is not None:
            info_dict["q_no_reg_grad_norm"] = q_no_reg_grad_norm.item()
        return info_dict
