from typing import Sequence, Optional, Dict, Tuple

import torch
import torch.optim as optim
from torch.nn import functional as F
import torch.nn as nn

from .base_agent import Action, EnvAction, EnvObservation, IAgent
from .replay import OffPolicyMemory


class SAC:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.

        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            # qf1_next_target, qf2_next_target = agent.target_q_value(
            #     observation=batch.next_state(), action=next_state_actions
            # )
            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions
            )

            # min_qf_next_target = (
            #     torch.min(qf1_next_target, qf2_next_target)
            #     - self.alpha * next_state_log_pi
            # )
            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        return info_dict


def sac_candidate_agent(
    agent: IAgent, mock_observation: torch.Tensor, mock_action: torch.Tensor
):
    mass = agent.action(observation=mock_observation)

    action = mass.sample()

    log_probs = mass.log_prob(action=action)

    qf = agent.q_value(observation=mock_observation, action=mock_action)
    qf_target = agent.target_q_value(observation=mock_observation, action=mock_action)

    assert (
        action.dim() == 2
    ), "Max dimensions for actions is 2 (Batch_size, action number)"

    assert log_probs.dim() == 2, "Max dimensions for log probs is two (Batch_size, 1)"
    assert log_probs.size()[1] == 1, "Last dimension for log probs should equal to 1"

    assert isinstance(qf, Sequence), "Q functions should be an Sequence"
    assert len(qf) == 2, "The numbers of Q values should be equals 2"

    assert qf[0].dim() == 2, "Max dimensions for Q functions is two (Batch_size, 1)"
    assert qf[0].size()[1] == 1, "Last dimension for Q functions should equal to 1"

    assert qf[1].dim() == 2, "Max dimensions for Q functions 2 is two (Batch_size, 1)"
    assert qf[1].size()[1] == 1, "Last dimension for Q functions 2 should equal to 1"

    assert isinstance(qf_target, Sequence), "Q functions target should be an Sequence"
    assert len(qf_target) == 2, "The numbers of Q values target should be equals 2"

    assert (
        qf_target[0].dim() == 2
    ), "Max dimensions for Q function target 1 is two (Batch_size, 1)"
    assert (
        qf_target[0].size()[1] == 1
    ), "Last dimension for Q functions target 1 should equal to 1"

    assert (
        qf_target[1].dim() == 2
    ), "Max dimensions for Q functions 2 is two (Batch_size, 1)"
    assert (
        qf_target[1].size()[1] == 1
    ), "Last dimension for Q functions 2 should equal to 1"
    assert (
        isinstance(agent.parameters_actor(), nn.Parameter) is not None
    ), "Parameter actor should be set"

    assert (
        isinstance(agent.parameters_critic(), nn.Parameter) is not None
    ), "Parameter actor should be set"

    # TODO ADD good soft update agents


class SACMPC:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
        mpc_weight: float = 0.1,
        mpc_time_step: int = 4,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.

        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1
        self.mpc_weight = mpc_weight
        self.mpc_time_step = mpc_time_step

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        optimizer_mpc: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        # MODEL PREDICTIVE CONTROL

        # RuntimeError: Trying to backward through the graph a second time (or directly access
        # saved tensors after they have already been freed). Saved intermediate values of
        # the graph are freed when you call .backward() or autograd.grad().
        # Specify retain_graph=True if you need to backward through the graph a second time
        # or if you need to access saved tensors after calling backward.
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        infos_mpc = agent.infos()
        predicted_next_state_1 = infos_mpc["predicted_next_state_1"]
        predicted_next_state_2 = infos_mpc["predicted_next_state_2"]

        mpc_loss_1 = F.mse_loss(predicted_next_state_1, batch.next_state())
        mpc_loss_2 = F.mse_loss(predicted_next_state_2, batch.next_state())
        mpc_loss = (mpc_loss_1 + mpc_loss_2) / 2
        optimizer_mpc.zero_grad()
        mpc_loss.backward()
        mpc_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_mpc.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            mpc_loss=mpc_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        mpc_loss: Optional[torch.Tensor],
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        if mpc_loss is not None:
            info_dict["mpc_loss"] = mpc_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        return info_dict


