"""Common aliases for type hints"""

from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union

import gym
import numpy as np
import torch as th

from stable_baselines3.common import callbacks, vec_env

GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
TensorDict = Dict[Union[str, int], th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]

# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Schedule = Callable[[float], float]

class RolloutBufferSamples_fair(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    old_values: List[th.Tensor] # [r, [r_U_0,..],[r_B_0,..]]
    old_log_prob: th.Tensor
    advantages: List[th.Tensor]
    returns: List[th.Tensor]

    deltas: th.Tensor   # only used by APPO
    delta_deltas: th.Tensor # only used by APPO

# below are from APPO's paper

# class RolloutBufferSamples(NamedTuple):
#     observations: th.Tensor
#     actions: th.Tensor
#     old_values: th.Tensor
#     old_log_prob: th.Tensor
#     advantages: th.Tensor
#     returns: th.Tensor
#     deltas: th.Tensor   # only used by APPO
#     delta_deltas: th.Tensor # only used by APPO


# class DictRolloutBufferSamples(RolloutBufferSamples):
#     observations: TensorDict
#     actions: th.Tensor
#     old_values: th.Tensor
#     old_log_prob: th.Tensor
#     advantages: th.Tensor
#     returns: th.Tensor


# class ReplayBufferSamples(NamedTuple):
#     observations: th.Tensor
#     actions: th.Tensor
#     next_observations: th.Tensor
#     dones: th.Tensor
#     rewards: th.Tensor


# class DictReplayBufferSamples(ReplayBufferSamples):
#     observations: TensorDict
#     actions: th.Tensor
#     next_observations: th.Tensor
#     dones: th.Tensor
#     rewards: th.Tensor


# class RolloutReturn(NamedTuple):
#     episode_timesteps: int
#     n_episodes: int
#     continue_training: bool


# Below: Not used anywhere

# class TrainFrequencyUnit(Enum):
#     STEP = "step"
#     EPISODE = "episode"


# class TrainFreq(NamedTuple):
#     frequency: int
#     unit: TrainFrequencyUnit  # either "step" or "episode"