import torch
import torchvision
import torchvision.transforms as transforms
from utils import CVConfig
from imbalanced_cifar import ImbalanceCIFAR10, ImbalanceCIFAR100
import os
import pickle
from val_sampler import sample_val

PC = CVConfig()


def get_transform(dst_name, t_type):
    t = None
    if 'cifar' in dst_name:
        if '10' in dst_name:
            normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        elif '100' in dst_name:
            normalize = transforms.Normalize(mean= [0.5071, 0.4865, 0.4409],
                                             std=[0.2009, 0.1984, 0.2023])
        else:
            raise Exception('No transforms for '+ dst_name)

        if t_type == 'train':
            t = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ])
        elif t_type =='test':
            t = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        elif t_type == 'to_tensor':
            t = transforms.Compose([
                transforms.ToTensor()
            ])
        elif 'norm':
            t = transforms.Compose([
                normalize
            ])
    else:
        pass
    return t


def get_class_num_counts(dst_name, cls_num, imb_factor):
    max_num_dict = {'im_cifar10':5000, 'im_cifar100':500}
    img_max = max_num_dict[dst_name]
    class_num_dict = {}
    for cls_idx in range(cls_num):
        num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
        class_num_dict[cls_idx] = int(num)
    return class_num_dict


def get_dataset(dst_name, split, rand_number, is_wrapper=False, val_method=None, **kwargs):

    if dst_name == 'im_cifar10':
        if split == 'train':
            dst = ImbalanceCIFAR10(root=PC.get_cifar10_dataset_path(), train=True, download=True, imb_factor=0.01,
                                   rand_number=rand_number)
        elif split == 'val':
            if val_method == 'LZO':
                val_index_dict = sample_val(num_class=10, method='RANDOM',
                                            val_num_per_class=get_class_num_counts(dst_name, 10, 0.01),
                                            random_seed=rand_number)
            else:
                val_index_dict = sample_val(num_class=10, method=val_method,
                                            val_num_per_class=500 if 'val_num_per_class' not in kwargs else kwargs['val_num_per_class'],
                                            random_seed=rand_number) # v1 300
            dst = AugmentedDataset(PC.get_cifar10_data_pool_path(), index_dict=val_index_dict)
        else:
            dst = ImbalanceCIFAR10(root=PC.get_cifar10_dataset_path(), train=False, download=True, imb_factor=1.0,
                                   rand_number=rand_number)
    elif dst_name == 'im_cifar100':
        if split == 'train':
            dst = ImbalanceCIFAR100(root=PC.get_cifar100_dataset_path(), train=True, download=True,
                                               imb_factor=0.1, rand_number=rand_number)
        elif split == 'val':
            if val_method == 'LZO':
                val_index_dict = sample_val(num_class=100, method='RANDOM',
                                            val_num_per_class=get_class_num_counts(dst_name, 100, 0.1),
                                            random_seed=rand_number)
            else:
                val_index_dict = sample_val(num_class=100, method=val_method,
                                            val_num_per_class=30  if 'val_num_per_class' not in kwargs else kwargs['val_num_per_class'],
                                            random_seed=rand_number)
            dst = AugmentedDataset(PC.get_cifar100_data_pool_path(), index_dict=val_index_dict)
        else:
            dst = ImbalanceCIFAR100(root=PC.get_cifar100_dataset_path(), train=False, download=True,
                                              imb_factor=1.0, rand_number=rand_number)
            
    if is_wrapper:
        dst = DatasetWrapper(dst)

    return dst


from torch.utils.data import  Dataset
import numpy as np
from PIL import Image


class BaseDataset(Dataset):
    def __init__(self, transform=None):
        self.data = None
        self.targets = None
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.data)


class AugmentedDataset(BaseDataset):
    def __init__(self, file_path, transform=None, index_dict=None):
        super().__init__(transform)
        self.__read_data_from_path(file_path, index_dict)

    def __read_data_from_path(self, path, index_dict=None):
        folders = os.listdir(path)
        folders.sort(key = lambda x: int(x))
        data = []
        targets = []
        for folder in folders:
            files = os.listdir(os.path.join(path, folder))
            files.sort(key = lambda x: int(x.split('.')[0]))
            if index_dict is not None:
                indexes = index_dict[int(folder)]
                files = [str(ind) + '.jpg' for ind in indexes]
            for f in files:
                data.append(np.array(Image.open(os.path.join(os.path.join(path,folder), f))))
                targets.append(folder)
        self.data = np.stack(data, axis=0)
        self.targets = np.asarray(targets, dtype=int)


class DatasetWrapper(Dataset):
    def __init__(self, dataset, indexset:list=None):
        self.dataset = dataset
        self.indexset = indexset
        self.class_split_indexes = {}
        self.label_list = []
        self.__len__()
        self._set_class_split_indexes()
        self.print_class_proportion()

    def print_class_proportion(self):
        result = []
        keys = sorted(self.class_split_indexes.keys())
        for i in keys:
            result.append(len(self.class_split_indexes[i]))
        print(np.asarray(result))

    def _set_class_split_indexes(self):
        for i in self.indexset:
            _, label = self.dataset[i]
            if label not in self.class_split_indexes.keys():
                self.class_split_indexes[label] = []
            self.class_split_indexes[label].append(i)
            self.label_list.append(label)

    def __len__(self):
        if self.indexset is None:
            self.indexset = [i for i in range(len(self.dataset))]
        return len(self.indexset)

    def __getitem__(self, indx):
        return self.dataset.__getitem__(self.indexset[indx])

    def get_dataset_by_class(self, label):
        dst = BaseDataset(self.dataset.transform)
        dst.data = self.dataset.data[self.class_split_indexes[label]]
        dst.targets = [self.dataset.targets[i] for i in self.class_split_indexes[label]]
        return dst

    def get_dataset_by_indexes(self, indexes, transform=None):
        dst = BaseDataset(transform)
        dst.data = self.dataset.data[indexes]
        dst.targets = [self.dataset.targets[i] for i in indexes]
        return dst

    def get_label_list(self):
        return self.label_list
