import json
from pathlib import Path
from .data_utils import make_val_data
from .datasets import FolderTaskDataset, ClassIncrementalDataset
from ..timm_custom.data.readers import AnnotationReader


def parse_vdd_json(json_path):

    with open(json_path, "r") as f:
        content = json.load(f)

    # Separate the fields for convenience
    info = content["info"]
    images = content["images"]
    annotations = content["annotations"]
    categories = content["categories"]

    assert len(images) == len(annotations)
    min_category_id = min([c["id"] for c in categories])
    
    _categories = {c["id"]: {"name": c["name"], "index": c["id"]-min_category_id} for c in categories}

    # Create a map
    _annotations = {
            a["id"]: {
                "category_id": a["category_id"],
                "index": _categories[a["category_id"]]["index"],
                "name": _categories[a["category_id"]]["name"]
            } 
            for a in annotations if a["id"] == a["image_id"]}
    assert len(_annotations) == len(annotations) # Ensure no loss

    # Map the images to classes
    image_info = {i["id"]: {"filename": i["file_name"], "class_info": _annotations[i["id"]]} for i in images}
    assert len(image_info) == len(images) # Ensure perfect mapping

    return image_info


def get_task_data(vdd_root, task_name, dev_percent=0.0):

    task_train_json = Path(vdd_root, "annotations", f"{task_name}_train.json")
    train_image_info = parse_vdd_json(task_train_json)

    # Get the val data info (that will be used as the test set)
    task_val_json = Path(vdd_root, "annotations", f"{task_name}_val.json")
    test_image_info = parse_vdd_json(task_val_json)

    image_info = {"train": train_image_info, "test": test_image_info}

    images = {"train": [], "test": []}
    labels = {"train": [], "test": []}
    class_names = {"train": [], "test": []}

    for setname in ["train", "test"]:
        for _id, im_meta in image_info[setname].items():
            filename = im_meta["filename"]
            class_info = im_meta["class_info"]

            class_id = class_info["category_id"]
            class_name = class_info["name"]
            class_index = class_info["index"]

            images[setname].append(filename)
            labels[setname].append(class_index)
            class_names[setname].append(class_name)

    if dev_percent > 0.:
        train_images, val_images, train_labels, val_labels, train_class_names, val_class_names = make_val_data(images["train"], labels["train"], class_names["train"], dev_percent)
        images["train"] = train_images
        labels["train"] = train_labels
        class_names["train"] = train_class_names

        images["val"] = val_images
        labels["val"] = val_labels
        class_names["val"] = val_class_names

    return images, labels, class_names

def get_vdd_data(vdd_root, task_names, task_idx=None, load_single_task=False, cache=True, dev_percent=0.1, offset_task_labels=False, scenario="class-incremental"):
    """
    scenario controls the structure of the task dictionary. 
        class-incremental: {<task>: {"train": [], "test": []}}
    """

    vdd_tasks = ["imagenet12", "cifar100", "svhn", "ucf101", "omniglot", "gtsrb", "daimlerpedcls", "vgg-flowers", "aircraft", "dtd"]

    if cache:
        # Check if cache exists
        # cache_root = Path(vdd_root, "cache", "metadata")
        # cache_root.mkdir(exist_ok=True, parents=True)
        # for task_name in tasks:
        #     filename = Path(cache_root, f"{task_name}.json")
        #     with open(filename, "w") as f:
        #         json.dump(task_data[task_name], f)
        pass

    if load_single_task:
        assert task_idx is not None

    datasets = {
        "train": dict(),
        "test": dict()
    }
    if dev_percent > 0.:
        datasets["val"] = dict()

    for _task_idx, task_name in enumerate(task_names):

        if task_name not in vdd_tasks or (load_single_task and _task_idx != task_idx):
            continue

        images, labels, class_names = get_task_data(vdd_root, task_name, dev_percent=dev_percent)

        if cache:
            pass

        sets = ["train", "test"]
        if dev_percent > 0.:
            sets.append("val")

        for _set in sets:
            set_images = images[_set]
            set_labels = labels[_set]
            set_class_names = class_names[_set]

            # Create the parser
            reader = AnnotationReader(vdd_root, set_images, set_labels)
            dataset = FolderTaskDataset(reader, vdd_root)
            datasets[_set][task_name] = dataset

    # Make the datasets
    datasets = {_set: ClassIncrementalDataset(datasets[_set], offset_task_labels=offset_task_labels) for _set in sets}

    return datasets
