from copy import deepcopy
from typing import Iterable, NoReturn, Optional, Tuple, Union, Dict, Hashable, Any, List
from collections import defaultdict
import logging

import cherry as ch
import torch
import torch.nn as nn
from torch.nn import Parameter  # type: ignore
from algos import IAction, IAgent
from models import create_target_network, ModelHydraPredictiveCodingCritic

IMujocoAction = IAction[torch.Tensor]
IMujocoAgent = IAgent[torch.Tensor, torch.Tensor, torch.Tensor]


class MujocoMass(IMujocoAction):  # type: ignore
    """Mass for agents
    Args:
        policy (Distribution): torch distributions
    """

    def __init__(self, mean: torch.Tensor, log_std: torch.Tensor):
        std = log_std.exp()
        self.mean = mean
        self.std = std
        self.normal = torch.distributions.Normal(mean, std)

    def _get_action(self) -> torch.Tensor:
        self.x_t = self.normal.rsample()
        self.y_t = torch.tanh(self.x_t)
        return self.y_t

    def sample(self) -> torch.Tensor:
        action = self._get_action()
        return action.detach()

    def rsample(self) -> torch.Tensor:
        return self._get_action()

    def log_prob(self, action: torch.Tensor) -> torch.Tensor:  # Size [Batch, 1]
        log_prob = self.normal.log_prob(self.x_t)
        # Enforcing Action Bound
        log_prob -= torch.log((1 - action.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        return log_prob

    def best_action(self) -> torch.Tensor:
        return torch.tanh(self.mean)

    def infos(self) -> Dict[Hashable, Any]:
        return super().infos()


class MujocoSacAgent(IMujocoAgent):  # type: ignore
    """Agent class

    Args:
        actor (nn.Module): \pi_{\theta}
        q_value_network_1 (nn.Module): Q_{\psi}
        q_value_network_2 (nn.Module): Q_{\phi}
        device (torch.device): Explicit
        action_size (int): dimension of action. Defaults to `1`.
    """

    def __init__(
        self,
        actor: nn.Module,
        q_value_network_1: nn.Module,
        q_value_network_2: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        self.actor = actor

        self.q_value_network_1 = q_value_network_1
        self.q_value_network_2 = q_value_network_2
        self.target_q_value_network_1 = create_target_network(
            network=self.q_value_network_1
        )
        self.target_q_value_network_2 = create_target_network(
            network=self.q_value_network_2
        )
        self.actor = actor.to(device)
        self.q_value_network_1 = self.q_value_network_1.to(device)
        self.q_value_network_2 = self.q_value_network_2.to(device)
        self.target_q_value_network_1 = self.target_q_value_network_1.to(device)
        self.target_q_value_network_2 = self.target_q_value_network_2.to(device)
        self._eval_model(model=self.target_q_value_network_1)
        self._eval_model(model=self.target_q_value_network_2)

        self.device = device
        self.action_size = action_size

    def action(self, observation: torch.Tensor):
        mean, log_std = self.actor(observation)
        mujoco_mass = MujocoMass(mean=mean, log_std=log_std)
        return mujoco_mass

    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1 = self.q_value_network_1(observation, action)
        q_value_2 = self.q_value_network_2(observation, action)
        return q_value_1, q_value_2

    def target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        target_q_value_1 = self.target_q_value_network_1(observation, action)
        target_q_value_2 = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def save(self, path: str) -> NoReturn:  # type: ignore
        torch.save(
            {
                "Actor": self.actor.state_dict(),
                "Q_values_1": self.q_value_network_1.state_dict(),
                "Q_values_2": self.q_value_network_2.state_dict(),
                "Target_Q_values_1": self.target_q_value_network_1.state_dict(),
                "Target_Q_values_2": self.target_q_value_network_2.state_dict(),
            },
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        self.actor.load_state_dict(checkpoint["Actor"])
        self.q_value_network_1.load_state_dict(checkpoint["Q_values_1"])
        self.q_value_network_2.load_state_dict(checkpoint["Q_values_2"])
        self.target_q_value_network_1.load_state_dict(checkpoint["Target_Q_values_1"])
        self.target_q_value_network_2.load_state_dict(checkpoint["Target_Q_values_2"])

    def train(self) -> NoReturn:  # type: ignore
        self._train_model(self.actor)
        self._train_model(self.q_value_network_1)
        self._train_model(self.q_value_network_2)

    def _train_model(self, model: nn.Module):
        for param in model.parameters():
            param.requires_grad = True
        model.train()

    def eval(self) -> NoReturn:  # type: ignore
        self._eval_model(self.actor)
        self._eval_model(self.q_value_network_1)
        self._eval_model(self.q_value_network_2)

    def _eval_model(self, model: nn.Module):
        for param in model.parameters():
            param.requires_grad = False
        model.eval()

    def clone(self) -> NoReturn:
        ...

    def parameters(self) -> Union[Parameter, Iterable[Parameter]]:
        parameters = (
            list(self.actor.parameters())
            + list(self.q_value_network_1.parameters())
            + list(self.q_value_network_2.parameters())
            + list(self.target_q_value_network_1.parameters())
            + list(self.target_q_value_network_2.parameters())
        )
        return parameters

    def parameters_actor(self) -> Union[Parameter, Iterable[Parameter]]:
        return self.actor.parameters()

    def parameters_critic(self) -> Union[Parameter, Iterable[Parameter]]:
        parameters = list(self.q_value_network_1.parameters()) + list(
            self.q_value_network_2.parameters()
        )
        return parameters

    def soft_update(self, tau: float):
        ch.models.polyak_average(
            source=self.target_q_value_network_1,
            target=self.q_value_network_1,
            alpha=tau,
        )
        ch.models.polyak_average(
            source=self.target_q_value_network_2,
            target=self.q_value_network_2,
            alpha=tau,
        )


class MujocoSacMPC(MujocoSacAgent):
    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1, self.next_state_1 = self.q_value_network_1(observation, action)
        q_value_2, self.next_state_2 = self.q_value_network_2(observation, action)
        return q_value_1, q_value_2

    def target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (target_q_value_1, _) = self.target_q_value_network_1(observation, action)
        (target_q_value_2, _) = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def infos(self) -> Dict[Hashable, Any]:
        return {
            "predicted_next_state_1": self.next_state_1,
            "predicted_next_state_2": self.next_state_2,
        }


class MujocoSacMPCHydra(MujocoSacAgent):
    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1, self.next_state_1, self.q_value_no_reg_1 = self.q_value_network_1(
            observation, action
        )
        q_value_2, self.next_state_2, self.q_value_no_reg_2 = self.q_value_network_2(
            observation, action
        )
        return q_value_1, q_value_2

    def target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            target_q_value_1,
            _,
            self.target_q_value_no_reg_1,
        ) = self.target_q_value_network_1(observation, action)
        (
            target_q_value_2,
            _,
            self.target_q_value_no_reg_2,
        ) = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def infos(self) -> Dict[Hashable, Any]:
        infos = {}
        if hasattr(self, "next_state_1"):
            infos["predicted_next_state_1"] = self.next_state_1
        if hasattr(self, "next_state_2"):
            infos["predicted_next_state_2"] = self.next_state_2

        if hasattr(self, "q_value_no_reg_1"):
            infos["q_no_reg_value_1"] = self.q_value_no_reg_1
        if hasattr(self, "q_value_no_reg_2"):
            infos["q_no_reg_value_2"] = self.q_value_no_reg_2

        if hasattr(self, "target_q_value_no_reg_1"):
            infos["target_q_no_reg_value_1"] = self.target_q_value_no_reg_1
        if hasattr(self, "target_q_value_no_reg_2"):
            infos["target_q_no_reg_value_2"] = self.target_q_value_no_reg_2
        return infos  # type: ignore


class MujocoSacMPCVarianceEnsembleHydra(MujocoSacAgent):
    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            q_value_1,
            self.next_state_1,
            self.variance_next_state_1,
            self.infos_ensemble_variance_network_1,
            self.q_value_no_reg_1,
        ) = self.q_value_network_1(observation, action)

        (
            q_value_2,
            self.next_state_2,
            self.variance_next_state_2,
            self.infos_ensemble_variance_network_2,
            self.q_value_no_reg_2,
        ) = self.q_value_network_2(observation, action)
        return q_value_1, q_value_2

    def target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            target_q_value_1,
            _,
            _,
            _,
            self.target_q_value_no_reg_1,
        ) = self.target_q_value_network_1(observation, action)
        (
            target_q_value_2,
            _,
            _,
            _,
            self.target_q_value_no_reg_2,
        ) = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def infos(self) -> Dict[Hashable, Any]:
        infos = {}
        if hasattr(self, "next_state_1"):
            infos["predicted_next_state_1"] = self.next_state_1
        if hasattr(self, "next_state_2"):
            infos["predicted_next_state_2"] = self.next_state_2

        if hasattr(self, "q_value_no_reg_1"):
            infos["q_no_reg_value_1"] = self.q_value_no_reg_1
        if hasattr(self, "q_value_no_reg_2"):
            infos["q_no_reg_value_2"] = self.q_value_no_reg_2

        if hasattr(self, "target_q_value_no_reg_1"):
            infos["target_q_no_reg_value_1"] = self.target_q_value_no_reg_1
        if hasattr(self, "target_q_value_no_reg_2"):
            infos["target_q_no_reg_value_2"] = self.target_q_value_no_reg_2

        if hasattr(self, "variance_next_state_1"):
            infos["variance_next_state_1"] = self.variance_next_state_1

        if hasattr(self, "variance_next_state_2"):
            infos["variance_next_state_2"] = self.variance_next_state_2

        if hasattr(self, "infos_ensemble_variance_network_1"):
            infos[
                "infos_ensemble_variance_network_1"
            ] = self.infos_ensemble_variance_network_1

        if hasattr(self, "infos_ensemble_variance_network_2"):
            infos[
                "infos_ensemble_variance_network_2"
            ] = self.infos_ensemble_variance_network_2
        return infos  # type: ignore


class RaviAgent(MujocoSacAgent):
    def __init__(
        self,
        actor: nn.Module,
        q_value_network_1: nn.Module,
        q_value_network_2: nn.Module,
        q_robust_1: nn.Module,
        q_robust_2: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        self.actor = actor

        self.q_value_network_1 = q_value_network_1
        self.q_value_network_2 = q_value_network_2
        self.target_q_value_network_1 = create_target_network(
            network=self.q_value_network_1
        )
        self.target_q_value_network_2 = create_target_network(
            network=self.q_value_network_2
        )
        self.q_robust_1 = q_robust_1
        self.q_robust_2 = q_robust_2

        self.actor = actor.to(device)
        self.q_value_network_1 = self.q_value_network_1.to(device)
        self.q_value_network_2 = self.q_value_network_2.to(device)
        self.target_q_value_network_1 = self.target_q_value_network_1.to(device)
        self.target_q_value_network_2 = self.target_q_value_network_2.to(device)
        self.q_robust_1 = self.q_robust_1.to(device)
        self.q_robust_2 = self.q_robust_2.to(device)

        self._eval_model(model=self.target_q_value_network_1)
        self._eval_model(model=self.target_q_value_network_2)
        self._eval_model(model=self.q_robust_1)
        self._eval_model(model=self.q_robust_2)

        self.device = device
        self.action_size = action_size

    def set_robust_nets(
        self, q_robust_1: nn.Module, q_robust_2: nn.Module, *args, **kwargs
    ):
        self.q_robust_1 = q_robust_1
        self.q_robust_2 = q_robust_2
        self.q_robust_1 = self.q_robust_1.to(self.device)
        self.q_robust_2 = self.q_robust_2.to(self.device)
        self._eval_model(model=self.q_robust_1)
        self._eval_model(model=self.q_robust_2)

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        target_q_value_1 = self.target_q_value_network_1(observation, action)
        target_q_value_2 = self.target_q_value_network_2(observation, action)
        robust_q_value_1 = self.q_robust_1(observation, action)
        robust_q_value_2 = self.q_robust_2(observation, action)
        return target_q_value_1, target_q_value_2, robust_q_value_1, robust_q_value_2

    def save(self, path: str) -> NoReturn:  # type: ignore
        torch.save(
            {
                "Actor": self.actor.state_dict(),
                "Q_values_1": self.q_value_network_1.state_dict(),
                "Q_values_2": self.q_value_network_2.state_dict(),
                "Target_Q_values_1": self.target_q_value_network_1.state_dict(),
                "Target_Q_values_2": self.target_q_value_network_2.state_dict(),
                "Robust_Q_1": self.q_robust_1.state_dict(),
                "Robust_Q_2": self.q_robust_2.state_dict(),
            },
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        self.actor.load_state_dict(checkpoint["Actor"])
        self.q_value_network_1.load_state_dict(checkpoint["Q_values_1"])
        self.q_value_network_2.load_state_dict(checkpoint["Q_values_2"])
        self.target_q_value_network_1.load_state_dict(checkpoint["Target_Q_values_1"])
        self.target_q_value_network_2.load_state_dict(checkpoint["Target_Q_values_2"])
        self.q_robust_1.load_state_dict(checkpoint["Robust_Q_1"])
        self.q_robust_2.load_state_dict(checkpoint["Robust_Q_2"])


class LifeLongRaviAgent(RaviAgent):
    def __init__(
        self,
        actor: nn.Module,
        q_value_network_1: nn.Module,
        q_value_network_2: nn.Module,
        q_robust_1: nn.Module,
        q_robust_2: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(
            actor,
            q_value_network_1,
            q_value_network_2,
            q_robust_1,
            q_robust_2,
            device,
            action_size,
        )
        self.robust_q_nets: List[nn.Module] = []

    def set_robust_nets(self, q_robust_1: nn.Module, q_robust_2: nn.Module):  # type: ignore
        q_robust_1 = q_robust_1.to(device=self.device)
        q_robust_2 = q_robust_2.to(device=self.device)

        self._eval_model(q_robust_1)
        self._eval_model(q_robust_2)
        self.robust_q_nets.append(q_robust_1)
        self.robust_q_nets.append(q_robust_2)

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        target_q_value_1 = self.target_q_value_network_1(observation, action)
        target_q_value_2 = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def all_target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Dict[str, Tuple[torch.Tensor, ...]]:
        target_q_value_1 = self.target_q_value_network_1(observation, action)
        target_q_value_2 = self.target_q_value_network_2(observation, action)
        robust_q_values = []
        for q_robust in self.robust_q_nets:
            q_value = q_robust(observation, action)
            robust_q_values.append(q_value)
        return {
            "target": (target_q_value_1, target_q_value_2),
            "robust": tuple(robust_q_values),
        }

    def save(self, path: str) -> NoReturn:  # type: ignore
        robust_dict = {
            f"robust_net_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        classic_nets = {
            "Actor": self.actor.state_dict(),
            "Q_values_1": self.q_value_network_1.state_dict(),
            "Q_values_2": self.q_value_network_2.state_dict(),
            "Target_Q_values_1": self.target_q_value_network_1.state_dict(),
            "Target_Q_values_2": self.target_q_value_network_2.state_dict(),
            "Robust_Q_1": self.q_robust_1.state_dict(),
            "Robust_Q_2": self.q_robust_2.state_dict(),
        }

        classic_nets.update(robust_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        self.actor.load_state_dict(checkpoint["Actor"])
        self.q_value_network_1.load_state_dict(checkpoint["Q_values_1"])
        self.q_value_network_2.load_state_dict(checkpoint["Q_values_2"])
        self.target_q_value_network_1.load_state_dict(checkpoint["Target_Q_values_1"])
        self.target_q_value_network_2.load_state_dict(checkpoint["Target_Q_values_2"])
        self.q_robust_1.load_state_dict(checkpoint["Robust_Q_1"])
        self.q_robust_2.load_state_dict(checkpoint["Robust_Q_2"])
        for k, v in checkpoint.items():
            if k.startswith("robust"):
                q_robust_copy = deepcopy(
                    self.q_robust_1
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)


class LifeLongRaviMPCAgent(RaviAgent):
    def __init__(
        self,
        actor: nn.Module,
        q_value_network_1: nn.Module,
        q_value_network_2: nn.Module,
        q_robust_1: nn.Module,
        q_robust_2: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(
            actor,
            q_value_network_1,
            q_value_network_2,
            q_robust_1,
            q_robust_2,
            device,
            action_size,
        )
        self.robust_q_nets: List[nn.Module] = []
        self.next_state_values: Optional[List[torch.Tensor]] = None

    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1, self.next_state_1 = self.q_value_network_1(observation, action)
        q_value_2, self.next_state_2 = self.q_value_network_2(observation, action)
        return q_value_1, q_value_2

    def set_robust_nets(self, q_robust_1: nn.Module, q_robust_2: nn.Module):  # type: ignore
        q_robust_1 = q_robust_1.to(device=self.device)
        q_robust_2 = q_robust_2.to(device=self.device)

        self._eval_model(q_robust_1)
        self._eval_model(q_robust_2)
        self.robust_q_nets.append(q_robust_1)
        self.robust_q_nets.append(q_robust_2)

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        target_q_value_1, _ = self.target_q_value_network_1(observation, action)
        target_q_value_2, _ = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def all_target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Dict[str, Tuple[torch.Tensor, ...]]:
        target_q_value_1, _ = self.target_q_value_network_1(observation, action)
        target_q_value_2, _ = self.target_q_value_network_2(observation, action)
        robust_q_values = []
        next_state_values = []
        for q_robust in self.robust_q_nets:
            q_value, next_state = q_robust(observation, action)
            robust_q_values.append(q_value)
            next_state_values.append(next_state)
        self.next_state_values = next_state_values
        return {
            "target": (target_q_value_1, target_q_value_2),
            "robust": tuple(robust_q_values),
        }

    def infos(self) -> Dict[Hashable, Any]:
        infos = {}
        if hasattr(self, "next_state_1"):
            infos["predicted_next_state_1"] = self.next_state_1
        if hasattr(self, "next_state_2"):
            infos["predicted_next_state_2"] = self.next_state_2
        if self.next_state_values is not None:
            infos["predicted_next_state_robust"] = self.next_state_values
        return infos  # type: ignore

    def save(self, path: str) -> NoReturn:  # type: ignore
        robust_dict = {
            f"robust_net_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        classic_nets = {
            "Actor": self.actor.state_dict(),
            "Q_values_1": self.q_value_network_1.state_dict(),
            "Q_values_2": self.q_value_network_2.state_dict(),
            "Target_Q_values_1": self.target_q_value_network_1.state_dict(),
            "Target_Q_values_2": self.target_q_value_network_2.state_dict(),
            "Robust_Q_1": self.q_robust_1.state_dict(),
            "Robust_Q_2": self.q_robust_2.state_dict(),
        }

        classic_nets.update(robust_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        self.actor.load_state_dict(checkpoint["Actor"])
        self.q_value_network_1.load_state_dict(checkpoint["Q_values_1"])
        self.q_value_network_2.load_state_dict(checkpoint["Q_values_2"])
        self.target_q_value_network_1.load_state_dict(checkpoint["Target_Q_values_1"])
        self.target_q_value_network_2.load_state_dict(checkpoint["Target_Q_values_2"])
        self.q_robust_1.load_state_dict(checkpoint["Robust_Q_1"])
        self.q_robust_2.load_state_dict(checkpoint["Robust_Q_2"])
        for k, v in checkpoint.items():
            if k.startswith("robust"):
                q_robust_copy = deepcopy(
                    self.q_robust_1
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)


class LifeLongRaviMPCHydraAgent(RaviAgent):
    def __init__(
        self,
        actor: nn.Module,
        q_value_network_1: nn.Module,
        q_value_network_2: nn.Module,
        q_robust_1: nn.Module,
        q_robust_2: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(
            actor,
            q_value_network_1,
            q_value_network_2,
            q_robust_1,
            q_robust_2,
            device,
            action_size,
        )
        self.robust_q_nets: List[nn.Module] = []
        self.next_state_values: Optional[List[torch.Tensor]] = None

    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1, self.next_state_1, self.q_no_reg_value_1 = self.q_value_network_1(
            observation, action
        )
        q_value_2, self.next_state_2, self.q_no_reg_value_2 = self.q_value_network_2(
            observation, action
        )
        return q_value_1, q_value_2

    def set_robust_nets(self, q_robust_1: nn.Module, q_robust_2: nn.Module):
        q_robust_1 = q_robust_1.to(device=self.device)
        q_robust_2 = q_robust_2.to(device=self.device)

        self._eval_model(q_robust_1)
        self._eval_model(q_robust_2)
        self.robust_q_nets.append(q_robust_1)
        self.robust_q_nets.append(q_robust_2)

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        target_q_value_1, _, _ = self.target_q_value_network_1(observation, action)
        target_q_value_2, _, _ = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2

    def all_target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Dict[str, Tuple[torch.Tensor, ...]]:
        target_q_value_1, _ = self.target_q_value_network_1(observation, action)
        target_q_value_2, _ = self.target_q_value_network_2(observation, action)
        robust_q_values = []
        robust_q_no_reg_values = []
        next_state_values = []
        for q_robust in self.robust_q_nets:
            q_value, next_state, q_no_reg_value = q_robust(observation, action)
            robust_q_values.append(q_value)
            next_state_values.append(next_state)
            robust_q_no_reg_values.append(q_no_reg_value)
        self.next_state_values = next_state_values
        return {
            "target": (target_q_value_1, target_q_value_2),
            "robust": tuple(robust_q_values),
            "robust_no_reg": tuple(robust_q_no_reg_values),
        }

    def infos(self) -> Dict[Hashable, Any]:
        infos = {}
        if hasattr(self, "next_state_1"):
            infos["predicted_next_state_1"] = self.next_state_1
        if hasattr(self, "next_state_2"):
            infos["predicted_next_state_2"] = self.next_state_2
        if self.next_state_values is not None:
            infos["predicted_next_state_robust"] = self.next_state_values
        return infos

    def save(self, path: str) -> NoReturn:  # type: ignore
        robust_dict = {
            f"robust_net_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        classic_nets = {
            "Actor": self.actor.state_dict(),
            "Q_values_1": self.q_value_network_1.state_dict(),
            "Q_values_2": self.q_value_network_2.state_dict(),
            "Target_Q_values_1": self.target_q_value_network_1.state_dict(),
            "Target_Q_values_2": self.target_q_value_network_2.state_dict(),
            "Robust_Q_1": self.q_robust_1.state_dict(),
            "Robust_Q_2": self.q_robust_2.state_dict(),
        }

        classic_nets.update(robust_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        self.actor.load_state_dict(checkpoint["Actor"])
        self.q_value_network_1.load_state_dict(checkpoint["Q_values_1"])
        self.q_value_network_2.load_state_dict(checkpoint["Q_values_2"])
        self.target_q_value_network_1.load_state_dict(checkpoint["Target_Q_values_1"])
        self.target_q_value_network_2.load_state_dict(checkpoint["Target_Q_values_2"])
        self.q_robust_1.load_state_dict(checkpoint["Robust_Q_1"])
        self.q_robust_2.load_state_dict(checkpoint["Robust_Q_2"])
        for k, v in checkpoint.items():
            if k.startswith("robust"):
                q_robust_copy = deepcopy(
                    self.q_robust_1
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)


class PessimistExpertAgent(RaviAgent):
    def __init__(
        self,
        actor: nn.Module,
        critic: nn.Module,
        device: torch.device,
        action_size: int = 1,
        threshold_pc: float = 0.5,
        monitor: bool = False,
    ):
        self.threshold_pc = threshold_pc
        self.robust_q_nets: List[nn.Module] = []
        self.robust_actor_nets: List[nn.Module] = []
        self.next_state_values: Optional[List[torch.Tensor]] = None
        # TODO: Is not valid
        self.observation_prediction: Optional[torch.Tensor] = None
        self.action_size = action_size
        # for loading neural networks
        self.actor = actor.to(device)
        self.critic = critic.to(device)

        self.device = device
        self.monitor = monitor
        self.info_action = defaultdict(int)
        self.info_pc_error = defaultdict(list)

    def action(self, observation: torch.Tensor):
        """Select the action associated to the pessimiest expert (lowest q value).
        If the pessimist expert is not confident enough
        (error between the previous predicted  next state and ),
        the action is sampled from the pool of experts.
        If all the experts are not confident enough, the action is selected with the
        most confident expert.
        """

        # Compute the action for all the experts
        list_action_mass = []
        list_action_mean = []
        list_action_log_std = []
        for actor in self.robust_actor_nets:
            mean, log_std = actor(observation)
            list_action_mean.append(mean)
            list_action_log_std.append(log_std)
            mujoco_mass = MujocoMass(mean=mean, log_std=log_std)
            list_action_mass.append(mujoco_mass)

        tensor_action_mean = torch.stack(list_action_mean, dim=1)
        tensor_action_log_std = torch.stack(list_action_log_std, dim=1)

        # Compute the q value for all the experts and the next state prediction
        list_q_value_min = []
        list_pc_value_mean = []
        # list_mask = []
        for idx_robust, policy_mass in zip(
            range(0, len(self.robust_q_nets), 2), list_action_mass
        ):
            q_value_1, pc_value_1 = self.robust_q_nets[idx_robust](
                observation, policy_mass.best_action()
            )
            q_value_2, pc_value_2 = self.robust_q_nets[idx_robust + 1](
                observation, policy_mass.best_action()
            )
            q_value_step = torch.stack([q_value_1, q_value_2], dim=1)
            q_value_min, index_q_value_min = torch.min(q_value_step, dim=1)
            pc_value_step = torch.stack([pc_value_1, pc_value_2], dim=1)
            pc_value_min_q_value = pc_value_step.gather(
                dim=1, index=index_q_value_min.unsqueeze(dim=-1)
            )
            pc_value_min_q_value = pc_value_min_q_value.view(-1, 1)
            # q_value_min = torch.min(q_value_1, q_value_2)
            # pc_value_mean = (pc_value_1 + pc_value_2) / 2

            list_q_value_min.append(q_value_min)
            # list_pc_value_mean.append(pc_value_mean)
            list_pc_value_mean.append(pc_value_min_q_value)

        # convert to tensor

        tensor_pc_value_mean = torch.stack(list_pc_value_mean, dim=-1)  # B, F, E
        tensor_q_value_min = torch.cat(list_q_value_min, dim=-1)  # B, E

        # get the number of experts
        _, _, E = tensor_pc_value_mean.size()
        (
            masked_q_value_pc_min,
            index_masked_q_value_min_pc_closest,
        ) = self._compute_pessimist_q_value(
            observation=observation,
            tensor_q_value_min=tensor_q_value_min,
            observation_prediction=self.observation_prediction,
            threshold_pc=self.threshold_pc,
            tensor_pc_value_mean=tensor_pc_value_mean,
        )

        action_mean_selected = torch.index_select(
            tensor_action_mean, dim=1, index=index_masked_q_value_min_pc_closest
        )
        action_log_std_selected = torch.index_select(
            tensor_action_log_std, dim=1, index=index_masked_q_value_min_pc_closest
        )

        self.observation_prediction = tensor_pc_value_mean
        mass_robust = MujocoMass(
            mean=action_mean_selected.squeeze(dim=1),
            log_std=action_log_std_selected.squeeze(dim=1),
        )
        return mass_robust

    def _compute_pessimist_q_value(
        self,
        observation: torch.Tensor,
        tensor_pc_value_mean: torch.Tensor,
        tensor_q_value_min: torch.Tensor,
        observation_prediction: Optional[torch.Tensor],
        threshold_pc: float,
    ):
        # get the number of experts
        _, _, E = tensor_pc_value_mean.size()

        # Compute the expert with the lowest previosly predicted next state error

        # Add dimension to observation to be able to repeat it and we repeat it
        # Expert_number_times times
        duplicated_observation = observation.unsqueeze(dim=-1).repeat(
            (1, 1, E)
        )  # B, F, E
        observation_prediction = (
            duplicated_observation
            if observation_prediction is None
            else observation_prediction
        )

        # Compute the error between the previous predicted next state and the current
        # predicted next state
        tensor_error_pc_value_mean = torch.abs(
            observation_prediction - duplicated_observation
        ).mean(
            dim=1
        )  # B, F , E -> B, E

        closest_idx_error = torch.argmin(
            tensor_error_pc_value_mean, dim=-1
        )  # B, E -> B Argmin on expert dimension
        q_value_expert_closest = tensor_q_value_min.gather(
            dim=-1, index=closest_idx_error.unsqueeze(dim=-1)
        )

        mask_fill = tensor_error_pc_value_mean >= threshold_pc

        # Count the number of experts that are masked
        nb_all_masked = mask_fill.sum(dim=-1).view(-1, 1)
        if self.monitor:
            self._update_info(
                nb_all_masked=nb_all_masked,
                tensor_error_pc_value_mean=tensor_error_pc_value_mean,
            )

        masked_q_values_concat = tensor_q_value_min.squeeze(dim=1).clone()
        masked_q_values_concat = masked_q_values_concat.masked_fill(
            mask_fill, float("inf")
        )

        masked_q_value_pc_min, index_masked_q_value_pc_min = torch.min(
            masked_q_values_concat, dim=-1
        )
        all_masked = torch.isinf(masked_q_value_pc_min)

        index_all_masked = torch.where(all_masked == True)[0]

        masked_q_value_pc_min = masked_q_value_pc_min.scatter(
            dim=0,
            index=index_all_masked,
            src=q_value_expert_closest.squeeze(dim=-1)[index_all_masked],
        )
        index_masked_q_value_min_pc_closest = index_masked_q_value_pc_min.scatter(
            dim=0, index=index_all_masked, src=closest_idx_error[index_all_masked]
        )
        return masked_q_value_pc_min, index_masked_q_value_min_pc_closest

    def reset_memory(self):
        self.observation_prediction = None

    def _update_info(
        self, nb_all_masked: torch.Tensor, tensor_error_pc_value_mean: torch.Tensor
    ):
        # tensor_error_pc_value_mean  (B, E)
        B, _ = nb_all_masked.size()
        for obs_nb in range(B):
            nb_masked = nb_all_masked[0].item()
            self.info_action[nb_masked] += 1
            pc_error = tensor_error_pc_value_mean[obs_nb].tolist()
            for obs_dim, error in enumerate(pc_error):
                self.info_pc_error[obs_dim].append(error)

    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # q_value_1, self.next_state_1 = self.q_value_network_1(observation, action)
        # q_value_2, self.next_state_2 = self.q_value_network_2(observation, action)
        # return q_value_1, q_value_2
        raise NotImplementedError

    def set_robust_nets(  # type: ignore
        self, q_robust_1: nn.Module, q_robust_2: nn.Module, actor_robust: nn.Module
    ):
        q_robust_1 = q_robust_1.to(device=self.device)
        q_robust_2 = q_robust_2.to(device=self.device)
        actor_robust = actor_robust.to(device=self.device)

        self._eval_model(q_robust_1)
        self._eval_model(q_robust_2)
        self._eval_model(actor_robust)
        self.robust_q_nets.append(q_robust_1)
        self.robust_q_nets.append(q_robust_2)
        self.robust_actor_nets.append(actor_robust)
        self.reset_memory()

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        # target_q_value_1, _ = self.target_q_value_network_1(observation, action)
        # target_q_value_2, _ = self.target_q_value_network_2(observation, action)
        # return target_q_value_1, target_q_value_2
        raise NotImplementedError

    def all_target_q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Dict[str, Tuple[torch.Tensor, ...]]:
        # target_q_value_1, _ = self.target_q_value_network_1(observation, action)
        # target_q_value_2, _ = self.target_q_value_network_2(observation, action)
        # robust_q_values = []
        # next_state_values = []
        # for q_robust in self.robust_q_nets:
        #     q_value, next_state = q_robust(observation, action)
        #     robust_q_values.append(q_value)
        #     next_state_values.append(next_state)
        # self.next_state_values = next_state_values
        # return {
        #     "target": (target_q_value_1, target_q_value_2),
        #     "robust": tuple(robust_q_values),
        # }
        raise NotImplementedError

    def infos(self) -> Dict[Hashable, Any]:
        infos: Dict[Hashable, Any] = {
            "nb_masked": self.info_action,
            "pc_error": self.info_pc_error,
        }
        return infos

    def save(self, path: str) -> NoReturn:  # type: ignore
        robust_dict_actor = {
            f"robust_actor_{i}": net.state_dict()
            for i, net in enumerate(self.robust_actor_nets)
        }
        robust_dict_critic = {
            f"robust_critic_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        classic_nets = {}

        classic_nets.update(robust_dict_actor)
        classic_nets.update(robust_dict_critic)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)
        # print(checkpoint.keys())
        # self.actor.load_state_dict(checkpoint["Actor"])
        # self.critic.load_state_dict(checkpoint["Q_values_1"])
        for k, v in checkpoint.items():
            if k.startswith("robust_critic"):
                q_robust_copy = deepcopy(
                    self.critic
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)
                logging.info(f"Loaded critic {k}")

            if k.startswith("robust_actor"):
                actor_robust_copy = deepcopy(self.actor)
                actor_robust_copy.load_state_dict(v)
                self.robust_actor_nets.append(actor_robust_copy)
                logging.info(f"Loaded actor {k}")

    def parameters(self) -> Union[Parameter, Iterable[Parameter]]:
        raise NotImplementedError

    def parameters_actor(self) -> Union[Parameter, Iterable[Parameter]]:
        raise NotImplementedError

    def parameters_critic(self) -> Union[Parameter, Iterable[Parameter]]:
        raise NotImplementedError

    def train(self) -> NoReturn:  # type: ignore
        # TODO: TO fix
        for actor in self.robust_actor_nets:
            self._train_model(actor)
        for critic in self.robust_q_nets:
            self._train_model(critic)

    def eval(self) -> NoReturn:  # type: ignore
        for actor in self.robust_actor_nets:
            self._eval_model(actor)
        for critic in self.robust_q_nets:
            self._eval_model(critic)


class PessimistExpertAgentDefaultAdaptativeThreshold(PessimistExpertAgent):
    def __init__(
        self,
        actor: nn.Module,
        critic: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(actor, critic, device, action_size)
        self.threshold_pc: torch.Tensor = torch.tensor([], device=self.device)  # type: ignore
        self.threshold_list = []

    def _compute_pessimist_q_value(
        self,
        observation: torch.Tensor,
        tensor_pc_value_mean: torch.Tensor,
        tensor_q_value_min: torch.Tensor,  # B, E
        observation_prediction: Optional[torch.Tensor],
        threshold_pc: float,
    ):
        DEFAULT_EXPERT = 0
        # get the number of experts
        _, _, E = tensor_pc_value_mean.size()

        # Compute the expert with the lowest previosly predicted next state error

        # Add dimension to observation to be able to repeat it and we repeat it
        # Expert_number_times times
        duplicated_observation = observation.unsqueeze(dim=-1).repeat(
            (1, 1, E)
        )  # B, F, E
        observation_prediction = (
            duplicated_observation
            if observation_prediction is None
            else observation_prediction
        )

        # Compute the error between the previous predicted next state and the current
        # predicted next state
        tensor_error_pc_value_mean = torch.abs(
            observation_prediction - duplicated_observation
        ).mean(
            dim=1
        )  # B, F , E -> B, E

        closest_idx_error = torch.argmin(
            tensor_error_pc_value_mean, dim=-1
        )  # B, E -> B Argmin on expert dimension

        default_idx = torch.ones_like(closest_idx_error) * DEFAULT_EXPERT
        q_value_expert_default = tensor_q_value_min[:, DEFAULT_EXPERT]  # B

        mask_fill = tensor_error_pc_value_mean >= threshold_pc

        # Count the number of experts that are masked
        nb_all_masked = mask_fill.sum(dim=-1).view(-1, 1)
        if self.monitor:
            self._update_info(
                nb_all_masked=nb_all_masked,
                tensor_error_pc_value_mean=tensor_error_pc_value_mean,
            )

        masked_q_values_concat = tensor_q_value_min.squeeze(dim=1).clone()
        masked_q_values_concat = masked_q_values_concat.masked_fill(
            mask_fill, float("inf")
        )

        masked_q_value_pc_min, index_masked_q_value_pc_min = torch.min(
            masked_q_values_concat, dim=-1
        )
        all_masked = torch.isinf(masked_q_value_pc_min)

        index_all_masked = torch.where(all_masked == True)[0]  # noqa: E712

        masked_q_value_pc_min = masked_q_value_pc_min.scatter(
            dim=0,
            index=index_all_masked,
            src=q_value_expert_default[index_all_masked],
        )
        index_masked_q_value_min_default = index_masked_q_value_pc_min.scatter(
            dim=0, index=index_all_masked, src=default_idx[index_all_masked]
        )
        return masked_q_value_pc_min, index_masked_q_value_min_default

    def append_last_threshold(self, last_threshold):
        self.threshold_list.append(last_threshold)
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def set_last_threshold(self, last_threshold):
        self.threshold_list[-1] = last_threshold
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def save(self, path: str):
        robust_dict_actor = {
            f"robust_actor_{i}": net.state_dict()
            for i, net in enumerate(self.robust_actor_nets)
        }
        robust_dict_critic = {
            f"robust_critic_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        treshold_dict = {"threshold_pc": self.threshold_pc}

        classic_nets = {}

        classic_nets.update(robust_dict_actor)
        classic_nets.update(robust_dict_critic)
        classic_nets.update(treshold_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)

        for k, v in checkpoint.items():
            if k.startswith("robust_critic"):
                q_robust_copy = deepcopy(
                    self.critic
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)
                logging.info(f"Loaded critic {k}")

            if k.startswith("robust_actor"):
                actor_robust_copy = deepcopy(self.actor)
                actor_robust_copy.load_state_dict(v)
                self.robust_actor_nets.append(actor_robust_copy)
                logging.info(f"Loaded actor {k}")

            if k.startswith("threshold_pc"):
                self.threshold_pc = v
                self.threshold_list = v.tolist()
                logging.info(f"Loaded threshold {k}")


class PessimistExpertAgentHydra(PessimistExpertAgent):
    def __init__(
        self,
        actor: nn.Module,
        critic: nn.Module,
        device: torch.device,
        action_size: int = 1,
        threshold_pc: float = 0.5,
    ):
        super().__init__(
            actor=actor,
            critic=critic,
            device=device,
            action_size=action_size,
            threshold_pc=threshold_pc,
        )

    def action(self, observation: torch.Tensor):
        """Select the action associated to the pessimiest expert (lowest q value).
        If the pessimist expert is not confident enough (error between the previous
        predicted  next state and ),
        the action is sampled from the pool of experts.
        If all the experts are not confident enough, the action is selected with the
        most confident expert.
        """

        # Compute the action for all the experts
        list_action_mass = []
        list_action_mean = []
        list_action_log_std = []
        for actor in self.robust_actor_nets:
            mean, log_std = actor(observation)
            list_action_mean.append(mean)
            list_action_log_std.append(log_std)
            mujoco_mass = MujocoMass(mean=mean, log_std=log_std)
            list_action_mass.append(mujoco_mass)

        # B, E, A
        tensor_action_mean = torch.stack(list_action_mean, dim=1)
        tensor_action_log_std = torch.stack(list_action_log_std, dim=1)

        # Compute the q value for all the experts and the next state prediction
        list_q_value_min = []
        list_pc_value_mean = []
        # list_mask = []
        for idx_robust, policy_mass in zip(
            range(0, len(self.robust_q_nets), 2), list_action_mass
        ):
            q_value_1, pc_value_1, q_value_1_no_reg = self.robust_q_nets[idx_robust](
                observation, policy_mass.best_action()
            )
            q_value_2, pc_value_2, q_value_2_no_reg = self.robust_q_nets[
                idx_robust + 1
            ](observation, policy_mass.best_action())
            # q_value_min = torch.min(q_value_1_no_reg, q_value_2_no_reg)
            # pc_value_mean = (pc_value_1 + pc_value_2) / 2

            # list_q_value_min.append(q_value_min)
            # list_pc_value_mean.append(pc_value_mean)

            q_value_step_no_reg = torch.stack(
                [q_value_1_no_reg, q_value_2_no_reg], dim=1
            )
            q_value_min, index_q_value_min = torch.min(q_value_step_no_reg, dim=1)
            pc_value_step = torch.stack([pc_value_1, pc_value_2], dim=1)
            pc_value_min_q_value = pc_value_step.gather(
                dim=1, index=index_q_value_min.unsqueeze(dim=-1)
            )
            pc_value_min_q_value = pc_value_min_q_value.view(-1, 1)

            list_q_value_min.append(q_value_min)
            list_pc_value_mean.append(pc_value_min_q_value)
        # convert to tensor

        tensor_pc_value_mean = torch.stack(list_pc_value_mean, dim=-1)  # B, F, E
        tensor_q_value_min = torch.cat(list_q_value_min, dim=-1)  # B, E

        # get the number of experts
        _, _, E = tensor_pc_value_mean.size()

        (
            masked_q_value_pc_min,
            index_masked_q_value_min_pc_closest,
        ) = self._compute_pessimist_q_value(
            observation=observation,
            tensor_q_value_min=tensor_q_value_min,
            observation_prediction=self.observation_prediction,
            threshold_pc=self.threshold_pc,
            tensor_pc_value_mean=tensor_pc_value_mean,
        )
        action_mean_selected = torch.index_select(
            tensor_action_mean, dim=1, index=index_masked_q_value_min_pc_closest
        )
        action_log_std_selected = torch.index_select(
            tensor_action_log_std, dim=1, index=index_masked_q_value_min_pc_closest
        )

        self.observation_prediction = tensor_pc_value_mean
        mass_robust = MujocoMass(
            mean=action_mean_selected.squeeze(dim=1),
            log_std=action_log_std_selected.squeeze(dim=1),
        )
        return mass_robust

    def q_value(
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q_value_1, self.next_state_1, self.q_value_no_reg_1 = self.q_value_network_1(
            observation, action
        )
        q_value_2, self.next_state_2, self.q_value_no_reg_2 = self.q_value_network_2(
            observation, action
        )
        return q_value_1, q_value_2

    def target_q_value(  # type: ignore
        self, observation: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            target_q_value_1,
            _,
            self.target_q_value_no_reg_1,
        ) = self.target_q_value_network_1(observation, action)
        (
            target_q_value_2,
            _,
            self.target_q_value_no_reg_2,
        ) = self.target_q_value_network_2(observation, action)
        return target_q_value_1, target_q_value_2


class PessimistExpertAgentHydraDefault(PessimistExpertAgentHydra):
    def _compute_pessimist_q_value(
        self,
        observation: torch.Tensor,
        tensor_pc_value_mean: torch.Tensor,
        tensor_q_value_min: torch.Tensor,  # B, E
        observation_prediction: Optional[torch.Tensor],
        threshold_pc: float,
    ):
        DEFAULT_EXPERT = 0
        # get the number of experts
        _, _, E = tensor_pc_value_mean.size()

        # Compute the expert with the lowest previosly predicted next state error

        # Add dimension to observation to be able to repeat it and we repeat it
        # Expert_number_times times
        duplicated_observation = observation.unsqueeze(dim=-1).repeat(
            (1, 1, E)
        )  # B, F, E
        observation_prediction = (
            duplicated_observation
            if observation_prediction is None
            else observation_prediction
        )

        # Compute the error between the previous predicted next state and the current
        # predicted next state
        tensor_error_pc_value_mean = torch.abs(
            observation_prediction - duplicated_observation
        ).mean(
            dim=1
        )  # B, F , E -> B, E

        closest_idx_error = torch.argmin(
            tensor_error_pc_value_mean, dim=-1
        )  # B, E -> B Argmin on expert dimension

        default_idx = torch.ones_like(closest_idx_error) * DEFAULT_EXPERT
        q_value_expert_default = tensor_q_value_min[:, DEFAULT_EXPERT]  # B

        mask_fill = tensor_error_pc_value_mean >= threshold_pc

        # Count the number of experts that are masked
        nb_all_masked = mask_fill.sum(dim=-1).view(-1, 1)
        if self.monitor:
            self._update_info(
                nb_all_masked=nb_all_masked,
                tensor_error_pc_value_mean=tensor_error_pc_value_mean,
            )

        masked_q_values_concat = tensor_q_value_min.squeeze(dim=1).clone()
        masked_q_values_concat = masked_q_values_concat.masked_fill(
            mask_fill, float("inf")
        )

        masked_q_value_pc_min, index_masked_q_value_pc_min = torch.min(
            masked_q_values_concat, dim=-1
        )
        all_masked = torch.isinf(masked_q_value_pc_min)

        index_all_masked = torch.where(all_masked == True)[0]  # noqa: E712

        masked_q_value_pc_min = masked_q_value_pc_min.scatter(
            dim=0,
            index=index_all_masked,
            src=q_value_expert_default[index_all_masked],
        )
        index_masked_q_value_min_default = index_masked_q_value_pc_min.scatter(
            dim=0, index=index_all_masked, src=default_idx[index_all_masked]
        )
        return masked_q_value_pc_min, index_masked_q_value_min_default


class PessimistExpertAgentHydraDefaultAdaptativeThreshold(
    PessimistExpertAgentHydraDefault
):
    def __init__(
        self,
        actor: nn.Module,
        critic: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(actor, critic, device, action_size)
        self.threshold_pc: torch.Tensor = torch.tensor([], device=self.device)  # type: ignore
        self.threshold_list = []

    def append_last_threshold(self, last_threshold):
        self.threshold_list.append(last_threshold)
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def set_last_threshold(self, last_threshold):
        self.threshold_list[-1] = last_threshold
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def save(self, path: str):
        robust_dict_actor = {
            f"robust_actor_{i}": net.state_dict()
            for i, net in enumerate(self.robust_actor_nets)
        }
        robust_dict_critic = {
            f"robust_critic_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        treshold_dict = {"threshold_pc": self.threshold_pc}

        classic_nets = {}

        classic_nets.update(robust_dict_actor)
        classic_nets.update(robust_dict_critic)
        classic_nets.update(treshold_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)

        for k, v in checkpoint.items():
            if k.startswith("robust_critic"):
                q_robust_copy = deepcopy(
                    self.critic
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)
                logging.info(f"Loaded critic {k}")

            if k.startswith("robust_actor"):
                actor_robust_copy = deepcopy(self.actor)
                actor_robust_copy.load_state_dict(v)
                self.robust_actor_nets.append(actor_robust_copy)
                logging.info(f"Loaded actor {k}")
            if k.startswith("threshold_pc"):
                self.threshold_pc = v
                self.threshold_list = v.tolist()
                logging.info(f"Loaded threshold {k}")
                logging.info(f"Threshold value :{v}")


class PessimistExpertAgentHydraAdaptativeThreshold(PessimistExpertAgentHydra):
    def __init__(
        self,
        actor: nn.Module,
        critic: nn.Module,
        device: torch.device,
        action_size: int = 1,
    ):
        super().__init__(actor, critic, device, action_size)
        self.threshold_pc: torch.Tensor = torch.tensor([], device=self.device)  # type: ignore
        self.threshold_list = []

    def append_last_threshold(self, last_threshold):
        self.threshold_list.append(last_threshold)
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def set_last_threshold(self, last_threshold):
        self.threshold_list[-1] = last_threshold
        self.threshold_pc = torch.tensor(self.threshold_list, device=self.device)

    def save(self, path: str):
        robust_dict_actor = {
            f"robust_actor_{i}": net.state_dict()
            for i, net in enumerate(self.robust_actor_nets)
        }
        robust_dict_critic = {
            f"robust_critic_{i}": net.state_dict()
            for i, net in enumerate(self.robust_q_nets)
        }
        treshold_dict = {"threshold_pc": self.threshold_pc}

        classic_nets = {}

        classic_nets.update(robust_dict_actor)
        classic_nets.update(robust_dict_critic)
        classic_nets.update(treshold_dict)
        torch.save(
            classic_nets,
            path,
        )

    def load(self, path: str) -> NoReturn:  # type: ignore
        checkpoint = torch.load(path, map_location=self.device)

        for k, v in checkpoint.items():
            if k.startswith("robust_critic"):
                q_robust_copy = deepcopy(
                    self.critic
                )  # q_robust_1 and q_robust should have the same architecture
                q_robust_copy.load_state_dict(v)
                self.robust_q_nets.append(q_robust_copy)
                logging.info(f"Loaded critic {k}")

            if k.startswith("robust_actor"):
                actor_robust_copy = deepcopy(self.actor)
                actor_robust_copy.load_state_dict(v)
                self.robust_actor_nets.append(actor_robust_copy)
                logging.info(f"Loaded actor {k}")
            if k.startswith("threshold_pc"):
                self.threshold_pc = v
                self.threshold_list = v.tolist()
                logging.info(f"Loaded threshold {k}")


