# Thank the authors of Co^2L
# The github address is https://github.com/chaht01/Co2L
# Our code is widely adapted from their repositories.

import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset, WeightedRandomSampler, Dataset
import random
import math
from PIL import Image
from dataset.load_tiny_imagenet import get_tiny_imagenet_data, get_tiny_imagenet_test_data

class TwoCropTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

def infinitify(loader):
    while True:
        for x, y in loader:
            yield x, y

class Custom_Dataset(Dataset):

    def __init__(self, dataset, targets, transform = None):

        self.dataset = dataset
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, ind):
        
        image = self.dataset[ind]
        image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        target = self.targets[ind].item()
        
        return image, target

def set_replay_samples(args, task_number, prev_indexes = None):

    val_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.labeled_dataset_name == 'cifar10':
        val_dataset = datasets.CIFAR10(root="data/cifar10",
                                         transform=val_transform,
                                         download=True)
        val_targets = np.array(val_dataset.targets)

    elif args.labeled_dataset_name == 'cifar100':
        val_dataset = datasets.CIFAR100(root="data/cifar100",
                                         transform=val_transform,
                                         download=True)
        val_targets = np.array(val_dataset.targets)

    elif args.labeled_dataset_name == 'tinyimagenet':
        
        val_targets = np.array(torch.arange(200).repeat_interleave(500))

    else:
        raise ValueError('dataset not supported: {}'.format(args.labeled_dataset_name))

    if prev_indexes is None:
        prev_indexes = []
        observed_classes = list(range(0, task_number * args.class_per_task))
    else:
        if args.labeled_dataset.num_labeled_per_class * args.class_per_task <= (args.memory_size - len(prev_indexes)):
            shrink_size = len(prev_indexes)
        else:
            shrink_size = (task_number - 1) * args.memory_size / task_number
        if len(prev_indexes) > 0:
            unique_cls = np.unique(val_targets[prev_indexes])
            _prev_indexes = prev_indexes
            prev_indexes = []

            for c in unique_cls:
                mask = val_targets[_prev_indexes] == c
                size_for_c = shrink_size / len(unique_cls)
                p = size_for_c - (shrink_size // len(unique_cls))
                if random.random() < p:
                    size_for_c = math.ceil(size_for_c)
                else:
                    size_for_c = math.floor(size_for_c)

                prev_indexes += torch.tensor(_prev_indexes)[mask][torch.randperm(mask.sum())[:size_for_c]].tolist()

            print(np.unique(val_targets[prev_indexes], return_counts=True))
        observed_classes = list(range(max(task_number-1, 0) * args.class_per_task, task_number * args.class_per_task))

    if len(observed_classes) == 0:
        return prev_indexes

    observed_indexes = []
    for tc in observed_classes:
        observed_indexes += np.where(val_targets == tc)[0][:args.labeled_dataset.num_labeled_per_class].tolist()


    val_observed_targets = val_targets[observed_indexes]
    val_unique_cls = np.unique(val_observed_targets)


    selected_observed_indexes = []
    for c_idx, c in enumerate(val_unique_cls):
        size_for_c_float = ((min(args.memory_size - len(prev_indexes), args.labeled_dataset.num_labeled_per_class * args.class_per_task) - len(selected_observed_indexes)) / (len(val_unique_cls) - c_idx))
        p = size_for_c_float -  ((min(args.memory_size - len(prev_indexes), args.labeled_dataset.num_labeled_per_class * args.class_per_task) - len(selected_observed_indexes)) // (len(val_unique_cls) - c_idx))
        if random.random() < p:
            size_for_c = math.ceil(size_for_c_float)
        else:
            size_for_c = math.floor(size_for_c_float)
        mask = val_targets[observed_indexes] == c
        selected_observed_indexes += torch.tensor(observed_indexes)[mask][torch.randperm(mask.sum())[:size_for_c]].tolist()
    print(np.unique(val_targets[selected_observed_indexes], return_counts=True))


    return prev_indexes + selected_observed_indexes



def get_each_task_dataloader(args, task_number, replay_indexes):

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=args.labeled_dataset.image_size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
    ])

    target_classes = list(range(task_number * args.class_per_task, (task_number + 1)*args.class_per_task))
    print(target_classes)

    if args.labeled_dataset_name == 'cifar10':
        _train_dataset = datasets.CIFAR10(root='data/cifar10',
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
        
    elif args.labeled_dataset_name == 'cifar100':
        _train_dataset = datasets.CIFAR100(root='data/cifar100',
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
                               
    elif args.labeled_dataset_name == 'tinyimagenet':
        dataset2 = get_tiny_imagenet_data()
        _train_dataset = Custom_Dataset(
            dataset2,
            torch.arange(200).repeat_interleave(500),
            TwoCropTransform(train_transform)
        )

    else:
        raise ValueError(args.labeled_dataset_name)

    subset_indexes = []
    for tc in range(args.class_per_task*task_number, args.class_per_task*(task_number+1)):
        subset_indexes += np.where(np.array(_train_dataset.targets) == tc)[0][:args.labeled_dataset.num_labeled_per_class].tolist()
    random.shuffle(subset_indexes)

    subset_indexes += replay_indexes

    train_loader1_indexes = list(subset_indexes)

    train_dataset1 = Subset(_train_dataset, train_loader1_indexes)
    print('Dataset1 size: {}'.format(len(train_loader1_indexes)))
    uk, uc = np.unique(np.array(_train_dataset.targets)[train_loader1_indexes], return_counts=True)
    print(uk)
    print(uc[np.argsort(uk)])

    train_loader1 = torch.utils.data.DataLoader(
        train_dataset1, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers)

    return train_loader1, train_dataset1, subset_indexes


def get_unlabeled_data_loader(args, labeled_subset):
    
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=args.unlabeled_dataset.image_size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
    ])

    if args.labeled_dataset_name == 'cifar10':
        _train_dataset = datasets.CIFAR10(root='data/cifar10',
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
        
    elif args.labeled_dataset_name == 'cifar100':
        _train_dataset = datasets.CIFAR100(root='data/cifar100',
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
        
    elif args.labeled_dataset_name == 'tinyimagenet':
        dataset2 = get_tiny_imagenet_data()
        _train_dataset = Custom_Dataset(
            dataset2,
            torch.arange(200).repeat_interleave(500),
            TwoCropTransform(train_transform)
        )
        
    else:
        raise ValueError(args.labeled_dataset_name)
        
        
        
    if args.unlabeled_dataset_name == 'cifar10':
        _train_dataset2 = datasets.CIFAR10(root='data/cifar10',
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
        
    elif args.unlabeled_dataset_name == 'cifar100':
        _train_dataset2 = datasets.CIFAR100(root='data/cifar100',
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
        
    elif args.unlabeled_dataset_name == 'tinyimagenet':
        dataset22 = get_tiny_imagenet_data()
        _train_dataset2 = Custom_Dataset(
            dataset22,
            torch.arange(200).repeat_interleave(500),
            TwoCropTransform(train_transform)
        )
        
    else:
        raise ValueError(args.labeled_dataset_name)

    subset_indexes = []
    

    for tc in range(args.labeled_dataset.num_classes):
        subset_indexes += np.where(np.array(_train_dataset.targets) == tc)[0][args.labeled_dataset.num_labeled_per_class:].tolist()

    random.shuffle(subset_indexes)

    train_loader_indexes1 = subset_indexes[:args.num_main_unlabeled]
    if args.unlabeled_dataset_name == 'tinyimagenet':
        train_loader_indexes2 = np.random.choice(np.arange(100000), args.num_peripheral_unlabeled, replace = False)
    else:
        train_loader_indexes2 = np.random.choice(np.arange(50000), args.num_peripheral_unlabeled, replace = False)
    

    train_dataset1 = Subset(_train_dataset, train_loader_indexes1)
    train_dataset2 = Subset(_train_dataset, train_loader_indexes1 + labeled_subset)
    train_dataset3 = Subset(_train_dataset2, train_loader_indexes2)

    train_dataset4 = torch.utils.data.ConcatDataset([train_dataset2, train_dataset3])
    train_dataset5 = torch.utils.data.ConcatDataset([train_dataset1, train_dataset3])

    print('Dataset1 size: {}'.format(len(train_dataset4)))

    train_loader1 = torch.utils.data.DataLoader(
        train_dataset4, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers)

    return train_loader1, train_dataset5


def get_test_dataloaders(args, replay_indexes):

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=args.labeled_dataset.image_size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    target_classes = list(range(0, args.labeled_dataset.num_classes))
    print(target_classes)

    if args.labeled_dataset_name == 'cifar10':
            _train_dataset = datasets.CIFAR10(
                root='data/cifar10',
                train=True,
                transform=train_transform,
                download=True
            )

            _val_dataset = datasets.CIFAR10(
                root='data/cifar10',
                train=False,
                transform=val_transform,
                download=True
            )

    elif args.labeled_dataset_name == 'cifar100':
        _train_dataset = datasets.CIFAR100(
            root='data/cifar100',
            train=True,
            transform=train_transform,
            download=True
        )

        _val_dataset = datasets.CIFAR100(
            root='data/cifar100',
            train=False,
            transform=val_transform,
            download=True
        )
        
    elif args.labeled_dataset_name == 'tinyimagenet':

        dataset1 = get_tiny_imagenet_data()
        _train_dataset = Custom_Dataset(
            dataset1,
            torch.arange(200).repeat_interleave(500),
            train_transform
        )
        
        dataset2 = get_tiny_imagenet_test_data()
        _val_dataset = Custom_Dataset(
            dataset2,
            torch.arange(200).repeat_interleave(50),
            val_transform
        )

    else:
        raise ValueError(args.labeled_dataset_name)


    _train_targets = np.array(_train_dataset.targets)
    ut, uc = np.unique(_train_targets[replay_indexes], return_counts=True)
    print(ut)
    print(uc)

    train_dataset =  Subset(_train_dataset, replay_indexes)

    weights = np.array([0.] * len(replay_indexes))
    for t, c in zip(ut, uc):
        weights[_train_targets[replay_indexes] == t] = 1./c

    subset_indexes = []

    for tc in target_classes:
        subset_indexes += np.where(np.array(_val_dataset.targets) == tc)[0].tolist()

    val_dataset =  Subset(_val_dataset, subset_indexes)

    train_sampler = WeightedRandomSampler(torch.Tensor(weights), len(weights))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, sampler = train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers)

    return train_loader, val_loader