import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from .randaugment import RandomAugment
import augment
from augment.autoaugment_extra import CIFAR10Policy
from augment.cutout import Cutout
from .utils_algo import generate_instance_dependent_candidate_labels, generate_instance_dependent_candidate_labels_num
from global_var import DATA_ROOT


def load_cifar100(ds, batch_size, split_seed=42, device=None, partial_rate=None, partial_num=None, num_or_rate="rate", has_eval_train_loader=False, has_meta_valid=False):
    test_transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    
    temp_train = dsets.CIFAR100(root=os.path.join(DATA_ROOT, "REDGE"), train=True, download=True, transform=transforms.ToTensor())
    temp_valid = dsets.CIFAR100(root=os.path.join(DATA_ROOT, "REDGE"), train=True, transform=test_transform)
    data_size = len(temp_train)
    train_dataset, _ = torch.utils.data.random_split(temp_train,
                                                                    [int(data_size * 0.9), data_size - int(data_size * 0.9)],
                                                                    torch.Generator().manual_seed(split_seed))
    train_dataset_for_partial_labels, valid_dataset = torch.utils.data.random_split(temp_valid,
                                                                    [int(data_size * 0.9), data_size - int(data_size * 0.9)],
                                                                    torch.Generator().manual_seed(split_seed))

    full_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False, num_workers=8)
    for data, targets in full_train_loader:
        traindata, trainlabels = data, targets.long()
    # get original data and labels
    # check 
    # full_valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=8)
    # for data, targets in full_valid_loader:
    #     validdata, validlabels = data, targets.long()
    # full_temp_valid_loader = torch.utils.data.DataLoader(dataset=temp_valid_dataset, batch_size=len(temp_valid_dataset), shuffle=False, num_workers=8)
    # for data, targets in full_temp_valid_loader:
    #     tempvaliddata, tempvalidlabels = data, targets.long()
    
    # print(validlabels)
    # print(tempvalidlabels)
    # print((validlabels != tempvalidlabels).sum())
    train_loader_for_partial_labels = torch.utils.data.DataLoader(dataset=train_dataset_for_partial_labels, batch_size=batch_size, shuffle=False, num_workers=8)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=8)
    test_dataset = dsets.CIFAR100(root=os.path.join(DATA_ROOT, "REDGE"), train=False, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=8)
    
    if num_or_rate == "rate" or num_or_rate is None:
        partialY, avgC = generate_instance_dependent_candidate_labels(ds, train_loader_for_partial_labels, trainlabels, device, _rate=partial_rate)
    if num_or_rate == "num":
        partialY, avgC = generate_instance_dependent_candidate_labels_num(ds, train_loader_for_partial_labels, trainlabels, device, _num=partial_num)
    print('Average candidate num: ', avgC)
    partial_matrix_dataset = CIFAR100_Augmentention(traindata, partialY.float(), trainlabels.float())
    # generate partial label dataset
    partial_matrix_train_loader = torch.utils.data.DataLoader(dataset=partial_matrix_dataset, 
                                                                batch_size=batch_size, 
                                                                shuffle=True, 
                                                                num_workers=8,
                                                                prefetch_factor=8*batch_size,
                                                                drop_last=True)
    dim = 32 * 32 * 3
    K = 100
    return_list = []
    return_list += [partial_matrix_train_loader, valid_loader, test_loader, dim, K]
    if has_eval_train_loader:
        eval_train_dataset = CIFAR100(traindata, partialY.float(), trainlabels.float())
        eval_train_loader = torch.utils.data.DataLoader(dataset=eval_train_dataset, 
                                                                batch_size=batch_size, 
                                                                shuffle=True, 
                                                                num_workers=8,
                                                                drop_last=True)
        return_list.append(eval_train_loader)
    if has_meta_valid:
        meta_full_valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=8, drop_last=False)
        for data, targets in meta_full_valid_loader:
            valdata, vallabels = data, targets.long()
        return_list.append(valdata)
        return_list.append(vallabels)

    return return_list


class CIFAR100_Augmentention(Dataset):
    def __init__(self, images, given_label_matrix, true_labels):
        self.images = images
        self.given_label_matrix = given_label_matrix
        # user-defined label (partial labels)
        self.true_labels = true_labels
        # PiCO augmentation
        # self.weak_transform = transforms.Compose(
        #     [
        #     transforms.ToPILImage(),
        #     transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        #     transforms.RandomHorizontalFlip(),
        #     transforms.RandomApply([
        #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        #     ], p=0.8),
        #     transforms.RandomGrayscale(p=0.2),
        #     transforms.ToTensor(), 
        #     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        # self.strong_transform = transforms.Compose(
        #     [
        #     transforms.ToPILImage(),
        #     transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        #     transforms.RandomHorizontalFlip(),
        #     RandomAugment(3, 5),
        #     transforms.ToTensor(), 
        #     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

        # PLCR
        self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ])
        self.weak_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='reflect'),
                    # transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
                    transforms.ToPILImage(),
                    CIFAR10Policy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
        self.strong_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='reflect'),
                    # transforms.ToTensor(),
                    Cutout(n_holes=1, length=16),
                    transforms.ToPILImage(),
                    CIFAR10Policy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])


    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        each_image_o = self.transform(self.images[index])
        each_image_w = self.weak_transform(self.images[index])
        each_image_s = self.strong_transform(self.images[index])
        each_image = [each_image_o, each_image_w, each_image_s]
        each_label = self.given_label_matrix[index]
        each_true_label = self.true_labels[index]
        
        return each_image, each_label, each_true_label, index

class CIFAR100(Dataset):
    def __init__(self, images, given_label_matrix, true_labels):
        self.images = images
        self.given_label_matrix = given_label_matrix
        self.true_labels = true_labels
        self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ])


    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        each_image_o = self.transform(self.images[index])
        each_label = self.given_label_matrix[index]
        each_true_label = self.true_labels[index]
        
        return each_image_o, each_label, each_true_label, index