from typing import Dict, NamedTuple, Optional, Protocol, Tuple, Union

import numpy as np
import torch as th

from ucrl.classify.classifier import PtEstGRU, DistributionGRU, CostBudgetEstMLP
from ucrl.common.policies import BaseHPolicy, ActorCriticCPolicy

TensorDict = Dict[str, th.Tensor]


class RolloutBufferHSamples(NamedTuple):
    observations: th.Tensor
    hidden_obs: th.Tensor
    full_hidden_obs: th.Tensor
    actions: th.Tensor
    log_scores: th.Tensor
    old_values: th.Tensor
    old_log_score_values: th.Tensor
    old_log_prob: th.Tensor
    advantages: th.Tensor
    log_score_advantages: th.Tensor
    returns: th.Tensor
    log_score_returns: th.Tensor


class DictRolloutBufferHSamples(NamedTuple):
    observations: TensorDict
    hidden_obs: th.Tensor
    full_hidden_obs: th.Tensor
    actions: th.Tensor
    log_scores: th.Tensor
    old_values: th.Tensor
    old_log_score_values: th.Tensor
    old_log_prob: th.Tensor
    advantages: th.Tensor
    log_score_advantages: th.Tensor
    returns: th.Tensor
    log_score_returns: th.Tensor

class PolicyHPredictor(Protocol):
    classifier: Union[PtEstGRU, DistributionGRU]
    policy: BaseHPolicy
    device: th.device

    def predict(
            self,
            observation: Union[np.ndarray, Dict[str, np.ndarray]],
            hidden_obs: Union[np.ndarray, Dict[str, np.ndarray]],
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Get the policy action from an observation (and optional hidden state).
        Includes sugar-coating to handle different observations (e.g. normalizing images).

        :param observation: the input observation
        :param hidden_obs: the input hidden observation
        :param state: The last hidden states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
            this corresponds to beginning of episodes,
            where the hidden states of the RNN must be reset.
        :param deterministic: Whether to return deterministic actions.
        :return: the model's action and the next hidden state
            (used in recurrent policies)
        """

    def get_cost_sa(
            self,
            original_obs: np.ndarray,
            clipped_actions: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get the estimated for an observation action pair.

        :param original_obs: the original (unnormalized) observation
        :param clipped_actions: the clipped actions to execute
        """


class RolloutBufferCSamples(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    neg_costs: th.Tensor
    old_values: th.Tensor
    old_neg_cost_values: th.Tensor
    old_log_prob: th.Tensor
    advantages: th.Tensor
    neg_cost_advantages: th.Tensor
    returns: th.Tensor
    neg_cost_returns: th.Tensor

class DictRolloutBufferCSamples(NamedTuple):
    observations: TensorDict
    actions: th.Tensor
    neg_costs: th.Tensor
    old_values: th.Tensor
    old_neg_cost_values: th.Tensor
    old_log_prob: th.Tensor
    advantages: th.Tensor
    neg_cost_advantages: th.Tensor
    returns: th.Tensor
    neg_cost_returns: th.Tensor

class PolicyCPredictor(Protocol):
    classifier: CostBudgetEstMLP
    policy: ActorCriticCPolicy
    device: th.device

    def predict(
            self,
            observation: Union[np.ndarray, Dict[str, np.ndarray]],
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Get the policy action from an observation.
        Includes sugar-coating to handle different observations (e.g. normalizing images).

        :param observation: the input observation
        :param state: The last hidden states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
            this corresponds to beginning of episodes,
            where the hidden states of the RNN must be reset.
        :param deterministic: Whether to return deterministic actions.
        :return: the model's action and the next hidden state
            (used in recurrent policies)
        """
