import numpy as np
from PIL import Image
import torch

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler


def subset_cifar100(cifar100_dataset):
    # keep = [72, 4, 95, 30, 55, 73, 32, 67, 91, 1, 92, 70, 82, 54, 62]
    # keep = [72, 4, 95, 30, 55, 73, 32, 67, 91]

    keep = [72, 4, 95, 30, 55]

    data = cifar100_dataset.data
    labels = np.array(cifar100_dataset.targets)

    mask = np.isin(labels, keep)
    labels = labels[mask]
    data = data[mask]

    cifar100_dataset.data = data
    cifar100_dataset.targets = labels
    return cifar100_dataset


def subset_cifar10(cifar100_dataset):
    # keep = [72, 4, 95, 30, 55, 73, 32, 67, 91, 1, 92, 70, 82, 54, 62]
    # keep = [72, 4, 95, 30, 55, 73, 32, 67, 91]

    keep = [0, 1, 8, 9]

    data = cifar100_dataset.data
    labels = np.array(cifar100_dataset.targets)

    mask = np.isin(labels, keep)
    labels = labels[mask]
    data = data[mask]

    cifar100_dataset.data = data
    cifar100_dataset.targets = labels
    return cifar100_dataset


def pretrain_remap(cifar100_dataset):
    keep = [72, 4, 95, 30, 55, 73, 32, 67, 91, 1]
    mapper = {keep[idx]: idx for idx in range(len(keep))}

    data = cifar100_dataset.data
    labels = np.array(cifar100_dataset.targets)

    mask = np.isin(labels, keep)
    labels = labels[mask]
    data = data[mask]

    new_labels = []
    for i in labels:
        new_labels.append(mapper[i])

    cifar100_dataset.data = data
    cifar100_dataset.targets = new_labels
    return cifar100_dataset


# def superclass_map():
#     coarse2fine = {
#         0: [72, 4, 95, 30, 55],
#         1: [73, 32, 67, 91, 1],
#         2: [92, 70, 82, 54, 62],
#         3: [16, 61, 9, 10, 28],
#         4: [51, 0, 53, 57, 83],
#         5: [40, 39, 22, 87, 86],
#         6: [20, 25, 94, 84, 5],
#         7: [14, 24, 6, 7, 18],
#         8: [43, 97, 42, 3, 88],
#         9: [37, 17, 76, 12, 68],
#         10: [49, 33, 71, 23, 60],
#         11: [15, 21, 19, 31, 38],
#         12: [75, 63, 66, 64, 34],
#         13: [77, 26, 45, 99, 79],
#         14: [11, 2, 35, 46, 98],
#         15: [29, 93, 27, 78, 44],
#         16: [65, 50, 74, 36, 80],
#         17: [56, 52, 47, 59, 96],
#         18: [8, 58, 90, 13, 48],
#         19: [81, 69, 41, 89, 85]
#     }

#     fine2coarse = {}
#     for k, v in coarse2fine.items():
#         for ix in v:
#             fine2coarse[ix] = k
#     return coarse2fine, fine2coarse


def superclass_map():
# classes = ('plane', 'car', 'bird', 'cat',
        #    'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    coarse2fine = {
        0: [0, 1, 8, 9],
        1: [2, 3, 4, 5, 6, 7]
    }

    fine2coarse = {}
    for k, v in coarse2fine.items():
        for ix in v:
            fine2coarse[ix] = k
    return coarse2fine, fine2coarse


# def examplar_classes():
#     examplars = [72, 73, 92, 16, 51, 40, 20, 14,
#                 43, 37, 49, 15, 75, 77, 11, 29,
#                 65, 56, 8, 81]
#     return examplars

def examplar_classes():
    examplars = [1, 4]
    return examplars


class ClustCIFAR(Dataset):
    """
    Train: For each sample creates randomly a pair of images
    Test: Creates fixed image pairs for testing
    
    The target labels are generated from the original image pair's
    labels using the divergence_fn
    """

    def __init__(self, cifar_dataset, dist_params={}):

        if not dist_params:
            self.dist_params = {
                'same': 0.5,
                'same_super': 8,
                'to_super': 2,
                'between_super': 8,
                'diff_super': 16,
                'to_diff_super': 16
            }
        else:
            self.dist_params = dist_params

        self.cifar_dataset = cifar_dataset

        self.train = self.cifar_dataset.train
        self.transform = self.cifar_dataset.transform

        self.coarse2fine, self.fine2coarse = superclass_map()
        self.examplars = examplar_classes()

        if self.train:
            self.targets = torch.tensor(self.cifar_dataset.targets)
            self.data = self.cifar_dataset.data
        
        else:
            self.targets = torch.tensor(self.cifar_dataset.targets)
            self.data = self.cifar_dataset.data
            
            random_state = np.random.RandomState(111)
            all_index2 = random_state.permutation(len(self.data))
    
            pair_idx = [(i, all_index2[i]) for i in range(len(self.data))]
            # from pdb import set_trace; set_trace()
            pair_labels = [(self.targets[pair[0]].item(), self.targets[pair[1]].item())
                           for pair in pair_idx]
            div_target = [self.class2dist(*pair) for pair in pair_labels]
            
            self.test_pairs = pair_idx
            self.div_target = div_target

    def class2dist(self, label1, label2):

        super1 = self.fine2coarse[label1]
        super2 = self.fine2coarse[label2]

        # same class
        if label1 == label2:
            return self.dist_params['same']

        # both examplars, diff class
        if label1 in self.examplars and label2 in self.examplars:
            return self.dist_params['between_super']

        # one is an examplar
        if label1 in self.examplars or label2 in self.examplars:
            # same superclass, one examplar
            if super1 == super2:
                return self.dist_params['to_super']
            # diff superclass, one examplar
            else:
                return self.dist_params['to_diff_super']

        # neither examplar
        if super1 == super2:
            return self.dist_params['same_super']
        else:
            return self.dist_params['diff_super']

    def __getitem__(self, index):
        if self.train:
            index2 = np.random.choice(len(self.data))

            img1, label1 = self.data[index], self.targets[index].item()
            img2, label2 = self.data[index2], self.targets[index2].item()
            target = self.class2dist(label1, label2)
        else:
            index1, index2 = self.test_pairs[index][0], self.test_pairs[index][1]
            img1, img2 = self.data[index1], self.data[index2]
            target = self.div_target[index]
        
        img1 = Image.fromarray(img1)
        img2 = Image.fromarray(img2)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, target

    def __len__(self):
        return len(self.cifar_dataset)
