import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from hyper_params import CLASSES

train_transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.Normalize(mean=0.5, std=0.5),
])

synthetic_transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5),
])


class EmnistDataset(Dataset):
    def __init__(self, csv_file, trainsform=train_transform):
        self.dataframe = pd.read_csv(csv_file)
        self.transform = trainsform
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index):
        x = torch.tensor(self.dataframe.iloc[index, :-1], dtype=torch.float32)
        y = torch.tensor(self.dataframe.iloc[index, -1])
        x = x.reshape([1,28,28])
        if self.transform:
            image = self.transform(x)
        return image, y

class latentDataset(Dataset):
    def __init__(self, z_dict):
        self.samples = []
        self.labels = []
        for y, zs in z_dict.items():
            for z in zs:
                self.samples.append(z)
                self.labels.append(y)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        z = self.samples[index].clone().detach().float()
        y = torch.tensor(self.labels[index])
        return z, y


class EmnistDataset_synthetic(Dataset):
    def __init__(self, trainsform=synthetic_transform, folder_path='distill_images'):
        self.samples = []
        self.labels = []
        for l in range(CLASSES):
            folder = os.path.join(folder_path, str(l))
            images = os.listdir(folder)
            for image in images:
                self.samples.append(os.path.join(folder, image))
                self.labels.append(l)
        self.transform = trainsform
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        y = torch.tensor(self.labels[index])
        x = Image.open(self.samples[index]).convert("L")
        #x = x.reshape([1,28,28])
        if self.transform:
            image = self.transform(x)
        return image, y