from PIL import Image
import random
import numpy as np
import os
import pandas as pd
import csv
import torch
from torch.utils.data import Dataset
from torchvision import transforms

CLIENTS = 100

train_transform = transforms.Compose([
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

class cinicDataset(Dataset):
    def __init__(self, csv_file, transform=train_transform):
        self.annotations = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        x = self.annotations.iloc[index,:-1]
        x = torch.tensor(x ,dtype=torch.float32)
        x = x/255
        x = x.reshape(3,32,32)
        y_label = torch.tensor(self.annotations.iloc[index, -1], dtype=torch.long)
        if self.transform:
            image = self.transform(x)
        return image, y_label
