from collections import defaultdict
import numpy as np
import torchvision


class IMBALANCECIFAR100(torchvision.datasets.CIFAR100):
    cls_num = 100

    def __init__(self, root, imb_factor=None, rand_number=0, train=True,
                 transform=None, target_transform=None, download=True):
        super().__init__(root, train, transform, target_transform, download)

        if train and imb_factor is not None:
            np.random.seed(rand_number)
            img_num_list = self.get_img_num_per_cls(self.cls_num, imb_factor)
            self.gen_imbalanced_data(img_num_list)

        self.classnames = self.classes
        self.cls_num_list = self.get_cls_num_list()
        self.num_classes = len(self.cls_num_list)
        
        self.labels = self.targets                
        
    def get_img_num_per_cls(self, cls_num, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        # np.random.shuffle(classes)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
        
    def get_cls_num_list(self):
        counter = defaultdict(int)
        for target in self.targets:
            counter[target] += 1
        targets = list(counter.keys())
        targets.sort()
        cls_num_list = [counter[target] for target in targets]
        return cls_num_list
    
    def drop_class_samples(self, drop_shot):                   
        new_data = []
        new_targets = []
        for image, target in zip(self.data, self.targets):
            if self.cls_num_list[target] > drop_shot:
                new_data.append(image)
                new_targets.append(target)
        self.data = new_data
        self.targets = new_targets
        
    def drop_samples(self, max_shot=None):
        new_data = []
        new_targets = []
        current_num = [0] * self.num_classes
        for image, target in zip(self.data, self.targets):
            if current_num[target] < max_shot:
                new_data.append(image)
                new_targets.append(target)
                current_num[target] += 1
        self.data = new_data
        self.targets = new_targets


class CIFAR100(IMBALANCECIFAR100):
    def __init__(self, root, train=True, transform=None, drop_shot=0):
        super().__init__(root, imb_factor=None, train=train, transform=transform)
        if train:
            self.drop_class_samples(drop_shot=drop_shot)
            self.cls_num_list = self.get_cls_num_list()
            self.num_classes = len(self.cls_num_list)


class CIFAR100_IR10(IMBALANCECIFAR100):
    def __init__(self, root, train=True, transform=None, drop_shot=0):
        super().__init__(root, imb_factor=0.1, train=train, transform=transform)
        if train:
            self.drop_class_samples(drop_shot=drop_shot)
            self.cls_num_list = self.get_cls_num_list()
            self.num_classes = len(self.cls_num_list)


class CIFAR100_IR50(IMBALANCECIFAR100):
    def __init__(self, root, train=True, transform=None, drop_shot=0):
        super().__init__(root, imb_factor=0.02, train=train, transform=transform)
        if train:
            self.drop_class_samples(drop_shot=drop_shot)
            self.cls_num_list = self.get_cls_num_list()
            self.num_classes = len(self.cls_num_list)


class CIFAR100_IR100(IMBALANCECIFAR100):
    def __init__(self, root, train=True, transform=None, drop_shot=0):
        super().__init__(root, imb_factor=0.01, train=train, transform=transform)
        if train:
            self.drop_class_samples(drop_shot=drop_shot)
            self.cls_num_list = self.get_cls_num_list()
            self.num_classes = len(self.cls_num_list)
            
            
