import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False, fix_number=None, imb_type='exp', stage_number=None, specific_classes=None, specific_number=None):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        self.class_idx_list = [None]*self.cls_num
        self.fix_number = fix_number
        self.imb_type = imb_type
        self.stage_number = stage_number
        self.specific_classes = specific_classes
        self.specific_number = specific_number
        self.img_num_list = self.get_img_num_per_cls(self.cls_num, imb_factor)
        self.gen_imbalanced_data(self.img_num_list)


    def get_img_num_per_cls(self, cls_num, imb_factor):
        if self.imb_type == 'exp':
            max_num = len(self.data) / cls_num
            img_max = len(self.data) / cls_num
            if self.fix_number is not None:
                img_max = self.fix_number / np.sum([imb_factor**(cls_idx / (cls_num - 1.0)) for cls_idx in range(cls_num)])
            if img_max > len(self.data) / cls_num:
                print(f"Cutting number: {img_max} to {max_num}")
                img_max = min(img_max, max_num)
            img_num_per_cls = []
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
            img_max = img_num_per_cls[0]
            if self.fix_number is not None:
                res = self.fix_number - np.sum(img_num_per_cls)
                begin_idx = 1 if img_max == max_num else 0
                end_idx = cls_num - 1
                while res > 0:
                    for cls_idx in range(begin_idx, end_idx):
                        if res > 0:
                            img_num_per_cls[cls_idx] += 1
                            res -= 1
                    end_idx -= 1
                    if img_num_per_cls[begin_idx] >= img_max:
                        begin_idx += 1
        elif self.imb_type == 'step' and self.fix_number == 10000 and self.stage_number in [1, 2, 4, 5]:
            if self.stage_number == 1:
                img_num_per_cls = [self.fix_number//self.cls_num]*self.cls_num
                return img_num_per_cls

            img_num_per_cls = []
            a = 2*self.fix_number // (self.stage_number*(self.stage_number-1))
            tmp_add_num = 0
            stage_cls_number = self.cls_num//self.stage_number
            for i in range(self.stage_number):
                img_num_per_cls += [tmp_add_num//stage_cls_number]*stage_cls_number
                tmp_add_num += a
        elif self.imb_type == 'specific_classes' and self.fix_number in [10000,]:
            #Note: Due to the integer division, the actual number of images may not be exactly the same as the fix_number
            if isinstance(self.specific_classes, int):
                self.specific_classes = [self.specific_classes,]
            if isinstance(self.specific_number, int):
                self.specific_number = [self.specific_number,]*len(self.specific_classes)
            remain_number = self.fix_number - sum(self.specific_number)
            img_num_per_cls = [remain_number//(self.cls_num-len(self.specific_number))]*self.cls_num
            for i, cls in enumerate(self.specific_classes):
                img_num_per_cls[cls] = self.specific_number[i]
            print(f"fix_number: {self.fix_number}, actual number: {sum(img_num_per_cls)}")
        else:
            raise ValueError(f"Not Implemented yet")
        
        # assert sum(img_num_per_cls) == self.fix_number


        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        self.class_start_idx = [0,]
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        # np.random.shuffle(classes)
        self.num_per_cls_dict = dict()
        debug_list = []
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]

            self.class_idx_list[the_class] = selec_idx
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)

            debug_list.append((len(selec_idx), the_img_num))
            self.class_start_idx.append(self.class_start_idx[-1] + the_img_num)

        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        self.debug_list = debug_list
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

class IMBALANCECIFAR100(IMBALANCECIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    cls_num = 100


class SUBSETCIFAR10(torchvision.datasets.CIFAR10):
    # get a subset of CIFAR10, each class has ipc images
    cls_num = 10

    def __init__(self, root, ipc, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(SUBSETCIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = [ipc, ] * self.cls_num
        self.gen_imbalanced_data(img_num_list)

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        # np.random.shuffle(classes)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

class SUBSETCIFAR100(SUBSETCIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    cls_num = 100

if __name__ == '__main__':
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = IMBALANCECIFAR100(root='./data', train=True,
                    download=True, transform=transform)
    trainloader = iter(trainset)
    data, label = next(trainloader)
    import pdb; pdb.set_trace()

class SUBSETIMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False, ipc=np.inf):
        super(SUBSETIMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
        for i in range(len(img_num_list)):
            if img_num_list[i] > ipc:
                img_num_list[i] = ipc
        self.ipc = ipc
        self.gen_imbalanced_data(img_num_list)

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        # np.random.shuffle(classes)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            if len(idx) < the_img_num:
                # copy idx the_img_num//len(idx)+1 times
                idx = np.tile(idx, the_img_num//len(idx)+1)
                idx = np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

class SUBSETIMBALANCECIFAR100(SUBSETIMBALANCECIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    cls_num = 100