class SACMPCNoReg:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
        mpc_weight: float = 0.1,
        mpc_time_step: int = 4,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.


        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1
        self.mpc_weight = mpc_weight
        self.mpc_time_step = mpc_time_step

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        optimizer_mpc: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # TODO: Real LOSS RAVI
            min_qf_next_target = torch.min(*q_next_target)
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        # MODEL PREDICTIVE CONTROL

        # RuntimeError: Trying to backward through the graph a second time (or directly access
        # saved tensors after they have already been freed). Saved intermediate values of
        # the graph are freed when you call .backward() or autograd.grad().
        # Specify retain_graph=True if you need to backward through the graph a second time
        # or if you need to access saved tensors after calling backward.
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        infos_mpc = agent.infos()
        predicted_next_state_1 = infos_mpc["predicted_next_state_1"]
        predicted_next_state_2 = infos_mpc["predicted_next_state_2"]

        mpc_loss_1 = F.mse_loss(predicted_next_state_1, batch.next_state())
        mpc_loss_2 = F.mse_loss(predicted_next_state_2, batch.next_state())
        mpc_loss = (mpc_loss_1 + mpc_loss_2) / 2
        optimizer_mpc.zero_grad()
        mpc_loss.backward()
        mpc_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_mpc.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            mpc_loss=mpc_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        mpc_loss: Optional[torch.Tensor],
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        if mpc_loss is not None:
            info_dict["mpc_loss"] = mpc_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        return info_dict




