import torch
import torchvision
import os
import csv
import config
import numpy as np
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image


def get_transform(opt, train=True):
    transforms_list = []
    transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
    if(train):
        transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.input_height // 8))
        transforms_list.append(transforms.RandomRotation(3))
        transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
    transforms_list.append(transforms.ToTensor())
    if(opt.dataset == 'cifar10' or opt.dataset == 'cifar100'):
        transforms_list.append(transforms.Normalize([0.5], [0.5]))
    else:
        raise Exception("Invalid Dataset")
    return transforms.Compose(transforms_list)


def get_dataloader(opt, train, mode):
    transform = get_transform(opt, train)
    if(opt.dataset == 'cifar10'):
        dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform, download=True)
    else:
        raise Exception('Invalid dataset')

    if(train):
        idx_list = [i for i in range(len(dataset))]
        ft_idx = np.random.permutation(len(dataset))[:int(opt.portion*len(dataset))]
        train_idx = [idx_list[i] for i in range(len(idx_list)) if i not in ft_idx]
        
        if mode == "ft":
            dataset = torch.utils.data.Subset(dataset, ft_idx)
        elif mode == "train":
            dataset = torch.utils.data.Subset(dataset, train_idx)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
    return dataloader


def main():
    opt = config.get_arguments().parse_args()
    transforms = get_transform(opt, False)
    dataloader = get_dataloader(opt, False)
    for item in dataloader:
        images, labels = item
    

if(__name__ == '__main__'):
    main()