import os
from torch.utils.data import DataLoader, Dataset
from pretrainer.datasets.pretrain.FaceDataset import HeatmapDataset_Image


class CustomDataset(Dataset):
    def __init__(self, data, label):
        self.labels = label
        self.data = data

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def load_data_to_memory(dataloader):
    for data, label in dataloader:
        data = data.cpu()
        label = label.cpu()
    return data, label


def face(data_mode='train', resize=224, shots=10):
    face_root = "C:/Users/Zber/Desktop/Subjects_Heatmap"
    num_classes = 7
    if data_mode == 'train':
        batch_size = num_classes * shots
        shuffle = False
        ann = os.path.join(face_root, "fs_{}_{}.txt".format(shots, data_mode))
    else:
        batch_size = 350
        shuffle = False
        ann = os.path.join(face_root, "fs_{}.txt".format(data_mode))

    dataset = HeatmapDataset_Image(face_root, ann, cumulated=True, num_frames=100, resize=resize)
    dataloader = DataLoader(dataset, num_workers=1, pin_memory=True, batch_size=batch_size, shuffle=shuffle)
    data, label = load_data_to_memory(dataloader)
    dataset = CustomDataset(data, label)
    dataloader = DataLoader(dataset, num_workers=1, pin_memory=True, batch_size=batch_size, shuffle=shuffle)
    return dataloader


def create_dataloader(dataname='', resize=224, n_shot=5, n_query=15, n_eposide=1, mode='d'):
    if dataname == 'face':
        train = face('train', resize=resize, shots=n_shot)
        test = face('test', resize=resize)
        return train, test

    else:
        train = None
        test = None
        raise NotImplementedError