class SACMPCHydra:
    def __init__(
        self,
        batch_size: int,
        target_entropy: torch.Tensor,  # -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() TODO: CORRECT with float
        discount: float,
        tau: float,
        reparametrized_sample: bool,
        device: torch.device = torch.device("cuda"),
        temperature_lr: float = 3e-4,  # TODO: discussion
        reward_scale: int = 1,
        max_grad_norm: float = 0.5,
        policy_frequency: int = 2,
        mpc_weight: float = 0.1,
        mpc_time_step: int = 4,
    ):
        """Soft Actor Critic with two Q-networks and adjusted temperature.

        Args:
            batch_size (int): Batch size of the optimization process
            target_entropy (torch.Tensor): Target entropy for the adjusted temperature
            tau (float): Value of the soft update
            reparametrized_sample (bool): If Normal distribution is use you should set this value at True
            device (torch.device, optional): Device which run the optimization. Defaults to torch.device("cuda").
            temperature_lr (float, optional): Learning rate for adjusted temperature. Defaults to 3e-4.


        """
        self.batch_size = batch_size
        self.target_entropy = target_entropy
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=temperature_lr, eps=1e-7)
        self.alpha = self.log_alpha.exp().item()
        self.discount = discount
        self.tau = tau
        self.reparametrized_sample = reparametrized_sample
        self.max_grad_norm = max_grad_norm
        self.policy_frequency = policy_frequency
        self.step = 1
        self.mpc_weight = mpc_weight
        self.mpc_time_step = mpc_time_step

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        optimizer_mpc: optim.Optimizer,
        optimizer_critic_no_reg: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """
        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            q_next_target = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # TODO: Real LOSS RAVI
            min_qf_next_target = (
                torch.min(*q_next_target) - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        # MODEL PREDICTIVE CONTROL

        # RuntimeError: Trying to backward through the graph a second time (or directly access
        # saved tensors after they have already been freed). Saved intermediate values of
        # the graph are freed when you call .backward() or autograd.grad().
        # Specify retain_graph=True if you need to backward through the graph a second time
        # or if you need to access saved tensors after calling backward.
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        infos_mpc = agent.infos()
        predicted_next_state_1 = infos_mpc["predicted_next_state_1"]
        predicted_next_state_2 = infos_mpc["predicted_next_state_2"]

        mpc_loss_1 = F.mse_loss(predicted_next_state_1, batch.next_state())
        mpc_loss_2 = F.mse_loss(predicted_next_state_2, batch.next_state())
        mpc_loss = (mpc_loss_1 + mpc_loss_2) / 2
        optimizer_mpc.zero_grad()
        mpc_loss.backward()
        mpc_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_mpc.step()

        # CRITIC NO REG

        _, _ = agent.q_value(observation=batch.state(), action=batch.action())  # type: ignore

        _ = agent.target_q_value(
            observation=batch.next_state(), action=next_state_actions  # type: ignore
        )
        infos_mpc = agent.infos()
        q_no_reg_value_1 = infos_mpc["q_no_reg_value_1"]
        q_no_reg_value_2 = infos_mpc["q_no_reg_value_2"]

        target_q_no_reg_value_1 = infos_mpc["target_q_no_reg_value_1"]
        target_q_no_reg_value_2 = infos_mpc["target_q_no_reg_value_2"]
        with torch.no_grad():
            min_qf_no_reg_next_target = torch.min(
                target_q_no_reg_value_1, target_q_no_reg_value_2
            )
            next_q_value = (
                batch.reward()
                + (1 - batch.done())
                * self.discount
                * (min_qf_no_reg_next_target).detach()
            )

        q_no_reg_loss_1 = F.mse_loss(q_no_reg_value_1, next_q_value)
        q_no_reg_loss_2 = F.mse_loss(q_no_reg_value_2, next_q_value)
        q_no_reg_loss = (q_no_reg_loss_1 + q_no_reg_loss_2) / 2
        optimizer_critic_no_reg.zero_grad()
        q_no_reg_loss.backward()
        q_no_reg_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic_no_reg.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            mpc_loss=mpc_loss,
            q_no_reg_grad_norm=q_no_reg_grad_norm,
            q_no_reg_loss=q_no_reg_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info

    @classmethod
    def _update_info_loss(
        cls,
        policy_loss: Optional[torch.Tensor],
        q_loss: torch.Tensor,
        mpc_loss: Optional[torch.Tensor],
        q_no_reg_loss: Optional[torch.Tensor],
        alpha: float,
        alpha_loss: Optional[torch.Tensor],
        actor_grad_norm: Optional[torch.Tensor],
        critic_grad_norm: torch.Tensor,
        q_no_reg_grad_norm: Optional[torch.Tensor],
    ) -> Dict[str, float]:
        info_dict = {}
        if policy_loss is not None:
            info_dict["policy_loss"] = policy_loss.item()

        info_dict["q_loss"] = q_loss.item()
        if q_no_reg_loss is not None:
            info_dict["q_no_reg_loss"] = q_no_reg_loss.item()
        if mpc_loss is not None:
            info_dict["mpc_loss"] = mpc_loss.item()
        info_dict["alpha"] = alpha

        if alpha_loss is not None:
            info_dict["alpha_loss"] = alpha_loss.item()
        if actor_grad_norm is not None:
            info_dict["actor_grad_norm"] = actor_grad_norm.item()

        info_dict["critic_grad_norm"] = critic_grad_norm.item()
        if q_no_reg_grad_norm is not None:
            info_dict["q_no_reg_grad_norm"] = q_no_reg_grad_norm.item()
        return info_dict


class RAVI(SAC):
    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """

        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            # qf1_next_target, qf2_next_target = agent.target_q_value(
            #     observation=batch.next_state(), action=next_state_actions
            # )
            (
                qf1_next_target,
                qf2_next_target,
                q1_robust,
                q2_robust,
            ) = agent.target_q_value(
                observation=batch.next_state(), action=next_state_actions  # type: ignore
            )

            # min_qf_next_target = (
            #     torch.min(qf1_next_target, qf2_next_target)
            #     - self.alpha * next_state_log_pi
            # )

            min_qf_next_target = (
                torch.min(qf1_next_target, qf2_next_target)
                - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
            # Here the magic happen
            q_robust_min = torch.min(q1_robust, q2_robust)
            next_q_value = torch.min(next_q_value, q_robust_min)
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info


class LifeLongRAVI(SAC):
    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, float]:
        """Update with one gradient step.

        Args:
            replay (OffPolicyMemory): Replay buffer
            agent (IAgent[EnvObservation, EnvAction, Action]): Implemented agent (sac candidate)
            optimizer_actor (optim.Optimizer): Actor's optimizer
            optimizer_critic (optim.Optimizer): Critic's optimizer
            scheduler (Optional[optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler. Defaults to None.

        Returns:
            Dict[str, float]: Information on loss
        """

        batch = replay.sample(self.batch_size)

        with torch.no_grad():
            next_mass = agent.action(observation=batch.next_state())  # type: ignore
            next_state_actions = next_mass.sample()
            next_state_log_pi = next_mass.log_prob(action=next_state_actions)

            # qf1_next_target, qf2_next_target = agent.target_q_value(
            #     observation=batch.next_state(), action=next_state_actions
            # )
            target_values = agent.all_target_q_value(  # type: ignore
                observation=batch.next_state(), action=next_state_actions
            )
            qf1_next_target, qf2_next_target = target_values["target"]
            robust_q_value = target_values["robust"]
            robust_q_value_group = torch.stack(robust_q_value, dim=-1)
            q_robust_min, _ = torch.min(robust_q_value_group, dim=-1)
            # min_qf_next_target = (
            #     torch.min(qf1_next_target, qf2_next_target)
            #     - self.alpha * next_state_log_pi
            # )

            min_qf_next_target = (
                torch.min(qf1_next_target, qf2_next_target)
                - self.alpha * next_state_log_pi
            )
            next_q_value = batch.reward() + (1 - batch.done()) * self.discount * (
                min_qf_next_target
            )
            # Here the magic happen
            next_q_value = torch.min(next_q_value, q_robust_min)
        # NOTE: other potential refactors
        # 1. using qf1() instead, which calls forward() by default
        # 2. use F.mse_loss() instead of loss_fn = nn.MSELoss()
        qf1_a_values, qf2_a_values = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = (qf1_loss + qf2_loss) / 2

        optimizer_critic.zero_grad()
        qf_loss.backward()
        qf_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_critic(), self.max_grad_norm
        )
        optimizer_critic.step()

        if self.step % self.policy_frequency == 0:
            # Policy
            for _ in range(self.policy_frequency):
                mass = agent.action(observation=batch.state())  # type: ignore
                pi = mass.rsample()
                log_pi = mass.log_prob(action=pi)
                qf1_pi, qf2_pi = agent.q_value(observation=batch.state(), action=pi)  # type: ignore
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

                optimizer_actor.zero_grad()
                policy_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                    agent.parameters_actor(), self.max_grad_norm
                )
                optimizer_actor.step()
                # alpha
                with torch.no_grad():
                    mass = agent.action(observation=batch.state())  # type: ignore
                    pi = mass.sample()
                    log_pi = mass.log_prob(action=pi)

                alpha_loss = (-self.log_alpha * (log_pi + self.target_entropy)).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp().item()
        else:
            policy_loss = None
            actor_grad_norm = None
            alpha_loss = None

        # Add cond soft update
        agent.soft_update(tau=self.tau)  # soft update for q_networks

        info = self._update_info_loss(
            policy_loss=policy_loss,
            q_loss=qf_loss,
            alpha=self.alpha,
            alpha_loss=alpha_loss,
            actor_grad_norm=actor_grad_norm,
            critic_grad_norm=qf_grad_norm,
        )
        self.step += 1
        return info


