import os
from pathlib import Path
import pickle
import logging
import json
import numpy as np
import scipy.io as sio

from torchvision import datasets
from .data_utils import make_val_data
from .datasets import PickleTaskDataset, FolderTaskDataset, ClassIncrementalDataset
from ..timm_custom.data.readers import PickleReader, AnnotationReader
from .vdd_utils import get_task_data as get_imagenet_data

_logger = logging.getLogger(__name__)


def get_and_split_cifar_data(root, dataset_name=None, test_percent=None):
    
    base_folder = "cifar-10-batches-py"
    train_list = [
        "data_batch_1",
        "data_batch_2",
        "data_batch_3",
        "data_batch_4",
        "data_batch_5"
    ]

    test_list = "test_batch"

    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    path = os.path.join(root, base_folder, meta["filename"])
    with open(path, "rb") as infile:
        metadata = pickle.load(infile, encoding="latin1")
        data_classes = metadata[meta["key"]]
    class_to_idx = {_class: i for i, _class in enumerate(data_classes)}
    idx_to_class = {i: _class for i, _class in enumerate(data_classes)}

    _images = []
    _labels = []
    
    for file_name in train_list:
        file_path = os.path.join(root, base_folder, file_name)
        with open(file_path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")
            _images.append(entry["data"])
            if "labels" in entry:
                _labels.extend(entry["labels"])
            else:
                _labels.extend(entry["fine_labels"])
    _labels = np.array(_labels).astype(int)
    _class_names = list(map(lambda x: idx_to_class[x], _labels))
    _images = np.vstack(_images).reshape(-1, 3, 32, 32)
    _images = _images.transpose((0, 2, 3, 1))  # convert to HWC

    test_images = []
    test_labels = []
    file_path = os.path.join(root, base_folder, test_list)
    with open(file_path, "rb") as f:
        entry = pickle.load(f, encoding="latin1")
        test_images.append(entry["data"])
        if "labels" in entry:
            test_labels.extend(entry["labels"])
        else:
            test_labels.extend(entry["fine_labels"])
    test_labels = np.array(test_labels).astype(int)
    test_class_names = list(map(lambda x: idx_to_class[x], test_labels))
    test_images = np.vstack(test_images).reshape(-1, 3, 32, 32)
    test_images = test_images.transpose((0, 2, 3, 1))  # convert to HWC

    return _images, test_images, _labels, test_labels, _class_names, test_class_names


def get_and_split_mnist_data(root, dataset_name, test_percent=None):

    root_path = Path(root, dataset_name)

    classes_names = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
    idx_to_class = {i: _class for i, _class in enumerate(classes_names)}

    image_file = "train-images-idx3-ubyte"
    _images = datasets.mnist.read_image_file(os.path.join(root_path, image_file)).numpy()
    label_file = "train-labels-idx1-ubyte"
    _labels = datasets.mnist.read_label_file(os.path.join(root_path, label_file)).numpy().tolist()
    _class_names = list(map(lambda x: idx_to_class[x], _labels))

    test_image_file = "t10k-images-idx3-ubyte"
    test_images = datasets.mnist.read_image_file(os.path.join(root_path, test_image_file)).numpy()

    test_label_file = "t10k-labels-idx1-ubyte"
    test_labels = datasets.mnist.read_label_file(os.path.join(root_path, test_label_file)).numpy().tolist()
    test_class_names = list(map(lambda x: idx_to_class[x], test_labels))

    return _images, test_images, _labels, test_labels, _class_names, test_class_names


def get_and_split_svhn_data(root, dataset_name=None, test_percent=None):

    root_path = Path(root, "svhn")

    classes_name_list = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
    idx_to_class = {i: _class for i, _class in enumerate(classes_name_list)}

    train_loaded_mat = sio.loadmat(os.path.join(root_path, "train_32x32.mat"))
    _images = train_loaded_mat["X"].transpose(3, 0, 1, 2)
    _labels = train_loaded_mat["y"].astype(np.int64).squeeze()
    np.place(_labels, _labels == 10, 0)

    _class_names = list(map(lambda x: idx_to_class[x], _labels))

    test_loaded_mat = sio.loadmat(os.path.join(root_path, "test_32x32.mat"))
    test_images = test_loaded_mat["X"].transpose(3, 0, 1, 2)
    test_labels = test_loaded_mat["y"].astype(np.int64).squeeze()
    np.place(test_labels, test_labels == 10, 0)
    test_class_names = list(map(lambda x: idx_to_class[x], test_labels))

    return _images, test_images, _labels, test_labels, _class_names, test_class_names


def get_and_split_not_mnist_data(root, dataset_name=None, test_percent=20.0):

    root_path = Path(root, dataset_name)

    class_name_list = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
    class_to_idx = {_class: i for i, _class in enumerate(class_name_list)}

    all_images = []
    all_classes = []

    for _class in class_name_list:

        class_path = Path(root_path, _class)
        _images = [Path(dataset_name, _class, f) for f in os.listdir(class_path) if ".png" in f] # list(class_path.glob("*.png"))
        all_images = all_images + [str(img) for img in _images]
        all_classes = all_classes + [_class for img in _images]
    all_labels = list(map(lambda x: class_to_idx[x], all_classes))

    _images, test_images, _labels, test_labels, _class_names, test_class_names = make_val_data(all_images, all_labels, all_classes, test_percent)

    return _images, test_images, _labels, test_labels, _class_names, test_class_names


def load_cache(dataset_name, cache_root, dev_percent=0.0):

    cache_path = Path(cache_root, dataset_name)

    if dataset_name == "not-mnist":

        # Load from json
        image_cache = Path(cache_path, "images.json")
        label_cache = Path(cache_path, "labels.json")
        class_names_cache = Path(cache_path, "class_names.json")
        cache_exists = image_cache.exists() and label_cache.exists() and class_names_cache.exists()

        if cache_exists:
            _logger.info(f"Loading not-mnist image metadata from {image_cache}")
            with open(image_cache, "r") as f:
                images = json.load(f)
            _logger.info(f"Loading not-mnist labela metadata from {image_cache}")
            with open(label_cache, "r") as f:
                labels = json.load(f)
            _logger.info(f"Loading not-mnist class_names metadata from {image_cache}")
            with open(class_names_cache, "r") as f:
                class_names = json.load(f)

            if dev_percent <= 0.:
                # Combine train and val splits
                _logger.info("Merging Train and Val data.")
                train_images = images["train"] + images["val"]
                del images["val"]
                images["train"] = train_images

                train_labels = labels["train"] + labels["val"]
                del labels["val"]
                labels["train"] = train_labels

                train_class_names = class_names["train"] + class_names["val"]
                del class_names["val"]
                class_names["train"] = train_class_names

        else:
            _logger.info(f"Cache for not-mnist does not exists at paths {image_cache}, {label_cache}, {class_names_cache}")
            images, labels, class_names = [], [], []

    else:
        # Load from numpy
        image_cache = Path(cache_path, "images.npz")
        label_cache = Path(cache_path, "labels.npz")
        class_names_cache = Path(cache_path, "class_names.pickle")
        cache_exists = image_cache.exists() and label_cache.exists() and class_names_cache.exists()

        if cache_exists:
            _logger.info(f"Loading {dataset_name} images from {image_cache}")
            images = np.load(image_cache)
            images = {"train": images["train"], "val": images["val"], "test": images["test"]}
            _logger.info(f"Loading {dataset_name} labels from {label_cache}")
            labels = np.load(label_cache)
            labels = {"train": labels["train"], "val": labels["val"], "test": labels["test"]}
            _logger.info(f"Loading {dataset_name} class_names from {class_names_cache}")
            with open(class_names_cache, "rb") as f:
                class_names = pickle.load(f)

            if dev_percent <= 0.:
                # Combine tran and val sets
                _logger.info("Merging Train and Val data.")
                train_images = np.concatenate([images["train"], images["val"]], axis=0)
                del images["val"]
                images["train"] = train_images

                train_labels = np.concatenate([labels["train"], labels["val"]], axis=0)
                del labels["val"]
                labels["train"] = train_labels

                train_class_names = class_names["train"] + class_names["val"]
                del class_names["val"]
                class_names["train"] = train_class_names
        else:
            _logger.info(f"Cache for {dataset_name} does not exists at paths {image_cache}, {label_cache}, {class_names_cache}")
            images, labels, class_names = [], [], []

    return cache_exists, images, labels, class_names


def save_cache(dataset_name, images, labels, class_names, root):

    cache_path = Path(root, dataset_name)
    cache_path.mkdir(exist_ok=True, parents=True)

    if dataset_name == "not-mnist":
        # Save to json
        image_cache = Path(cache_path, "images.json")
        label_cache = Path(cache_path, "labels.json")
        class_names_cache = Path(cache_path, "class_names.json")

        _logger.info(f"Saving not-mnist image metadata to {image_cache}")
        with open(image_cache, "w") as f:
            json.dump(images, f)
        _logger.info(f"Saving not-mnist label metadata to {label_cache}")
        with open(label_cache, "w") as f:
            json.dump(labels, f)
        _logger.info(f"Saving not-mnist class_names metadata to {class_names_cache}")
        with open(class_names_cache, "w") as f:
            json.dump(class_names, f)

    else:
        # Save to npz and oickle
        image_cache = Path(cache_path, "images.npz")
        label_cache = Path(cache_path, "labels.npz")
        class_names_cache = Path(cache_path, "class_names.pickle")

        _logger.info(f"Saving {dataset_name} images to {image_cache}")
        np.savez_compressed(image_cache, train=images["train"], val=images["val"], test=images["test"])
        _logger.info(f"Saving {dataset_name} labels to {label_cache}")
        np.savez_compressed(label_cache, train=labels["train"], val=labels["val"], test=labels["test"])
        _logger.info(f"Saving {dataset_name} class_names to {class_names_cache}")
        with open(class_names_cache, "wb") as f:
            pickle.dump(class_names, f)


def get_task_data(root, task_name, dev_percent=0.0, test_percent=20., cache=True, cache_root=None):

    if cache:
        cache_exists, images, labels, class_names = load_cache(task_name, cache_root, dev_percent=dev_percent)

        if cache_exists:
            return images, labels, class_names

    dataset_loaders = {
        "svhn": get_and_split_svhn_data,
        "cifar10": get_and_split_cifar_data,
        "not-mnist": get_and_split_not_mnist_data,
        "fashion-mnist": get_and_split_mnist_data,
        "mnist": get_and_split_mnist_data
    }

    images, labels, class_names = dict(), dict(), dict()

    _images, test_images, _labels, test_labels, _class_names, test_class_names = dataset_loaders[task_name](root, dataset_name=task_name, test_percent=test_percent)

    images["test"] = test_images
    labels["test"] = test_labels
    class_names["test"] = test_class_names

    if dev_percent > 0.:
        train_images, val_images, train_labels, val_labels, train_class_names, val_class_names = make_val_data(_images, _labels, _class_names, dev_percent)
        images["val"] = val_images
        labels["val"] = val_labels
        class_names["val"] = val_class_names
    else:
        train_images = _images
        train_labels = _labels
        train_class_names = _class_names

    images["train"] = train_images
    labels["train"] = train_labels
    class_names["train"] = train_class_names

    if cache and dev_percent > 0.:
        save_cache(task_name, images, labels, class_names, cache_root)
    
    return images, labels, class_names

def get_5datasets_data(imnet_root, root, task_names, task_idx=None, load_single_task=False, cache=True, cache_root=None, dev_percent=10.0, test_percent=20.0, offset_task_labels=False, scenario="class-incremental"):
    """
    scenario controls the structure of the task dictionary. 
        class-incremental: {<task>: {"train": [], "test": []}}
    """

    if load_single_task:
        assert task_idx is not None

    reader_classes = {
        "imagenet12": AnnotationReader, 
        "svhn": PickleReader,
        "cifar10": PickleReader,
        "not-mnist": AnnotationReader,
        "fashion-mnist": PickleReader,
        "mnist": PickleReader
    }
    dataset_classes = {
        "imagenet12": FolderTaskDataset,
        "svhn": PickleTaskDataset,
        "cifar10": PickleTaskDataset,
        "not-mnist": FolderTaskDataset,
        "fashion-mnist": PickleTaskDataset,
        "mnist": PickleTaskDataset
    }

    datasets = {
        "train": dict(),
        "test": dict()
    }
    if dev_percent > 0.:
        datasets["val"] = dict()

    for _task_idx, task_name in enumerate(task_names):

        if load_single_task and _task_idx != task_idx:
            continue

        if task_name == "imagenet12":
            images, labels, class_names = get_imagenet_data(imnet_root, task_name, dev_percent=dev_percent)
        else:
            images, labels, class_names = get_task_data(root, task_name, dev_percent=dev_percent, test_percent=test_percent, cache=cache, cache_root=cache_root)

        sets = ["train", "test"]
        if dev_percent > 0.:
            sets.append("val")

        for _set in sets:
            set_images = images[_set]
            set_labels = labels[_set]
            set_class_names = class_names[_set]

            # Create the reader
            _root = imnet_root if task_name == "imagenet12" else root
            reader = reader_classes[task_name](_root, set_images, set_labels)
            # Create the dataset
            dataset = dataset_classes[task_name](reader, root=None, class_map=None)
            datasets[_set][task_name] = dataset

    # Make the datasets
    datasets = {_set: ClassIncrementalDataset(datasets[_set], offset_task_labels=offset_task_labels) for _set in sets}

    return datasets
