from abc import abstractmethod
from typing import Any, Protocol, Sequence, TypeVar, runtime_checkable

import torch.utils.data as th_data

DT = TypeVar("DT", bound=th_data.Dataset)


@runtime_checkable
class MCARDataset(Protocol[DT]):
    orig_dataset: DT


@runtime_checkable
class DatasetHasFeaturizer(Protocol):
    featurizer: Any

    @property
    @abstractmethod
    def to_return_zs(self) -> bool: ...

    @to_return_zs.setter
    @abstractmethod
    def to_return_zs(self, to_return_zs: bool) -> None: ...


@runtime_checkable
class DatasetHasFeatureName(Protocol):
    feature_names: Sequence[str]
    feature_types: Sequence[str | Sequence[float] | Sequence[str]]


@runtime_checkable
class ExpertHasName(Protocol):
    expert_names: Sequence[str]


@runtime_checkable
class SubsetFeatureSelectorDataManager(Protocol):
    max_features: int


@runtime_checkable
class KNNSelectorDataManager(SubsetFeatureSelectorDataManager, Protocol):
    n_neighs: int
