import json
import os
import random
import zipfile
from copy import deepcopy

import torch
import torchvision
from MLclf import MLclf
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset
from torch.utils.data import Subset, DataLoader
from torchvision.datasets.utils import download_url
from tqdm import tqdm
from transformers import AutoTokenizer

from data.data import CUB200
from utils import verbose_iterator


def preprocess(data, tokenizer, label_to_int):
    inputs_process = []
    labels_process = []
    for inputs, labels in data:
        inputs_process.append(
            tokenizer(
                inputs,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64,
            )
        )
        labels_process.append(label_to_int[labels])
    return {"inputs": inputs_process, "labels": labels_process}


class CustomDataset(Dataset):
    def __init__(self, data):
        self.inputs = []
        self.targets = []
        for inputs, labels in zip(data["inputs"], data["labels"]):
            self.inputs.append(inputs)
            self.targets.append(labels)
        self.classes = list(set(self.targets))

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        tokenized_data = self.inputs[idx]
        # print(tokenized_data)

        return {
            key: val.squeeze() for key, val in tokenized_data.items()
        }, self.targets[idx]


def get_data(
    dataset_name,
    cl_type,
    verbose,
    times_augment=0,
    download=True,
    dir_store="./store/datasets",
    **kwargs,
):
    """Get train, test data"""
    if not os.path.exists(dir_store):
        os.makedirs(dir_store)
    # specify transforms

    transform = [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
    ]

    transform = torchvision.transforms.Compose(transform)

    if dataset_name == "clinc150":
        if not os.path.exists(dir_store + f"/clinc150/"):
            os.makedirs(dir_store + f"/clinc150/")
            url = "https://archive.ics.uci.edu/static/public/570/clinc150.zip"
            filename = "clinc150.zip"
            root = dir_store + f"/clinc150/"
            download_url(url, dir_store + f"/clinc150/", filename)
            # Try to unzip the file
            with zipfile.ZipFile(root + filename, "r") as zip_ref:
                zip_ref.extractall(root)

        with open(dir_store + f"/clinc150/clinc150_uci/data_full.json") as f:
            data = json.load(f)

        # specify tokenizer
        model_name = "distilroberta-base"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token

        labels = [labels for _, labels in data["train"]]
        targets = sorted(list(set(labels)))
        label_to_int = {label: index for index, label in enumerate(targets)}

        train_dataset = CustomDataset(
            preprocess(data["train"], tokenizer, label_to_int)
        )
        test_dataset = CustomDataset(preprocess(data["test"], tokenizer, label_to_int))

    else:
        if dataset_name == "imagenet":
            transform = torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Resize(256, antialias=True),
                    torchvision.transforms.CenterCrop(224),
                    torchvision.transforms.Normalize(
                        mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]
                    ),
                ]
            )
        elif dataset_name == "cub":
            transform = torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Resize(256, antialias=True),
                    torchvision.transforms.CenterCrop(224),
                    torchvision.transforms.Normalize(
                        mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]
                    ),
                ]
            )
        if cl_type == "augment":
            aug_transforms = []
            for i in range(
                times_augment
            ):  # create times_augment different augmentation transforms
                aug_transforms_base = [
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.RandomRotation(
                        15 * (i + 1)
                    ),  # different rotation for each
                    torchvision.transforms.ColorJitter(
                        brightness=0.1 * (i + 1), contrast=0.1 * (i + 1)
                    ),
                    # different color jitter for each
                    torchvision.transforms.RandomAffine(
                        degrees=15 * (i + 1)
                    ),  # different affine transformation for each
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]
                    ),
                ]
                if dataset_name == "cub":
                    aug_transforms_base.insert(
                        0, torchvision.transforms.Resize(256, antialias=True)
                    )
                    aug_transforms_base.insert(
                        1, torchvision.transforms.CenterCrop(224)
                    )
                # Add transformation to PIL Image if it is a miniimagenet dataset
                if dataset_name == "miniimagenet" or dataset_name == "tinyimagenet":
                    aug_transforms_base.insert(0, torchvision.transforms.ToPILImage())

                aug_transforms.append(
                    torchvision.transforms.Compose(aug_transforms_base)
                )

        # get data
        if dataset_name == "cifar10":
            train_dataset, test_dataset = [
                torchvision.datasets.CIFAR10(
                    root=dir_store, train=train, download=download, transform=transform
                )
                for train in [True, False]
            ]
            if cl_type == "augment":
                train_augmented_datasets = [
                    torchvision.datasets.CIFAR10(
                        root=dir_store, train=True, download=False, transform=transform
                    )
                    for transform in aug_transforms
                ]
                test_augmented_datasets = [
                    torchvision.datasets.CIFAR10(
                        root=dir_store, train=False, download=False, transform=transform
                    )
                    for transform in aug_transforms
                ]
                train_dataset = ConcatDataset(
                    [train_dataset] + train_augmented_datasets
                )
                test_dataset = ConcatDataset([test_dataset] + test_augmented_datasets)

        elif dataset_name == "cifar100":
            train_dataset, test_dataset = [
                torchvision.datasets.CIFAR100(
                    root=dir_store, train=train, download=download, transform=transform
                )
                for train in [True, False]
            ]
            if cl_type == "augment":
                train_augmented_datasets = [
                    torchvision.datasets.CIFAR100(
                        root=dir_store, train=True, download=False, transform=transform
                    )
                    for transform in aug_transforms
                ]
                test_augmented_datasets = [
                    torchvision.datasets.CIFAR100(
                        root=dir_store, train=False, download=False, transform=transform
                    )
                    for transform in aug_transforms
                ]
                train_dataset = ConcatDataset(
                    [train_dataset] + train_augmented_datasets
                )
                test_dataset = ConcatDataset([test_dataset] + test_augmented_datasets)

        elif dataset_name == "imagenet":
            train_dataset, test_dataset = [
                torchvision.datasets.ImageNet(
                    root=dir_store, train=train, download=download, transform=transform
                )
                for train in [True, False]
            ]

        elif dataset_name == "cub":
            train_dataset, test_dataset = CUB200(
                root=dir_store, download=download, transform=transform
            )
            if cl_type == "augment":
                train_augmented_datasets, test_augmented_datasets = [], []
                for transform_ in aug_transforms:
                    train_dataset_, test_dataset_ = CUB200(
                        root=dir_store, download=download, transform=transform_
                    )
                    train_augmented_datasets.append(train_dataset_)
                    test_augmented_datasets.append(test_dataset_)

                train_dataset = ConcatDataset(
                    [train_dataset] + train_augmented_datasets
                )
                test_dataset = ConcatDataset([test_dataset] + test_augmented_datasets)

        elif dataset_name == "miniimagenet" or dataset_name == "tinyimagenet":
            current_dir = os.getcwd()
            os.chdir(dir_store)
            if dataset_name == "miniimagenet":
                download = (
                    False
                    if os.path.isfile(f"./data_miniimagenet/miniimagenet.zip")
                    else True
                )
                MLclf.miniimagenet_download(Download=download)
                train_dataset, _, test_dataset = MLclf.miniimagenet_clf_dataset(
                    ratio_train=5 / 6,
                    ratio_val=0.0,
                    seed_value=None,
                    shuffle=True,
                    transform=transform,
                    save_clf_data=False,
                )
            else:
                download = (
                    False
                    if os.path.isfile(f"./data_tinyimagenet/tiny-imagenet-200.zip")
                    else True
                )
                MLclf.tinyimagenet_download(Download=download)
                train_dataset, _, test_dataset = MLclf.tinyimagenet_clf_dataset(
                    ratio_train=5 / 6,
                    ratio_val=0.0,
                    seed_value=None,
                    shuffle=True,
                    transform=transform,
                    save_clf_data=False,
                )

            if cl_type == "augment":
                if dataset_name == "miniimagenet":
                    augmented_train_datasets = []
                    # augmented_test_datasets = []
                    for aug_transform in verbose_iterator(aug_transforms, verbose):
                        (
                            train_dataset_,
                            _,
                            test_dataset_,
                        ) = MLclf.miniimagenet_clf_dataset(
                            ratio_train=5 / 6,
                            ratio_val=0.0,
                            seed_value=None,
                            shuffle=True,
                            transform=aug_transform,
                            save_clf_data=False,
                        )
                        augmented_train_datasets.append(train_dataset_)
                        # augmented_test_datasets.append(test_dataset_)

                    train_dataset = ConcatDataset(
                        [train_dataset] + augmented_train_datasets
                    )
                    # test_dataset = ConcatDataset([test_dataset] + augmented_test_datasets)
                else:
                    augmented_train_datasets = []
                    # augmented_test_datasets = []
                    for aug_transform in verbose_iterator(aug_transforms, verbose):
                        (
                            train_dataset_,
                            _,
                            test_dataset_,
                        ) = MLclf.tinyimagenet_clf_dataset(
                            ratio_train=5 / 6,
                            ratio_val=0.0,
                            seed_value=None,
                            shuffle=True,
                            transform=aug_transform,
                            save_clf_data=False,
                        )
                        augmented_train_datasets.append(train_dataset_)
                        # augmented_test_datasets.append(test_dataset_)

                    train_dataset = ConcatDataset(
                        [train_dataset] + augmented_train_datasets
                    )
            os.chdir(current_dir)
        else:
            raise ValueError(
                "Dataset not supported, choose between cifar10, cifar100, miniimagenet"
            )
    return train_dataset, test_dataset


