from PIL import Image
from itertools import chain
from copy import deepcopy
import numpy as np
import torch
import clip
from torch.utils.data.dataset import Subset
from torchvision.datasets import CIFAR100
from custom_imagefolder import ImageFolder

device = "cuda" if torch.cuda.is_available() else "cpu"
_, TRANSFORM = clip.load('ViT-B/32', device)


class CIFAR100_(CIFAR100):
    """
    Modified __getitem__
    """
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR100, self).__init__(root, train, transform, target_transform, download)
        self.full_labels = None # Not really necessary for the experiments

        self.targets = np.array(self.targets)

    def __getitem__(self, index):
        img, target, names = self.data[index], self.targets[index], self.names[index]

        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        full_labels = self.full_labels[index]

        return img, target, full_labels, names


class CIFAR10_(CIFAR100_):
    base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR100, self).__init__(root, train, transform, target_transform, download)
        self.full_labels = None # None -> np.array for concept shift

        self.targets = np.array(self.targets)

def generate_random_cl(args):
    if args.dataset == 'cifar100':
        total_n_cls = 100
        from label_names import fine_label as fine_label
    elif args.dataset == 'cifar10':
        total_n_cls = 10
        from label_names import cifar10_labels as fine_label
    elif args.dataset == 'imagenet':
        total_n_cls = 1000
        fine_label = [str(i) for i in range(1000)]
    else:
        raise NotImplementedError("dataset not implemented")
    n_cls = total_n_cls // args.n_tasks # number of classes per task

    seq = np.arange(total_n_cls)
    if args.seed != 0:
        np.random.shuffle(seq)

    seq = seq.reshape(args.n_tasks, n_cls)

    task_list = []
    for t in seq:
        names_list, sub_cls_list = [], []
        for c in t:
            name = fine_label[c]
            names_list.append(name)
            sub_cls_list.append(c)
        task_list.append([names_list, sub_cls_list])
    return task_list

class StandardCL:
    def __init__(self, dataset, args, task_list):
        self.dataset = dataset
        self.args = args
        self.seen_names = []
        self.task_id = 0
        self.task_list = task_list

        self.validation = args.validation

        if args.dataset == 'imagenet':
            assert not args.clip_init
            self.dataset.targets = np.array(self.dataset.targets)
            self.dataset.full_labels = np.concatenate((self.dataset.targets.reshape(-1, 1),
                                                        self.dataset.targets.reshape(-1, 1)), 1)
            self.dataset.names = self.dataset.targets.tolist()
            self.dataset.targets_relabeled = self.dataset.targets.copy()

    def make_dataset(self):
        dataset_ = deepcopy(self.dataset)

        if self.validation is not None:
            dataset_valid = deepcopy(self.dataset)

        targets_aux, targets_aux_valid = [], []
        data_aux, data_aux_valid = [], []
        full_targets_aux, full_targets_aux_valid = [], []
        names_aux, names_aux_valid = [], []

        cls_names =  self.task_list[self.task_id][0]
        cls_ids = self.task_list[self.task_id][1]
        idx_list, idx_list_valid = [], [] # These are used for ImageNet
        for i, (name, c) in enumerate(zip(cls_names, cls_ids)):
            if name not in self.seen_names:
                self.seen_names.append(name)
            idx = np.where(self.dataset.targets == c)[0]

            if self.validation is not None:
                np.random.shuffle(idx)
                n_samples = len(idx)
                idx_valid = idx[int(n_samples * self.validation):]
                idx = idx[:int(n_samples * self.validation)]

            idx_list.append(idx)
            if self.validation is not None: idx_list_valid.append(idx_valid)

            if self.args.dataset == 'cifar100' or self.args.dataset == 'cifar10':
                data_aux.append(self.dataset.data[idx])
                targets_aux.append(np.zeros(len(idx), dtype=np.int) + self.seen_names.index(name))
                full_targets_aux.append([[self.seen_names.index(name),
                                          self.seen_names.index(name)] for _ in range(len(idx))])
                names_aux.append([name for _ in range(len(idx))])

                if self.validation is not None:
                    data_aux_valid.append(self.dataset.data[idx_valid])
                    targets_aux_valid.append(np.zeros(len(idx_valid), dtype=np.int) + self.seen_names.index(name))
                    full_targets_aux_valid.append([[self.seen_names.index(name),
                                              self.seen_names.index(name)] for _ in range(len(idx_valid))])
                    names_aux_valid.append([name for _ in range(len(idx_valid))])

            elif self.args.dataset == 'imagenet':
                for i in idx:
                    self.dataset.names[i] = name
                    self.dataset.targets_relabeled[i] = self.seen_names.index(name)
                    self.dataset.full_labels[i] = np.zeros(2) + self.dataset.targets_relabeled[i]
                if self.validation is not None:
                    for i in idx_valid:
                        self.dataset.names[i] = name
                        self.dataset.targets_relabeled[i] = self.seen_names.index(name)
                        self.dataset.full_labels[i] = np.zeros(2) + self.dataset.targets_relabeled[i]

        if self.args.dataset == 'cifar100' or self.args.dataset == 'cifar10':
            dataset_.data = np.array(list(chain(*data_aux)))
            dataset_.targets = np.array(list(chain(*targets_aux)))
            dataset_.full_labels = np.array(list(chain(*full_targets_aux)))
            dataset_.names = list(chain(*names_aux))
            del data_aux, targets_aux, full_targets_aux, names_aux

            if self.validation is not None:
                dataset_valid.data = np.array(list(chain(*data_aux_valid)))
                dataset_valid.targets = np.array(list(chain(*targets_aux_valid)))
                dataset_valid.full_labels = np.array(list(chain(*full_targets_aux_valid)))
                dataset_valid.names = list(chain(*names_aux_valid))
                del data_aux_valid, targets_aux_valid, full_targets_aux_valid, names_aux_valid

        elif self.args.dataset == 'imagenet':
            idx_list = np.concatenate(idx_list)
            dataset_ = Subset(self.dataset, idx_list)

            if self.validation is not None:
                idx_list_valid = np.concatenate(idx_list_valid)
                dataset_valid = Subset(self.dataset, idx_list_valid)

        self.task_id += 1

        if self.validation is None:
            return dataset_
        else:
            self.args.logger.print(f"******* Validation {self.validation} used *******")
            return dataset_, dataset_valid

def get_data(args):
    if args.dataset == 'cifar100':
        train = CIFAR100_(root=args.root, train=True, download=True, transform=TRANSFORM)
        test  = CIFAR100_(root=args.root, train=False, download=True, transform=TRANSFORM)
    elif args.dataset == 'cifar10':
        train = CIFAR10_(root=args.root, train=True, download=True, transform=TRANSFORM)
        test  = CIFAR10_(root=args.root, train=False, download=True, transform=TRANSFORM)
    elif args.dataset == 'imagenet':
        train = ImageFolder(root=args.root + '/ImageNet/train', transform=TRANSFORM)
        test = ImageFolder(root=args.root + '/ImageNet/val', transform=TRANSFORM)

    if args.validation and 'cifar' in args.dataset:
        pass
    elif args.validation and args.dataset == 'imagenet':
        test = ImageFolder(root=args.root + '/ImageNet/train', transform=TRANSFORM)
    return train, test

