from typing import Tuple
from collections import defaultdict, Counter

import numpy as np 
import torch.nn.functional as F
import torch.optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from backbone.ResNet18 import resnet18
from PIL import Image
from torchvision.datasets import CIFAR100

from datasets.transforms.denormalization import DeNormalize
from datasets.utils.continual_dataset import ContinualDataset
from datasets.utils.validation import get_train_val
from utils.conf import base_path


fine_dict = {19: 'cattle', 29: 'dinosaur', 0: 'apple', 11: 'boy', 1: 'aquarium_fish', 86: 'telephone', 90: 'train', 28: 'cup', 23: 'cloud', 31: 'elephant', 39: 'keyboard', 96: 'willow_tree', 82: 'sunflower', 17: 'castle', 71: 'sea', 8: 'bicycle', 97: 'wolf', 80: 'squirrel', 74: 'shrew', 59: 'pine_tree', 70: 'rose', 87: 'television', 84: 'table', 64: 'possum', 52: 'oak_tree', 42: 'leopard', 47: 'maple_tree', 65: 'rabbit', 21: 'chimpanzee', 22: 'clock', 81: 'streetcar', 24: 'cockroach', 78: 'snake', 45: 'lobster', 49: 'mountain', 56: 'palm_tree', 76: 'skyscraper', 89: 'tractor', 73: 'shark', 14: 'butterfly', 9: 'bottle', 6: 'bee', 20: 'chair', 98: 'woman', 36: 'hamster', 55: 'otter', 72: 'seal', 43: 'lion', 51: 'mushroom', 35: 'girl', 83: 'sweet_pepper', 33: 'forest', 27: 'crocodile', 53: 'orange', 92: 'tulip', 50: 'mouse', 15: 'camel', 18: 'caterpillar', 46: 'man', 75: 'skunk', 38: 'kangaroo', 66: 'raccoon', 77: 'snail', 69: 'rocket', 95: 'whale', 99: 'worm', 93: 'turtle', 4: 'beaver', 61: 'plate', 94: 'wardrobe', 68: 'road', 34: 'fox', 32: 'flatfish', 88: 'tiger', 67: 'ray', 30: 'dolphin', 62: 'poppy', 63: 'porcupine', 40: 'lamp', 26: 'crab', 48: 'motorcycle', 79: 'spider', 85: 'tank', 54: 'orchid', 44: 'lizard', 7: 'beetle', 12: 'bridge', 2: 'baby', 41: 'lawn_mower', 37: 'house', 13: 'bus', 25: 'couch', 10: 'bowl', 57: 'pear', 5: 'bed', 60: 'plain', 91: 'trout', 3: 'bear', 58: 'pickup_truck', 16: 'can'}
fine2corase_dict = {19: 11, 29: 15, 0: 4, 11: 14, 1: 1, 86: 5, 90: 18, 28: 3, 23: 10, 31: 11, 39: 5, 96: 17, 82: 2, 17: 9, 71: 10, 8: 18, 97: 8, 80: 16, 74: 16, 59: 17, 70: 2, 87: 5, 84: 6, 64: 12, 52: 17, 42: 8, 47: 17, 65: 16, 21: 11, 22: 5, 81: 19, 24: 7, 78: 15, 45: 13, 49: 10, 56: 17, 76: 9, 89: 19, 73: 1, 14: 7, 9: 3, 6: 7, 20: 6, 98: 14, 36: 16, 55: 0, 72: 0, 43: 8, 51: 4, 35: 14, 83: 4, 33: 10, 27: 15, 53: 4, 92: 2, 50: 16, 15: 11, 18: 7, 46: 14, 75: 12, 38: 11, 66: 12, 77: 13, 69: 19, 95: 0, 99: 13, 93: 15, 4: 0, 61: 3, 94: 6, 68: 9, 34: 12, 32: 1, 88: 8, 67: 1, 30: 0, 62: 2, 63: 12, 40: 5, 26: 13, 48: 18, 79: 13, 85: 19, 54: 2, 44: 15, 7: 7, 12: 9, 2: 14, 41: 19, 37: 9, 13: 18, 25: 6, 10: 3, 57: 4, 5: 6, 60: 10, 91: 1, 3: 8, 58: 18, 16: 3}
corase2fine_dict = defaultdict(list)
for k, v in fine2corase_dict.items():
    corase2fine_dict[v].append(k)

def get_superclass_label(label):
    return fine2corase_dict[label]

class TCIFAR100(CIFAR100):
    """Workaround to avoid printing the already downloaded messages."""
    def __init__(self, root, train=True, transform=None,
                 target_transform=None, download=False) -> None:
        self.root = root
        super(TCIFAR100, self).__init__(root, train, transform, target_transform, download=not self._check_integrity())

