from torch.utils.data import Dataset
from typing import Tuple, Any
import random
import torch
from torchvision import transforms
from PIL import Image


class DistributedDataset(Dataset):
    def __init__(self, dataset: Dataset, index):
        super().__init__()
        self.dataset = dataset
        self.index = index

    def __getitem__(self, item):

        return self.dataset.__getitem__(self.index[item])

    def __len__(self):
        return len(self.index)
    


def distributed_dataset(dataset: Dataset, sample_size: int, rank: int, size: int = None, remove_index: int = 0, seed: int = 777, node: int = 2) -> Dataset:
    if size is None:
        size = len(dataset)
    random.seed(seed)
    indexes = [x for x in range(size)]
    random.shuffle(indexes)
    indexes_list = []

    for i in range(node):
        temp_index = indexes[i * sample_size: (i + 1) * sample_size]
        del temp_index[remove_index]
        indexes_list.append(temp_index)

    return DistributedDataset(dataset, indexes_list[rank])
