from dataclasses import dataclass

from egxc.dataloading.datasets.base import (
    BaseDataset,
    PartiallySplitDataset,
    PresplitDataset,
    UnsplitDataset,
)


@dataclass
class DatasetEnsemble:
    train: BaseDataset
    val: BaseDataset
    test: BaseDataset

    @classmethod
    def infer_split(
        cls,
        dataset: BaseDataset,
        train_fraction: float | None = None,
        val_fraction: float | None = None,
        data_split_seed: int = 0,
    ) -> 'DatasetEnsemble':
        if isinstance(dataset, PresplitDataset):
            assert train_fraction is None and val_fraction is None
            return cls(*dataset.split())
        elif isinstance(dataset, PartiallySplitDataset):
            assert train_fraction is None and val_fraction is not None
            return cls(*dataset.random_split(val_fraction, data_split_seed))
        elif isinstance(dataset, UnsplitDataset):
            assert train_fraction is not None and val_fraction is not None
            return cls(
                *dataset.random_split(train_fraction, val_fraction, data_split_seed)
            )
        elif isinstance(dataset, BaseDataset):
            return cls(
                *dataset.infer_split(
                    train_fraction=train_fraction,
                    val_fraction=val_fraction,
                    data_split_seed=data_split_seed,
                )
            )
        else:
            raise ValueError(f'Unknown dataset type: {type(dataset)}')

    @classmethod
    def from_presplit_dataset(cls, dataset: PresplitDataset) -> 'DatasetEnsemble':
        train, val, test = dataset.split()
        return cls(train, val, test)

    @classmethod
    def from_partial_random_split(
        cls,
        dataset: PartiallySplitDataset,
        val_fraction: float,
        data_split_seed: int,
    ) -> 'DatasetEnsemble':
        train, val, test = dataset.random_split(val_fraction, data_split_seed)
        return cls(train, val, test)

    @classmethod
    def from_random_split(
        cls,
        dataset: UnsplitDataset,
        train_fraction: float,
        val_fraction: float,
        data_split_seed: int,
    ) -> 'DatasetEnsemble':
        train, val, test = dataset.random_split(
            train_fraction, val_fraction, data_split_seed
        )
        return cls(train, val, test)

    def __repr__(self) -> str:
        return f'DatasetEnsemble(train={len(self.train)}, val={len(self.val)}, test={len(self.test)})'
