from math import ceil
from PIL.Image import BICUBIC
from PIL import Image
from torchvision.datasets.cifar import CIFAR100, CIFAR10
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine
from torchvision.transforms import ToTensor, Normalize

from torch.utils.data import Subset,Dataset, Sampler
import torchvision.utils as vutils
import random
from torch.utils.data import DataLoader
import numpy as np
import random

class BalancedSampler(Sampler):
    def __init__(self, buckets, retain_epoch_size=False):
        for bucket in buckets:
            random.shuffle(bucket)

        self.bucket_num = len(buckets)
        self.buckets = buckets
        self.bucket_pointers = [0 for _ in range(self.bucket_num)]
        self.retain_epoch_size = retain_epoch_size
    
    def __iter__(self):
        count = self.__len__()
        while count > 0:
            yield self._next_item()
            count -= 1

    def _next_item(self):
        bucket_idx = random.randint(0, self.bucket_num - 1)
        bucket = self.buckets[bucket_idx]
        item = bucket[self.bucket_pointers[bucket_idx]]
        self.bucket_pointers[bucket_idx] += 1
        if self.bucket_pointers[bucket_idx] == len(bucket):
            self.bucket_pointers[bucket_idx] = 0
            random.shuffle(bucket)
        return item

    def __len__(self):
        if self.retain_epoch_size:
            return sum([len(bucket) for bucket in self.buckets]) # Acrually we need to upscale to next full batch
        else:
            return max([len(bucket) for bucket in self.buckets]) * self.bucket_num # Ensures every instance has the chance to be visited in an epoch

def load_cifar10(r=None,train_size=4000,train_rho=0.01,val_size=1000,val_rho=0.01,image_size=32,batch_size=128,num_workers=4,path='./data',num_classes=10,balance_val=False,noise_mode='uniform'):
    print(r)
    train_transform = Compose([
        RandomCrop(32,padding=4),
        #Resize(image_size, BICUBIC),
        #RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
    ])

    test_transform = Compose([
        #Resize(image_size, BICUBIC),    
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
    ])

    train_dataset = CIFAR10(root=path, train=True, transform=train_transform, download=True)
    test_dataset = CIFAR10(root=path, train=False, transform=test_transform, download=True)
    train_x,train_y = np.array(train_dataset.data), np.array(train_dataset.targets)
    #test_x, test_y = test_dataset.data, test_dataset.targets
    total_size=5000
    num_total_samples=[]
    num_train_samples=[]
    num_val_samples=[]

    if not balance_val:
        train_mu=train_rho**(1./9.)
        val_mu=val_rho**(1./9.)
        for i in range(num_classes):
            num_total_samples.append(ceil(total_size*(train_mu**i)))
            num_train_samples.append(ceil(train_size*(train_mu**i)))
            num_val_samples.append(ceil(val_size*(val_mu**i)))
            #num_val_samples.append(num_total_samples[-1]-num_train_samples[-1])
            #num_val_samples.append(round(val_size*(val_mu**i)))
    elif balance_val:
        train_mu=train_rho**(1./9.)
        for i in range(num_classes):
            num_val_samples.append(val_size)
            num_total_samples.append(ceil(total_size*(train_mu**i)))
            num_train_samples.append(ceil(train_size*(train_mu**i)))
            #num_train_samples.append(num_total_samples[-1]-num_val_samples[-1])

    train_index=[]
    val_index=[]
    #print(train_x,train_y)
    #print(num_train_samples,num_val_samples)
    for i in range(num_classes):
        train_index.extend(np.where(train_y==i)[0][:num_train_samples[i]])
        val_index.extend(np.where(train_y==i)[0][-num_val_samples[i]:])
    
    total_index=[]
    total_index.extend(train_index)
    total_index.extend(val_index)
    total_index=list(set(total_index))
    random.shuffle(total_index)
    train_x, train_y=train_x[total_index], train_y[total_index]

    train_index=[]
    val_index=[]
    #print(train_x,train_y)
    print(num_train_samples,num_val_samples)
    for i in range(num_classes):
        train_index.extend(np.where(train_y==i)[0][:num_train_samples[i]])
        val_index.extend(np.where(train_y==i)[0][-num_val_samples[i]:])
    random.shuffle(train_index)
    random.shuffle(val_index)
    
    train_data,train_targets=train_x[train_index],train_y[train_index]
    val_data,val_targets=train_x[val_index],train_y[val_index]
    true_train_new_num = []
    true_train_num = []
    if r is not None:
        new_val_targets = noisy_label(r,val_targets,num_classes,noise_mode,num_val_samples,val_rho)
        new_train_targets = noisy_label(r,train_targets,num_classes,noise_mode,num_train_samples,train_rho)
        
        #new label, contain true label
        for i in range(num_classes):
            train_index_new_i = np.where(new_train_targets==i)[0]
            true_train_new_num.append(np.sum((train_targets[train_index_new_i] == i)))
        print('true_train_new_num',true_train_new_num)
        #old label, contain true label

        for i in range(num_classes):
            train_index_i = np.where(train_targets==i)[0]
            #print((new_train_targets[train_index_i] == i)[0])
            true_train_num.append(np.sum((new_train_targets[train_index_i] == i)))
        print('true_train_num',true_train_num)
        train_targets = new_train_targets
        val_targets = new_val_targets

    num_class_i = np.zeros([num_classes])
    for i in range(np.shape(train_targets)[0]):
        num_class_i[train_targets[i]] = num_class_i[train_targets[i]]+1
    print('train class num',num_class_i)


    train_dataset = CustomDataset(train_data,train_targets,train_transform)
    val_dataset = CustomDataset(val_data,val_targets,train_transform)
    train_eval_dataset = CustomDataset(train_data,train_targets,test_transform)
    val_eval_dataset = CustomDataset(val_data,val_targets,test_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=True, drop_last=False, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=True, drop_last=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=False, drop_last=False, pin_memory=True)

    eval_train_loader = DataLoader(train_eval_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=False, drop_last=False, pin_memory=True)
    eval_val_loader = DataLoader(val_eval_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=False, drop_last=False, pin_memory=True)

    return train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples,true_train_new_num,true_train_num