def convert_labels_to_long(batch):
    data, labels = zip(*batch)
    return torch.stack(data), torch.stack(labels).long()


def get_ada_loader(
    dataset, task_labels, n_tasks, n_repeats, batch_size, collate_fn, verbose, **kwargs
):
    # Split indices of train_dataset into n_tasks by labels
    dataset_ = deepcopy(dataset)
    indices_tasks = [[] for _ in range(len(task_labels))]
    for i, (_, label) in enumerate(verbose_iterator(dataset_, verbose)):
        for task, labels in enumerate(task_labels):
            if label in labels:
                indices_tasks[task].append(i)
                break

    # Shuffle indices within each task,  and split them into segments of size n_data_per_segment
    n_data_per_segment = len(dataset_) // (n_repeats * n_tasks)
    print(f"num of tasks {n_tasks}, each task repeats {n_repeats} times")
    print(f"num of samples per segment: {n_data_per_segment}")
    indices_batches = []
    for indices in indices_tasks:
        random.shuffle(indices)
        indices_batches += [
            indices[i : i + n_data_per_segment]
            for i in range(0, len(indices), n_data_per_segment)
        ]
    random.shuffle(indices_batches)

    # Create dataloader in order of indices_batches
    ada_cl_dataset = Subset(
        dataset_, [_ for indices in indices_batches for _ in indices]
    )

    ada_cl = DataLoader(
        ada_cl_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )
    return ada_cl


