import abc
from typing import (
    Any,
    Dict,
    Generic,
    Hashable,
    Iterable,
    NoReturn,
    Sequence,
    TypeVar,
    Union,
)

import torch
from torch.nn.parameter import Parameter

EnvObservation = TypeVar("EnvObservation")
EnvAction = TypeVar("EnvAction")

Action = TypeVar("Action")


# TODO: Rename to IOutput ?
class IAction(Generic[Action]):
    """
    Interface for actions, more generally it is the output of the neural network.
    Args:
        Generic ([Action]): Generic type of action
    """

    @abc.abstractmethod
    def sample(self) -> Action:
        """
        If the $\pi$ policy is a distribution implement this method.
        For example if you are using a policy-based algorithm.
        Returns:
            Action: General type of action. Usually it is a `torch.Tensor`,
            `np.ndarray` a sequence of these.
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def rsample(self) -> Action:
        """
        If the $\pi$ policy is a distribution implement this method.
        For example if you are using a policy-based algorithm.
        Returns:
            Action: General type of action. Usually it is a `torch.Tensor`,
            `np.ndarray` a sequence of these.
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def best_action(self) -> Action:
        """Output the best action
        Returns:
            Action: General type of action. Usually it is a `torch.Tensor`,
            `np.ndarray` a sequence of these.
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def log_prob(self, action: Action) -> torch.Tensor:
        """
        Log probability of a given action.
        Args:
            action (Action): Action
        Returns:
            torch.Tensor: tensor of log probability with size `[Any, 1]`
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def entropy(self) -> torch.Tensor:
        """
        Entropy
        Returns:
            torch.Tensor: tensor of entropy with size `[Any, 1]`
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def value(self) -> torch.Tensor:
        """
        Output of value function.
        Returns:
            torch.Tensor: Value output with size ``[Any, 1]`
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def target_value(self) -> torch.Tensor:

        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def infos(self) -> Dict[Hashable, Any]:
        raise NotImplementedError("to implements")


class IAgent(Generic[EnvObservation, EnvAction, Action]):
    """
    Interface for Agent.
    Args:
        Generic ([EnvObservation, EnvAction, Action]): Generic type for EnvObservation, EnvAction and Action.
    """

    @abc.abstractmethod
    def action(self, observation: EnvObservation) -> IAction[Action]:
        """Output an IAction instance.
        Args:
            observation (EnvObservation): Current state.
        Returns:
            IAction[Action]: Output of the agent.
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def q_value(
        self, observation: EnvObservation, action: Action, **kwargs
    ) -> Sequence[torch.Tensor]:
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def target_q_value(
        self, observation: EnvObservation, action: Action, **kwargs
    ) -> Sequence[torch.Tensor]:
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def infos(self) -> Dict[Hashable, Any]:
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def save(self, path: str) -> NoReturn:
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def load(self, path: str) -> NoReturn:
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def train(self) -> NoReturn:
        """
        Set agent to train mode
        """

        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def eval(self) -> NoReturn:
        """
        set agent to eval mode and disable grad
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def clone(self) -> NoReturn:
        """
        Clone the agent
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def parameters(self) -> Union[Parameter, Iterable[Parameter]]:
        """
        Parameters of the agent
        Returns:
            Union[Parameter, Iterable[Parameter]]: [description]
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def parameters_actor(self) -> Union[Parameter, Iterable[Parameter]]:
        """
        Parameters of the agent
        Returns:
            Union[Parameter, Iterable[Parameter]]: [description]
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def parameters_value(self) -> Union[Parameter, Iterable[Parameter]]:
        """
        Parameters of the agent
        Returns:
            Union[Parameter, Iterable[Parameter]]: [description]
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def parameters_critic(self) -> Union[Parameter, Iterable[Parameter]]:
        """
        Parameters of the agent
        Returns:
            Union[Parameter, Iterable[Parameter]]: [description]
        """
        raise NotImplementedError("to implements")

    @abc.abstractmethod
    def soft_update(self, tau: float) -> NoReturn:
        raise NotImplementedError("to implements")