class ImitationLearningPolicy:
    def __init__(
        self,
        base_temperature: float = 1.0,
        batch_size: int = 256,
        max_grad_norm: float = 0.5,
    ) -> None:
        self.base_temperature = base_temperature
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: Optional[optim.Optimizer] = None,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ):
        batch = replay.sample(self.batch_size)
        mass = agent.action(observation=batch.state())  # type: ignore
        pi = mass.rsample()
        log_pi = mass.log_prob(action=pi)
        target_values = agent.all_target_q_value(observation=batch.state(), action=pi)  # type: ignore

        robust_q_value: Tuple[torch.Tensor] = target_values["robust"]
        robust_q_value_group = torch.stack(robust_q_value, dim=-1)
        imitation_target, _ = torch.min(robust_q_value_group, dim=-1)

        policy_loss = ((self.base_temperature * log_pi) - imitation_target).mean()

        optimizer_actor.zero_grad()
        policy_loss.backward()
        actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
            agent.parameters_actor(), self.max_grad_norm
        )
        optimizer_actor.step()
        info = {}
        info["imitation_policy_loss"] = policy_loss.item()
        info["imimation_grad_norm"] = actor_grad_norm.item()
        return info


class ImitationLearningMPCPolicy:
    def __init__(
        self,
        base_temperature: float = 1.0,
        batch_size: int = 256,
        max_grad_norm: float = 0.5,
        mpc_threshold: float = 0.5,
        minimum_withdraw: Optional[int] = None,
    ) -> None:
        self.base_temperature = base_temperature
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm
        self.mpc_threshold = mpc_threshold
        self.minimum_withdraw = minimum_withdraw

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ):
        batch = replay.sample(self.batch_size)

        # update phase Q-network
        target_values_critic = agent.all_target_q_value(  # type: ignore
            observation=batch.state(), action=batch.action()
        )
        q_values_1_critic, q_values_2_critic = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )

        # error is here
        # _ = agent.q_value(
        #     observation=batch.state(), action=pi  # type: ignore
        # )  # AttributeError: 'LifeLongRaviMPCAgent' object has no attribute 'next_state_1'
        # target_values = agent.all_target_q_value(observation=batch.state(), action=pi)  # type: ignore

        robust_q_value: Tuple[torch.Tensor] = target_values_critic["robust"]
        robust_q_value_group = torch.stack(robust_q_value, dim=-1)

        mpc_values = agent.infos()["predicted_next_state_robust"]
        mpc_values_group = torch.stack(mpc_values, dim=-1)
        (
            imitation_target,
            reconstruction_to_keep,
            all_batch_is_withdraw,
        ) = self.__determine_target_value_withdraw(
            next_state=batch.next_state(),
            robust_value_group=robust_q_value_group,
            mpc_value_group=mpc_values_group,
            min_before_withdraw=self.minimum_withdraw,
        )
        ####
        info = {"batch_is_withdraw": float(all_batch_is_withdraw)}
        # TODO: HANDLE when batch is all withdraw
        if not all_batch_is_withdraw:
            critic_loss_1 = F.mse_loss(
                q_values_1_critic[reconstruction_to_keep],
                imitation_target[reconstruction_to_keep],
            )
            critic_loss_2 = F.mse_loss(
                q_values_2_critic[reconstruction_to_keep],
                imitation_target[reconstruction_to_keep],
            )
            critic_loss = (critic_loss_1 + critic_loss_2) / 2

            optimizer_critic.zero_grad()
            critic_loss.backward()
            critic_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                agent.parameters_critic(), self.max_grad_norm * 1000
            )
            optimizer_critic.step()

            info["imitation_critic_loss"] = critic_loss.item()
            info["imimation_grad_norm_critic"] = critic_grad_norm.item()

            mass = agent.action(observation=batch.state())  # type: ignore
            pi = mass.rsample()
            log_pi = mass.log_prob(action=pi)
            q_values_1_critic_imitation, q_values_2_critic_imitation = agent.q_value(
                observation=batch.state(), action=pi  # type: ignore
            )
            imitation_target_policy = torch.min(
                q_values_1_critic_imitation, q_values_2_critic_imitation
            )
            policy_loss = (
                (self.base_temperature * log_pi) - imitation_target_policy
            ).mean()

            optimizer_actor.zero_grad()
            policy_loss.backward()
            actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                agent.parameters_actor(), self.max_grad_norm
            )
            optimizer_actor.step()

            info["imitation_policy_loss"] = policy_loss.item()
            info["imimation_grad_norm"] = actor_grad_norm.item()
        return info

    def __determine_target_value(
        self, robust_value_group: torch.Tensor, mpc_value_group: torch.Tensor
    ):
        target_value = torch.zeros_like(robust_value_group)
        B, E = mpc_value_group.size()
        imitation_target = torch.zeros(B, 1)

        """ If the reconstruction value of the next state is greater than the threshold, then the target
        value for imitation is the q value of the expert with the smallest reconstruction error.
        """

        # returns an tuple of indices and values
        reconstruction_mask = mpc_value_group < self.mpc_threshold

        # Computes the number of experts that have a reconstruction error greater than the threshold
        reconstuction_sum = reconstruction_mask.sum(dim=-1, keepdim=True)

        # Computes a mask that indicates all of the experts that have a reconstruction error greater than the threshold
        reconstruction_full_ood = reconstuction_sum == E

        # Find the index where all experts have a reconstruction error greater than the threshold
        index_reconstruction_full_ood = torch.where(reconstruction_full_ood == True)[0]

        # Do the opposite of the previous mask
        opposite_index_reconstruction_full_ood = torch.where(
            reconstruction_full_ood == False
        )[0]

        # Find the index of the minimum reconstruction error for each expert where each expert has a
        # reconstruction error greater than the threshold
        _, reconstruction_minimum_index = torch.min(mpc_value_group, dim=-1)

        # Replace the q value of the expert with the smallest reconstruction error with the target value if
        # each expert has a reconstruction error greater than the threshold
        imitation_target[index_reconstruction_full_ood] = torch.gather(
            robust_value_group[index_reconstruction_full_ood],
            dim=-1,
            index=reconstruction_minimum_index[index_reconstruction_full_ood].unsqueeze(
                dim=-1
            ),
        )

        masked_q_values_concat = robust_value_group[
            opposite_index_reconstruction_full_ood
        ].masked_fill(
            reconstruction_mask[opposite_index_reconstruction_full_ood],
            torch.finfo(robust_value_group.dtype).max,
        )
        imitation_masked_target, _ = torch.min(robust_value_group, dim=-1)

        imitation_target[
            opposite_index_reconstruction_full_ood
        ] = imitation_masked_target[opposite_index_reconstruction_full_ood]
        return imitation_target

    def __determine_target_value_withdraw(
        self,
        next_state: torch.Tensor,
        robust_value_group: torch.Tensor,
        mpc_value_group: torch.Tensor,
        min_before_withdraw: Optional[int] = None,
    ):
        B, F, E = mpc_value_group.size()
        duplicated_next_state = next_state.unsqueeze(dim=-1).repeat((1, 1, E))
        error_pc = torch.abs(duplicated_next_state - mpc_value_group).mean(dim=1)
        imitation_target = torch.zeros(B, 1, device=next_state.device)

        # returns an tuple of indices and values

        reconstruction_mask = error_pc < self.mpc_threshold
        # reconstruction_mask = error_pc > self.mpc_threshold

        # Computes the number of experts that have a reconstruction error greater than the threshold
        reconstuction_sum = reconstruction_mask.sum(dim=-1, keepdim=True)

        # # Computes a mask that indicates all of the experts that have a reconstruction error greater than the threshold
        # reconstruction_full_ood = reconstuction_sum == E

        # # Find the index where all experts have a reconstruction error greater than the threshold
        # index_reconstruction_full_ood = torch.where(reconstruction_full_ood == True)[0]

        # # Do the opposite of the previous mask
        # opposite_index_reconstruction_full_ood = torch.where(
        #     reconstruction_full_ood == False
        # )[0]

        # masked_q_values_concat = robust_value_group[
        #     opposite_index_reconstruction_full_ood
        # ].masked_fill(
        #     reconstruction_mask[opposite_index_reconstruction_full_ood],
        #     torch.finfo(robust_value_group.dtype).max,
        # )
        # imitation_masked_target, _ = torch.min(robust_value_group, dim=-1)

        # all_batch_is_withdraw = reconstruction_full_ood.sum() == B
        # # if not all_batch_is_withdraw:
        # #     all_batch_is_withdraw = reconstruction_full_ood.sum() <= min_before_withdraw
        # return imitation_target, reconstruction_full_ood, all_batch_is_withdraw

        # Computes a mask that indicates all of the experts that have a reconstruction error greater than the threshold
        reconstruction_full_ood = reconstuction_sum == 0

        # Comptute a mask that indicates all of the experts that have a reconstruction error less than the threshold
        reconstruction_to_keep = reconstuction_sum != 0

        # Find the index where all experts have a reconstruction error greater than the threshold
        index_reconstruction_full_ood = torch.where(reconstruction_full_ood == False)[0]

        # Do the opposite of the previous mask
        opposite_index_reconstruction_full_ood = torch.where(
            reconstruction_full_ood == True
        )[0]

        masked_q_values_concat = robust_value_group
        masked_q_values_concat[opposite_index_reconstruction_full_ood] = torch.finfo(
            robust_value_group.dtype
        ).max
        imitation_masked_target, _ = torch.min(robust_value_group, dim=-1)

        # all_batch_is_withdraw = reconstruction_full_ood.sum() == B
        # if not all_batch_is_withdraw:
        all_batch_is_withdraw = reconstruction_full_ood.sum() >= self.minimum_withdraw

        return (
            imitation_masked_target,
            reconstruction_to_keep,
            all_batch_is_withdraw.item(),
        )