def get_ada_cl(
    dataset_name,
    n_tasks,
    n_repeats,
    batch_size,
    n_classes,
    cl_type,
    verbose,
    times_augment=0,
    return_train_list=False,
    **kwargs,
):
    """Generate Adaptive CL scenario, and return Adaptive CL dataloader and test dataloader list for all tasks"""
    # get data
    train_dataset, test_dataset = get_data(
        dataset_name=dataset_name,
        cl_type=cl_type,
        verbose=verbose,
        times_augment=times_augment,
    )
    print(f"Len of dataset: {len(train_dataset)}, {len(test_dataset)}")

    collate_fn = (
        convert_labels_to_long
        if (dataset_name == "miniimagenet" or dataset_name == "tinyimagenet")
        else None
    )

    # task information
    n_class_per_task = n_classes // n_tasks
    task_labels = [
        torch.arange(n_classes)[i : i + n_class_per_task]
        for i in range(0, n_classes, n_class_per_task)
    ]

    ada_cl = get_ada_loader(
        train_dataset,
        task_labels,
        n_tasks,
        n_repeats,
        batch_size,
        collate_fn,
        verbose,
    )

    # Create test dataloader
    task_labels = [
        torch.arange(n_classes)[i : i + n_class_per_task]
        for i in range(0, n_classes, n_class_per_task)
    ]
    test_cl_list = []
    test_labels = [labels for _, labels in test_dataset]
    print("Creating Adaptive CL scenario")
    for task_idx in tqdm(verbose_iterator(range(n_tasks), verbose)):
        # get a list of test dataloader for each task
        indices = [
            i for i, labels in enumerate(test_labels) if labels in task_labels[task_idx]
        ]
        dataset = Subset(test_dataset, indices)
        test_cl_list.append(
            DataLoader(
                dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
            )
        )

    if return_train_list:
        task_labels = [
            torch.arange(n_classes)[i : i + n_class_per_task]
            for i in range(0, n_classes, n_class_per_task)
        ]
        train_cl_list = []
        train_dataset, test_dataset = get_data(
            dataset_name=dataset_name,
            cl_type=cl_type,
            verbose=verbose,
            times_augment=times_augment,
        )
        for task_idx in tqdm(verbose_iterator(range(n_tasks), verbose)):
            # get a list of train dataloader for each task

            train_dataset_ = deepcopy(train_dataset)
            indices = [
                i
                for i, (_, labels) in enumerate(train_dataset_)
                if labels in task_labels[task_idx]
            ]
            dataset = Subset(train_dataset_, indices)
            train_cl_list.append(
                DataLoader(
                    dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
                )
            )
        return ada_cl, test_cl_list, train_cl_list
    else:
        return ada_cl, test_cl_list
