from typing import Tuple, Any

import torch
from PIL import Image
from torchvision.datasets.cifar import CIFAR10, CIFAR100
import numpy as np
import math
from tools.get_path import *
import json

class ImbalanceCIFAR10(CIFAR10):
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False,
                 ):
        super().__init__(root, train, transform, target_transform, download)

        dataset_len = len(self.data)
        data_dic = {}
        for class_id in set(self.targets):
            data_dic[class_id] = []
        for i in range(dataset_len):
            class_id = self.targets[i]
            data_dic[class_id].append(self.data[i])

        with open(get_project_path()+"/saves/cifar10.json", "r") as f2:
            scalers = json.load(f2)

        num_per_class = {}
        for cls_id, v in data_dic.items():
            scale_factor = scalers[str(cls_id)]
            num = math.ceil(scale_factor * len(v))
            num_per_class[cls_id] = num

        for k, v in num_per_class.items():
            random_indexs = np.random.randint(low=0, high=len(data_dic[k]), size=v)
            data_dic[k] = np.stack(data_dic[k])
            data_dic[k] = data_dic[k][random_indexs]

        targets = []
        datas = []
        for k, v in data_dic.items():
            target_tmp = [k] * len(v)
            datas.extend(v)
            targets.extend(target_tmp)
        self.data = datas
        self.targets = targets


class ImbalanceCIFAR100(CIFAR100):
    def __init__(self, root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False, ):
        super().__init__(root=root,train=train,transform=transform,target_transform=target_transform,download=download)

        dataset_len = len(self.data)
        data_dic = {}
        for class_id in set(self.targets):
            data_dic[class_id] = []
        for i in range(dataset_len):
            class_id = self.targets[i]
            data_dic[class_id].append(self.data[i])

        with open(get_project_path() + "/saves/cifar100.json", "r") as f2:
            scalers = json.load(f2)

        num_per_class = {}
        for cls_id, v in data_dic.items():
            scale_factor = scalers[str(cls_id)]
            num = math.ceil(scale_factor * len(v))
            num_per_class[cls_id] = num

        for k, v in num_per_class.items():
            random_indexs = np.random.randint(low=0, high=len(data_dic[k]), size=v)
            data_dic[k] = np.stack(data_dic[k])
            data_dic[k] = data_dic[k][random_indexs]

        targets = []
        datas = []
        for k, v in data_dic.items():
            target_tmp = [k] * len(v)
            datas.extend(v)
            targets.extend(target_tmp)
        self.data = datas
        self.targets = targets


class ImbalenceTinyImageNet(TinyImagenet):
    def load_data(self):
        data = [[], []]

        classes = list(range(200))

        with open(get_project_path()+"/saves/tinyimagenet.json", "r") as f2:
            scalers = json.load(f2)

        for class_id in classes:
            class_name = self.id2label[class_id]

            if self.train:
                X = self.get_train_images_paths(class_name)
                Y = [class_id] * len(X)
            else:
                # test set
                X = self.get_test_images_paths(class_name)
                Y = [class_id] * len(X)

            sample_nums = math.ceil(np.round(scalers[str(class_id)] * len(Y), 2))
            random_indexs = np.random.randint(low=0, high=len(Y), size=sample_nums)

            X = [X[id] for id in random_indexs]
            Y = [Y[id] for id in random_indexs]

            data[0] += X
            data[1] += Y

        return data

