from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, Union

from d3rlpy.models.encoders import EncoderFactory, PixelEncoderFactory, VectorEncoderFactory
from d3rlpy.models.torch import (
    Encoder,
    EncoderWithAction,
    PixelEncoder,
    PixelEncoderWithAction,
    VectorEncoder,
    VectorEncoderWithAction,
)

from d3rlpy.models.encoders import register_encoder_factory


class FeatureSizeDefaultEncoderFactory(EncoderFactory):
    """Default encoder factory class with feature size specification.

    This encoder factory returns an encoder based on observation shape with output feature size.

    Args:
        activation (str): activation function name.
        use_batch_norm (bool): flag to insert batch normalization layers.
        dropout_rate (float): dropout probability.

    """

    TYPE: ClassVar[str] = "default_fsize"
    _feature_size: int
    _activation: str
    _use_batch_norm: bool
    _dropout_rate: Optional[float]

    def __init__(
        self,
        feature_size: float = 10,
        activation: str = "relu",
        use_batch_norm: bool = False,
        dropout_rate: Optional[float] = None,
    ):
        self._feature_size = feature_size
        self._activation = activation
        self._use_batch_norm = use_batch_norm
        self._dropout_rate = dropout_rate

    def create(self, observation_shape: Sequence[int], feature_size: int = 0) -> Encoder:
        factory: Union[PixelEncoderFactory, VectorEncoderFactory]
        if not feature_size:
            feature_size = self._feature_size
        if len(observation_shape) == 3:
            factory = PixelEncoderFactory(
                activation=self._activation,
                use_batch_norm=self._use_batch_norm,
                dropout_rate=self._dropout_rate,
                feature_size=feature_size
            )
        else:
            factory = VectorEncoderFactory(
                activation=self._activation,
                use_batch_norm=self._use_batch_norm,
                dropout_rate=self._dropout_rate,
                hidden_units=[256, feature_size]
            )
        return factory.create(observation_shape)

    def create_with_action(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        discrete_action: bool = False,
        feature_size: int = 0,
    ) -> EncoderWithAction:
        if not feature_size:
            feature_size = self._feature_size
        factory: Union[PixelEncoderFactory, VectorEncoderFactory]
        if len(observation_shape) == 3:
            factory = PixelEncoderFactory(
                activation=self._activation,
                use_batch_norm=self._use_batch_norm,
                dropout_rate=self._dropout_rate,
                feature_size=feature_size,
            )
        else:
            factory = VectorEncoderFactory(
                activation=self._activation,
                use_batch_norm=self._use_batch_norm,
                dropout_rate=self._dropout_rate,
                hidden_units=[256, feature_size],
            )
        return factory.create_with_action(
            observation_shape, action_size, discrete_action
        )

    def get_params(self, deep: bool = False) -> Dict[str, Any]:
        return {
            "activation": self._activation,
            "use_batch_norm": self._use_batch_norm,
            "dropout_rate": self._dropout_rate,
        }


register_encoder_factory(FeatureSizeDefaultEncoderFactory)