class ImitationLearningMPCHydraPolicy:
    def __init__(
        self,
        base_temperature: float = 1.0,
        batch_size: int = 256,
        max_grad_norm: float = 0.5,
        mpc_threshold: float = 0.5,
        minimum_withdraw: Optional[int] = None,
    ) -> None:
        self.base_temperature = base_temperature
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm
        self.mpc_threshold = mpc_threshold
        self.minimum_withdraw = minimum_withdraw

    def update(
        self,
        replay: OffPolicyMemory,
        agent: IAgent[EnvObservation, EnvAction, Action],
        optimizer_actor: optim.Optimizer,
        optimizer_critic: optim.Optimizer,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ):
        batch = replay.sample(self.batch_size)

        # update phase Q-network
        target_values_critic = agent.all_target_q_value(  # type: ignore
            observation=batch.state(), action=batch.action()
        )
        q_values_1_critic, q_values_2_critic = agent.q_value(
            observation=batch.state(), action=batch.action()  # type: ignore
        )

        # error is here
        # _ = agent.q_value(
        #     observation=batch.state(), action=pi  # type: ignore
        # )  # AttributeError: 'LifeLongRaviMPCAgent' object has no attribute 'next_state_1'
        # target_values = agent.all_target_q_value(observation=batch.state(), action=pi)  # type: ignore

        robust_q_value: Tuple[torch.Tensor] = target_values_critic["robust_no_reg"]
        robust_q_value_group = torch.stack(robust_q_value, dim=-1)

        mpc_values = agent.infos()["predicted_next_state_robust"]
        mpc_values_group = torch.stack(mpc_values, dim=-1)
        (
            imitation_target,
            reconstruction_to_keep,
            all_batch_is_withdraw,
        ) = self.__determine_target_value_withdraw(
            next_state=batch.next_state(),
            robust_value_group=robust_q_value_group,
            mpc_value_group=mpc_values_group,
            min_before_withdraw=self.minimum_withdraw,
        )
        ####
        info = {"batch_is_withdraw": float(all_batch_is_withdraw)}
        # TODO: HANDLE when batch is all withdraw
        if not all_batch_is_withdraw:
            critic_loss_1 = F.mse_loss(
                q_values_1_critic[reconstruction_to_keep],
                imitation_target[reconstruction_to_keep],
            )
            critic_loss_2 = F.mse_loss(
                q_values_2_critic[reconstruction_to_keep],
                imitation_target[reconstruction_to_keep],
            )
            critic_loss = (critic_loss_1 + critic_loss_2) / 2

            optimizer_critic.zero_grad()
            critic_loss.backward()
            critic_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                agent.parameters_critic(), self.max_grad_norm
            )
            optimizer_critic.step()

            info["imitation_critic_loss"] = critic_loss.item()
            info["imimation_grad_norm_critic"] = critic_grad_norm.item()

            mass = agent.action(observation=batch.state())  # type: ignore
            pi = mass.rsample()
            log_pi = mass.log_prob(action=pi)
            q_values_1_critic_imitation, q_values_2_critic_imitation = agent.q_value(
                observation=batch.state(), action=pi  # type: ignore
            )
            imitation_target_policy = torch.min(
                q_values_1_critic_imitation, q_values_2_critic_imitation
            )
            policy_loss = (
                (self.base_temperature * log_pi) - imitation_target_policy
            ).mean()

            optimizer_actor.zero_grad()
            policy_loss.backward()
            actor_grad_norm = nn.utils.clip_grad_norm_(  # type: ignore
                agent.parameters_actor(), self.max_grad_norm
            )
            optimizer_actor.step()

            info["imitation_policy_loss"] = policy_loss.item()
            info["imimation_grad_norm"] = actor_grad_norm.item()
        return info

    def __determine_target_value_withdraw(
        self,
        next_state: torch.Tensor,
        robust_value_group: torch.Tensor,
        mpc_value_group: torch.Tensor,
        min_before_withdraw: Optional[int] = None,
    ):
        B, F, E = mpc_value_group.size()
        duplicated_next_state = next_state.unsqueeze(dim=-1).repeat((1, 1, E))
        error_pc = torch.abs(duplicated_next_state - mpc_value_group).mean(dim=1)
        imitation_target = torch.zeros(B, 1, device=next_state.device)

        # returns an tuple of indices and values

        reconstruction_mask = error_pc < self.mpc_threshold
        # reconstruction_mask = error_pc > self.mpc_threshold

        # Computes the number of experts that have a reconstruction error greater than the threshold
        reconstuction_sum = reconstruction_mask.sum(dim=-1, keepdim=True)

        # Computes a mask that indicates all of the experts that have a reconstruction error greater than the threshold
        reconstruction_full_ood = reconstuction_sum == 0

        # Comptute a mask that indicates all of the experts that have a reconstruction error less than the threshold
        reconstruction_to_keep = reconstuction_sum != 0

        # Find the index where all experts have a reconstruction error greater than the threshold
        index_reconstruction_full_ood = torch.where(reconstruction_full_ood == False)[0]

        # Do the opposite of the previous mask
        opposite_index_reconstruction_full_ood = torch.where(
            reconstruction_full_ood == True
        )[0]

        masked_q_values_concat = robust_value_group
        masked_q_values_concat[opposite_index_reconstruction_full_ood] = torch.finfo(
            robust_value_group.dtype
        ).max
        imitation_masked_target, _ = torch.min(robust_value_group, dim=-1)

        # all_batch_is_withdraw = reconstruction_full_ood.sum() == B
        # if not all_batch_is_withdraw:
        all_batch_is_withdraw = reconstruction_full_ood.sum() >= self.minimum_withdraw

        return (
            imitation_masked_target,
            reconstruction_to_keep,
            all_batch_is_withdraw.item(),
        )
