import numpy as np
from PIL import Image
import torchvision

def get_cifar10(root, args, train=True,
                 transform_train=None, transform_val=None,
                 download=False):

    base_dataset = torchvision.datasets.CIFAR10(root, train=train)
    train_dataset = CIFAR10_train(root, args, train=train, transform=transform_train)
    train_dataset.label_real_init()
    #train_dataset.label_init_mid()
    val_dataset = torchvision.datasets.CIFAR10(root, train=False, transform=transform_val)

    print (f"Train: {len(train_dataset)} Val: {len(val_dataset)}")
    return train_dataset, val_dataset

class CIFAR10_train(torchvision.datasets.CIFAR10):
    def __init__(self, root, args=None, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_train, self).__init__(root, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        self.args = args
        self.train_data = self.train_data
        self.train_labels = np.array(self.train_labels)
        self.train_labels_real = np.zeros((len(self.train_data)))
        self.soft_labels = np.zeros((len(self.train_data), 10), dtype=np.float32)
        self.prediction = np.zeros((len(self.train_data), 1, 10), dtype=np.float32)
        self.count = 0

    def label_real_init(self):
        self.train_labels_real = np.array(self.train_labels)


    def label_init(self):
        for i in range(len(self.train_data)//self.args.batch_size):
            self.train_labels[i*self.args.batch_size:(i+1)*self.args.batch_size - 1] = np.random.permutation(self.train_labels[i*self.args.batch_size:(i+1)*self.args.batch_size - 1])
        #self.train_labels[(i+1)*self.args.batch_size,:] = np.random.permutation(self.train_labels[(i+1)*self.args.batch_size,:])

        indices = np.random.permutation(len(self.train_data))
        for i, idx in enumerate(indices):
            self.soft_labels[idx][self.train_labels[idx]] = 1.

    def label_init_mid(self):
        self.soft_labels = np.loadtxt('soft_labels.txt', dtype=np.float32)
        self.train_labels = np.argmax(self.soft_labels, axis=1).astype(np.int64)


    def label_update(self, results):
        self.count += 1

        # While updating the noisy label y_i by the probability s, we used the average output probability of the network of the past 10 epochs as s.
        idx = (self.count - 1) % 1
        self.prediction[:, idx] = results
        self.soft_labels = self.prediction.mean(axis=1)
        self.train_labels = np.argmax(self.soft_labels, axis=1).astype(np.int64)

        if self.count == self.args.begin_first - 1:
            np.savetxt('soft_labels_64_gan',self.soft_labels)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target, target_real, soft_target = self.train_data[index], self.train_labels[index], self.train_labels_real[index], self.soft_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, target_real, soft_target, index