import os
import pickle
import numpy as np
import torch
import PIL
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import Flowers102
from PIL import Image
from .randaugment import RandomAugment
from sklearn.preprocessing import OneHotEncoder
from .utils_algo import *
import copy

def load_flower102(input_size, partial_rate, noisy_rate, batch_size, hierarchical=False):

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    test_transform = transforms.Compose(
        [
        transforms.Resize(int(input_size/0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

    temp_train = Flowers102(root='/Projects/data/flower102', split ='train', download=True)
    temp_val = Flowers102(root='/Projects/data/flower102', split ='val', download=True)
    data, labels = temp_train._image_files + temp_val._image_files, torch.Tensor(temp_train._labels + temp_val._labels).long()
    
    if hierarchical:
        partialY = generate_hierarchical_cv_candidate_labels('cifar100', labels, partial_rate, noisy_rate=noisy_rate)
    else:
        partialY = generate_uniform_cv_candidate_labels(labels, partial_rate, noisy_rate=noisy_rate)
        
    train_dataset = Flower102_Augmentention(data, partialY, labels)

    test_dataset = Flowers102(root='/Projects/data/flower102',
                               download=True, 
                               split='test', 
                                transform=test_transform)

    partialY = train_dataset.partial_label_matrix
    print('Average candidate num: ', partialY.sum(1).mean())

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size*4, shuffle=False, num_workers=4,
        sampler=torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False))
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    
    partial_matrix_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
        batch_size=batch_size, 
        shuffle=(train_sampler is None), 
        num_workers=4,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)
    return partial_matrix_train_loader, partialY, train_sampler, test_loader


class Flower102_Augmentention(Dataset):
    def __init__(self, images, partial_label_matrix, true_labels):
        self.images = images
        self.partial_label_matrix = partial_label_matrix
        self.true_labels = true_labels
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        input_size = 224
        self.weak_transform = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])
        self.strong_transform = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                RandomAugment(3, 5),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
                ])

    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        image_file = self.images[index]
        image = PIL.Image.open(image_file).convert("RGB")
        image_w = self.weak_transform(image)
        image_s = self.strong_transform(image)
        partial_label = self.partial_label_matrix[index]
        true_label = self.true_labels[index]
        return image_w, image_s, partial_label, true_label, index


def binarize_class(y):  
    label = y.reshape(len(y), -1)
    enc = OneHotEncoder(categories='auto') 
    enc.fit(label)
    label = enc.transform(label).toarray().astype(np.float32)     
    label = torch.from_numpy(label)
    return label

if __name__ == '__main__':
    pass