class CustomDataset(Dataset):
    """CustomDataset with support of transforms.
    """
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target
    def __len__(self):
        return len(self.data)
#load_cifar10()
def noisy_label(r,sample_targets,num_classes,noise_mode,total_sample,rho):
    sample_targets_new = sample_targets.copy()
    #print(num_classes)
    for i in range(num_classes):
        r_i = r[i]
        index_i = np.where(sample_targets==i)[0]
        random.shuffle(index_i)
        num_noise = int(len(index_i)*r[i])
        #print(num_noise)
        #print('num_noise',num_noise)
        if num_noise>0:
            noise_index_i = index_i[:num_noise]
            if noise_mode == 'uniform':
                if rho != 1:
                    for j in range(np.shape(noise_index_i)[0]):
                        #print('previous',sample_targets_new[noise_index_i[j]])
                        new_targets = i
                        while new_targets == i:
                                new_targets_ratio = random.uniform(0,1)
                                #print(new_targets_ratio)
                                for k in range(num_classes):
                                    if k > 0:
                                        ratio_low = np.sum(total_sample[0:k])/np.sum(total_sample)
                                    else: ratio_low = 0
                                    if k<num_classes:
                                        ratio_high = np.sum(total_sample[0:k+1])/np.sum(total_sample)
                                    else:
                                        ratio_high = 1
                                    #print('low',ratio_low)
                                    #print('high',ratio_high)
                                    if new_targets_ratio <ratio_high and new_targets_ratio>=ratio_low:
                                        new_targets = k
                                        #print('new_targets',new_targets)
                                        break
                        #print(sample_targets_new[noise_index_i[j]])
                        sample_targets_new[noise_index_i[j]] = new_targets
                        #print(sample_targets_new[noise_index_i[j]])
                        #print('after',sample_targets_new[noise_index_i[j]])
                else:
                        for j in range(np.shape(noise_index_i)[0]):
                            new_targets = i
                            while new_targets == i:
                                    new_targets = random.randint(0,num_classes-1)
                            #print(sample_targets_new[noise_index_i[j]])
                            sample_targets_new[noise_index_i[j]] = new_targets
                            #print(sample_targets_new[noise_index_i[j]])
                            #print('after',sample_targets_new[noise_index_i[j]])
            else:
                index_start = 0
                frac_noisy = (np.array(total_sample)*np.array(r))/np.sum(np.array(total_sample)*np.array(r))

                for j in range(num_classes):
                    sample_targets_new[index_start:index_start+int(frac_noisy[j]*num_noise)] == int(j)
                    index_start = index_start+int(frac_noisy[j]*num_noise)
    return sample_targets_new

