from typing import List

import torch
from opacus.data_loader import dtype_safe, shape_safe, wrap_collate_with_empty
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.utils.data._utils.collate import default_collate

from ipp import IPP
from ipp.utils.sampler import IPPSampler


class IPPDataLoader(DataLoader):
    """
    Modifies the original `opacus.data_loader.DPDataLoader` for individualized privacy progression.

    Args:
        dataset: the dataset to be loaded.
        per_sample_sampling_rates: The sampling rate of each sample in `dataset`.
    """
    def __init__(self, dataset: Dataset, per_sample_sampling_rates: torch.Tensor, collate_fn=None):
        n_data = len(dataset)
        batch_size = torch.sum(per_sample_sampling_rates).item()
        n_steps = int(n_data / batch_size)
        batch_sampler = IPPSampler(
            n_samples=len(dataset),
            per_sample_sampling_rates=per_sample_sampling_rates,
            n_steps=n_steps
        )

        if collate_fn is None:
            collate_fn = default_collate
        sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]]
        dtypes = [dtype_safe(x) for x in dataset[0]]

        super().__init__(
            dataset=dataset,
            batch_sampler=batch_sampler,
            collate_fn=wrap_collate_with_empty(
                collate_fn=collate_fn,
                sample_empty_shapes=sample_empty_shapes,
                dtypes=dtypes,
            ),
            generator=None)
        
    
    @classmethod
    def from_data_loader(cls, data_loader: DataLoader, ipp: IPP):
        """
        Privatize the given `data_loader` according to `ipp`.

        Args:
            data_loader: The `DataLoader` to be privatized.
            ipp: The `IPP` profile to follow.
        """
        per_sample_sampling_rates = ipp.get_per_sample_sampling_rates()[0]

        return cls(
            dataset=data_loader.dataset,
            per_sample_sampling_rates=per_sample_sampling_rates,
            collate_fn=data_loader.collate_fn
        )
