from typing import Dict, Optional
import copy
from omegaconf import DictConfig, OmegaConf
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.distributed import DistributedSampler

__all__ = ["PadoDataLoader"]


class PadoDataLoader(DataLoader):
    """
    Base class of data loaders.
    * batch_size, num_workers should be passed "per-gpu"
    """

    def __init__(self,
                 dataset,  # PadoDataset
                 batch_size: int,
                 shuffle: bool = False,
                 num_workers: int = 0,
                 pin_memory: bool = False,
                 drop_last: bool = False,
                 timeout: float = 0, *,
                 prefetch_factor: int = 2,
                 persistent_workers: bool = False,
                 sampler=None,
                 collate_fn=None,
                 seed: int = 0,
                 start_epoch: int = 0) -> None:

        if (collate_fn is None) and hasattr(dataset, 'collate_fn'):
            collate_fn = dataset.collate_fn

        self.init_kwargs = dict(
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
            sampler=sampler,
            collate_fn=collate_fn
        )

        super().__init__(dataset, **self.init_kwargs)

        self.seed = seed  # this is currently only supported by DistributedSampler.
        self.epoch = start_epoch
        self.set_epoch(start_epoch)

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch
        if isinstance(self.sampler, DistributedSampler):
            self.sampler.set_epoch(epoch)
        return

    def set_collate_fn(self, fn) -> None:
        self.collate_fn = fn

    @classmethod
    def from_config(cls, cfg: DictConfig, dataset, sampler=None, collate_fn=None) -> "PadoDataLoader":
        cfg_as_dict = OmegaConf.to_container(cfg, resolve=True)
        return cls(dataset, sampler=sampler, collate_fn=collate_fn, **cfg_as_dict)

    @classmethod
    def from_other(cls, other: "PadoDataLoader", override_kwargs: Optional[Dict] = None) -> "PadoDataLoader":
        init_kwargs = copy.deepcopy(other.init_kwargs)
        if override_kwargs is not None:
            for k, v in override_kwargs.items():
                if k in init_kwargs.keys():
                    init_kwargs[k] = v
                else:
                    raise ValueError(f"Dataloader does not accept key {k} (value: {v})")
        return cls(dataset=other.dataset, seed=other.seed, start_epoch=other.epoch, **init_kwargs)  # noqa
