from typing import Generator, List, Sequence, Tuple

import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, Subset

Sample = Tuple[torch.Tensor, torch.Tensor]


def my_adv_loader(train_data: Dataset, adv_data: List[Sample], T_i: int) -> Generator[int, None, None]:
    """
    train_data: The whole training set
    adv_data: The adversarial examples that are currently available
    T_i: The number of samples to return as a dataloader
    """
    min = len(adv_data)
    max = len(train_data)
    for _ in range(T_i):
        indices = torch.randint(-min, max, size=(1,))
        yield from indices.tolist()


def load_sample(train_data: Dataset, adv_data: List[Sample], i: int):
    """
    train_data: The whole training set
    adv_data: The adversarial examples that are currently available
    i: The sample to load
    """
    if i >= 0:
        return train_data[i]
    else:
        return adv_data[-i - 1]


def my_dataloader(train_data: Dataset, T: int) -> Generator[int, None, None]:
    """
    train_data: The whole training dataset without the adversarial examples
    T: The number of samples for the dataloader
    """
    max = len(train_data)
    for _ in range(T):
        indices = torch.randint(0, max, size=(1,))
        yield from indices.tolist()


def load_stage(train_data: Dataset, sorted_indices: Sequence[int], i: int, T_i: int,
               replacement: bool = False, batch_size: int = 1) -> DataLoader:
    """
    train_data: The whole training set for all stages
    sorted_indices: The indices of the data sorted according to their label
    i: The length of the prefix
    T_i: The number of iterations that we will do => the number of samples the need to be drawn
          using the dataloader will be T_i * batch_size
    replacement: whether to use replacement when drawing the samples
    batch_size: the batch_size of the data_loader
    """
    # First get the fraction of the dataset for this stage
    prefix_i = sorted_indices[:i + 1]
    train_data_prefix = Subset(train_data, prefix_i)

    num_samples=T_i * batch_size#min(len(train_data_prefix),T_i * batch_size)
    # Initialize sampler and data loader
    train_sampler = RandomSampler(train_data_prefix, replacement=replacement,num_samples=num_samples)
    train_loader = DataLoader(train_data_prefix, batch_size=batch_size, sampler=train_sampler,
                              num_workers=0, pin_memory=False)
    
    #print(len(train_data_prefix), len(train_loader))

    # Return an iterator over the samples
    return train_loader


def full_stage(train_data: Dataset, sorted_indices: Sequence[int], i: int,
               shuffle: bool = True, batch_size: int = 1) -> DataLoader:
    """
    train_data: The whole training set for all stages
    sorted_indices: The indices of the data sorted according to their label
    i: The length of the prefix
    shuffle: Whether to shuffle the dataset before drawing the samples iteratively
    batch_size: the batch size of the data loader
    """
    # First get the fraction of the dataset for this stage
    prefix_i = sorted_indices[:i + 1]
    train_data_prefix = Subset(train_data, prefix_i)

    # Initialize data loader
    train_loader = DataLoader(train_data_prefix, batch_size=batch_size, shuffle=shuffle)

    return train_loader


def load_stage_standard(train_data: Dataset, T_i: int,
                        replacement: bool = True, batch_size: int = 1) -> DataLoader:
    """
    train_data: The whole training set for all stages
    T_i: The number of iterations that we will do => the number of samples the need to be drawn
          using the dataloader will be T_i * batch_size
    replacement: whether to use replacement when drawing the samples
    batch_size: the batch_size of the data_loader
    """

    # Initialize sampler and data loader
    train_sampler = RandomSampler(train_data, replacement=replacement, num_samples=T_i * batch_size)
    train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=False)

    # Return an iterator over the samples
    return train_loader


def full_stage_standard(train_data: Dataset, shuffle: bool = True, batch_size: int = 1) -> DataLoader:
    """
    train_data: The whole training set for all stages
    shuffle: Whether to shuffle the dataset before drawing the samples iteratively
    batch_size: the batch size of the data loader
    """

    # Initialize data loader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle)

    return train_loader
