import os
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from .randaugment import RandomAugment
import augment
from augment.autoaugment_extra import CIFAR10Policy
from augment.cutout import Cutout
from .utils_algo import generate_instance_dependent_candidate_labels, generate_instance_dependent_candidate_labels_num
from global_var import TINY_DATA_ROOT

def load_tinyimagenet(ds, batch_size, split_seed=42, device=None, partial_rate=None, partial_num=None, num_or_rate="rate", has_eval_train_loader=False, has_meta_valid=False):
    test_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(32),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    temp_train = ImageFolder(root=os.path.join(TINY_DATA_ROOT, "train"), transform=transforms.ToTensor())
    temp_valid = ImageFolder(root=os.path.join(TINY_DATA_ROOT, "train"), transform=test_transform)
    data_size = len(temp_train)
    train_dataset, _ = torch.utils.data.random_split(temp_train,
                                                                    [int(data_size * 0.9), data_size - int(data_size * 0.9)],
                                                                    torch.Generator().manual_seed(split_seed))
    train_dataset_for_partial_labels, valid_dataset = torch.utils.data.random_split(temp_valid,
                                                                    [int(data_size * 0.9), data_size - int(data_size * 0.9)],
                                                                    torch.Generator().manual_seed(split_seed))

    full_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False, num_workers=0)
    for data, targets in full_train_loader:
        traindata, trainlabels = data, targets.long()

    train_loader_for_partial_labels = torch.utils.data.DataLoader(dataset=train_dataset_for_partial_labels, batch_size=batch_size*4, shuffle=False, num_workers=8)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=8)
    test_dataset = ImageFolder(root=os.path.join(TINY_DATA_ROOT, "val"), transform=test_transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=8)
    # set test dataloader
    
    if num_or_rate == "rate" or num_or_rate is None:
        partialY, avgC = generate_instance_dependent_candidate_labels(ds, train_loader_for_partial_labels, trainlabels, device, _rate=partial_rate)
    if num_or_rate == "num":
        partialY, avgC = generate_instance_dependent_candidate_labels_num(ds, train_loader_for_partial_labels, trainlabels, device, _num=partial_num)
    print('Average candidate num: ', avgC)
    partial_matrix_dataset = TinyImageNet_Augmentention(traindata, partialY.float(), trainlabels.float())
    # generate partial label dataset
    partial_matrix_train_loader = torch.utils.data.DataLoader(dataset=partial_matrix_dataset, 
                                                                batch_size=batch_size, 
                                                                shuffle=True, 
                                                                num_workers=8,
                                                                prefetch_factor=8*batch_size,
                                                                drop_last=True)
    dim = 32 * 32 * 3
    K = 200
    return_list = []
    return_list += [partial_matrix_train_loader, valid_loader, test_loader, dim, K]
    if has_eval_train_loader:
        eval_train_dataset = TinyImageNet(traindata, partialY.float(), trainlabels.float())
        eval_train_loader = torch.utils.data.DataLoader(dataset=eval_train_dataset, 
                                                                batch_size=batch_size, 
                                                                shuffle=True, 
                                                                num_workers=8,
                                                                drop_last=True)
        return_list.append(eval_train_loader)
    if has_meta_valid:
        meta_full_valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=8, drop_last=False)
        for data, targets in meta_full_valid_loader:
            valdata, vallabels = data, targets.long()
        return_list.append(valdata)
        return_list.append(vallabels)

    return return_list


class TinyImageNet_Augmentention(Dataset):
    def __init__(self, images, given_label_matrix, true_labels):
        self.images = images
        self.given_label_matrix = given_label_matrix
        # user-defined label (partial labels)
        self.true_labels = true_labels

        self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    torchvision.transforms.Resize(32),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                ])
        self.weak_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='reflect'),
                    # transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
                    transforms.ToPILImage(),
                    CIFAR10Policy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
        self.strong_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='reflect'),
                    # transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
                    transforms.ToPILImage(),
                    CIFAR10Policy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])


    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        each_image_o = self.transform(self.images[index])
        each_image_w = self.weak_transform(self.images[index])
        each_image_s = self.strong_transform(self.images[index])
        each_image = [each_image_o, each_image_w, each_image_s]
        each_label = self.given_label_matrix[index]
        each_true_label = self.true_labels[index]
        
        return each_image, each_label, each_true_label, index

class TinyImageNet(Dataset):
    def __init__(self, images, given_label_matrix, true_labels):
        self.images = images
        self.given_label_matrix = given_label_matrix
        # user-defined label (partial labels)
        self.true_labels = true_labels
        self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    torchvision.transforms.Resize(32),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                ])


    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        each_image_o = self.transform(self.images[index])
        each_label = self.given_label_matrix[index]
        each_true_label = self.true_labels[index]
        
        return each_image_o, each_label, each_true_label, index
    
if __name__ == "__main__":
    pass