from PIL import Image

import numpy as np
from numpy.random import multinomial, dirichlet

import torch
import torchvision.datasets as datasets

def get_label_proportion(num_bags=100, num_classes=10, seed=42):
    if seed is not None:
        np.random.seed(seed)
    proportion = np.random.rand(num_bags, num_classes)
    proportion /= proportion.sum(axis=1, keepdims=True)
    return proportion

def get_N_label_proportion(proportion, num_instances, num_classes, seed=42):
    if seed is not None:
        np.random.seed(seed)
    N = np.zeros(proportion.shape)
    for i in range(len(proportion)):
        p = proportion[i]
        for c in range(len(p)):
            if (c+1) != num_classes:
                num_c = int(np.round(num_instances*p[c]))
                if sum(N[i])+num_c >= num_instances:
                    num_c = int(num_instances-sum(N[i]))
            else:
                num_c = int(num_instances-sum(N[i]))

            N[i][c] = int(num_c)
        np.random.shuffle(N[i])
    print(N.sum(axis=0))
    print((N.sum(axis=1) != num_instances).sum())
    return N

def create_bags(label, num_bags, num_classes, num_instances, seed=42):
    if seed is not None:
        np.random.seed(seed)
    label = np.array(label)
    # make proportion
    proportion = get_label_proportion(num_bags, num_classes, seed)
    proportion_N = get_N_label_proportion(proportion, num_instances, num_classes, seed)
    # make index
    idx = np.arange(len(label))
    idx_c = []
    for c in range(num_classes):
        x = idx[label[idx] == c]
        np.random.shuffle(x)
        idx_c.append(x)
    bags_idx = []
    for n in range(len(proportion_N)):
        bag_idx = []
        for c in range(num_classes):
            sample_c_index = np.random.choice(idx_c[c], size=int(proportion_N[n][c]), replace=False)
            bag_idx.extend(sample_c_index)
        np.random.shuffle(bag_idx)
        bags_idx.append(bag_idx)
    
    return bags_idx

def make_bags(label, num_classes, num_instances, alg='diri', seed=42):
    np.random.seed(seed)

    data_idx = np.arange(len(label))
    class_idx = []
    for c in range(num_classes):
        x = data_idx[label[data_idx] == c]
        np.random.shuffle(x)
        class_idx.append(x)

    bags_idx = []
    while True:
        idx = []

        # make proportion
        if alg=='uniform':
            proportion_N = multinomial(num_instances, np.ones(num_classes)/num_classes)
        elif alg=='diri':
            proportion_N = multinomial(num_instances, dirichlet(tuple([1 for _ in range(num_classes)])))
        else:
            raise ValueError(alg)

        for c in range(num_classes):
            idx.extend(class_idx[c][: int(proportion_N[c])])
            class_idx[c] = class_idx[c][int(proportion_N[c]):]
        if len(idx) != num_instances:
            break
        np.random.shuffle(idx)
        bags_idx.append(idx)

    return bags_idx

class MNISTLLPLabels(datasets.MNIST):
    """MNIST Dataset with noisy labels.

    Args:
        noise_type (string): Noise type (default: 'symmetric').
            The value is either 'symmetric' or 'asymmetric'.
        noise_rate (float): Probability of label corruption (default: 0.0).
        seed (int): Random seed (default: 12345).
        
    This is a subclass of the `CIFAR10` Dataset.
    """

    def __init__(self,
                 num_instances=32,
                 alg='diri',
                 seed=42,
                 **kwargs):
        super(MNISTLLPLabels, self).__init__(**kwargs)
        self.seed = seed
        self.num_classes = 10
        self.bag_indices = make_bags(
            self.targets, 
            self.num_classes, 
            num_instances,
            alg,
            seed)

    def __len__(self):
        return len(self.bag_indices)

    def __getitem__(self, index: int):
        indices_bag = self.bag_indices[index]
        img_bag = torch.stack([self.get_image_by_idx(idx_inst) for idx_inst in indices_bag], dim=0)
        target_bag = torch.tensor([self.get_target_by_idx(idx_inst) for idx_inst in indices_bag])
        target_prop = torch.eye(self.num_classes)[target_bag].mean(dim=0)
        return img_bag, target_prop

    def get_image_by_idx(self, idx: int):
        img = self.data[idx]
        img = Image.fromarray(img.numpy(), mode="L")
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def get_target_by_idx(self, idx: int):
        target = int(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return target

class CIFAR10LLPLabels(datasets.CIFAR10):
    """CIFAR10 Dataset with noisy labels.

    Args:
        noise_type (string): Noise type (default: 'symmetric').
            The value is either 'symmetric' or 'asymmetric'.
        noise_rate (float): Probability of label corruption (default: 0.0).
        seed (int): Random seed (default: 12345).
        
    This is a subclass of the `CIFAR10` Dataset.
    """

    def __init__(self,
                 num_instances=32,
                 alg='diri',
                 seed=42,
                 **kwargs):
        super(CIFAR10LLPLabels, self).__init__(**kwargs)
        self.seed = seed
        self.num_classes = 10
        self.targets = torch.tensor(self.targets)

        self.bag_indices = make_bags(
            self.targets, 
            self.num_classes, 
            num_instances,
            alg,
            seed)
        
    def __len__(self):
        return len(self.bag_indices)
    
    def __getitem__(self, index: int):
        indices_bag = self.bag_indices[index]
        img_bag = torch.stack([self.get_image_by_idx(idx_inst) for idx_inst in indices_bag], dim=0)
        target_bag = self.targets[indices_bag]
        target_prop = torch.eye(self.num_classes)[target_bag].mean(dim=0)
        return img_bag, target_prop

    def get_image_by_idx(self, idx: int):
        img = self.data[idx]
        img = Image.fromarray(img, mode="RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def get_target_by_idx(self, idx: int):
        target = int(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return target

if __name__ == '__main__':
    import torchvision.transforms as transforms

    # MNISTLLPLabels
    mnist_llp = CIFAR10LLPLabels(
        root='./dataset',
        train=True,
        download=True,
        transform=transforms.ToTensor(),
        num_instances=32
    )

    img_bag, target_prop = mnist_llp[0]
    print("img_bag.shape:", img_bag.shape)
    print("target_prop:", target_prop)     