class MyCIFAR100(CIFAR100):
    """
    Overrides the CIFAR100 dataset to change the getitem function.
    """
    def __init__(self, root, train=True, transform=None,
                 target_transform=None, download=False) -> None:
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        self.root = root
        super(MyCIFAR100, self).__init__(root, train, transform, target_transform, not self._check_integrity())

    def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
        """
        Gets the requested element from the dataset.
        :param index: index of the element to be returned
        :returns: tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # to return a PIL Image
        img = Image.fromarray(img, mode='RGB')
        original_img = img.copy()

        not_aug_img = self.not_aug_transform(original_img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if hasattr(self, 'logits'):
            return img, target, not_aug_img, self.logits[index]

        return img, target, not_aug_img


class CIFAR20TrainDataset(Dataset):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """

    def __init__(self, data, targets, transform = None, target_transform = None) -> None:
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        self.data = data 
        self.targets = targets 
        self.transform = transform 
        self.target_transform = target_transform 


    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.uint8(255 * img))
        original_img = img.copy()

        not_aug_img = self.not_aug_transform(original_img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, get_superclass_label(target), not_aug_img

    def __len__(self):
        return len(self.data)

class CIFAR20TestDataset(Dataset):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """

    def __init__(self, data, targets, transform = None, target_transform = None) -> None:
        self.data = data 
        self.targets = targets 
        self.transform = transform 
        self.target_transform = target_transform 


    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.uint8(255 * img))
        original_img = img.copy()

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, get_superclass_label(target)

    def __len__(self):
        return len(self.data)

class GenCIFAR100(ContinualDataset):

    NAME = 'gen-cifar100'
    SETTING = 'general-continual'
    N_CLASSES = 20
    N_TASKS = 20
    N_CLASSES_PER_TASK = -1

    INPUT_SIZE = (3, 32, 32)
    TRANSFORM = transforms.Compose(
            [transforms.RandomCrop(32, padding=4),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize((0.5071, 0.4867, 0.4408),
                                  (0.2675, 0.2565, 0.2761))])

    def __init__(self, args):
        self.CLASSES_PER_TASK = [
            [t*(100//self.N_TASKS)+i for i in range(100//self.N_TASKS)] 
            for t in range(self.N_TASKS)]
        # for t in range(self.N_TASKS):
        #     print(t, sorted(Counter(fine2corase_dict[c] for c in self.CLASSES_PER_TASK[t]).items()))
        super(GenCIFAR100, self).__init__(args)
        transform = self.TRANSFORM

        test_transform = transforms.Compose(
            [transforms.ToTensor(), self.get_normalization_transform()])

        self.train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
                                  download=True, transform=transform)
        if self.args.validation:
            self.train_dataset, self.test_dataset = get_train_val(self.train_dataset,
                                                    test_transform, self.NAME)
        else:
            self.test_dataset = TCIFAR100(base_path() + 'CIFAR100',train=False,
                                   download=True, transform=test_transform)


    def get_examples_number(self):
        train_dataset = MyCIFAR100(base_path() + 'CIFAR10', train=True,
                                  download=True)
        return len(train_dataset.data)

    def get_data_loaders(self):
        return self.store_masked_loaders()

    def store_masked_loaders(self):
        train_mask = np.zeros_like(self.train_dataset.targets)
        test_mask = np.zeros_like(self.test_dataset.targets)
        for c in self.CLASSES_PER_TASK[self.i]:
            train_mask[np.array(self.train_dataset.targets) == c] = 1
            test_mask[np.array(self.test_dataset.targets) == c] = 1

        train_mask = train_mask == 1
        test_mask = test_mask == 1

        train_data = self.train_dataset.data[train_mask]
        test_data = self.test_dataset.data[test_mask]

        train_targets = np.array(self.train_dataset.targets)[train_mask]
        test_targets = np.array(self.test_dataset.targets)[test_mask]

        split_train_dataset = CIFAR20TrainDataset(train_data, train_targets, transform=self.train_dataset.transform)
        split_test_dataset = CIFAR20TestDataset(test_data, test_targets, transform=self.test_dataset.transform)

        train_loader = DataLoader(split_train_dataset,
                                batch_size=self.args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(split_test_dataset,
                                batch_size=self.args.batch_size, shuffle=False, num_workers=4)
        self.test_loaders.append(test_loader)
        self.train_loader = train_loader

        self.i += 1
        return train_loader, test_loader


    @staticmethod
    def get_transform():
        transform = transforms.Compose(
            [transforms.ToPILImage(), GenCIFAR100.TRANSFORM])
        return transform

    @staticmethod
    def get_aug_transform():
        transform = transforms.Compose(
            [transforms.ToPILImage(), GenCIFAR100.AUG_TRANSFORM])
        return transform

    @staticmethod
    def get_backbone():
        return resnet18(GenCIFAR100.N_CLASSES)

    @staticmethod
    def get_loss():
        return F.cross_entropy

    @staticmethod
    def get_normalization_transform():
        transform = transforms.Normalize((0.5071, 0.4867, 0.4408),
                                         (0.2675, 0.2565, 0.2761))
        return transform

    @staticmethod
    def get_denormalization_transform():
        transform = DeNormalize((0.5071, 0.4867, 0.4408),
                                (0.2675, 0.2565, 0.2761))
        return transform

    @staticmethod
    def get_epochs():
        return 50

    @staticmethod
    def get_batch_size():
        return 32

    @staticmethod
    def get_minibatch_size():
        return GenCIFAR100.get_batch_size()

    @staticmethod
    def get_scheduler(model, args):
        return None

