
"""
This script is used to convert the referring expression dataset annotations to COCO format as expected by MDETR.
data_path :  path to original refexp annotations to be downloaded from https://github.com/lichengunc/refer

"""
import argparse
import json
import os
import pickle
from pathlib import Path
import sys
PACKAGE_PARENT = ".."
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
from utils.spans import consolidate_spans
from utils.text import get_root_and_nouns


def parse_args():
    parser = argparse.ArgumentParser("Conversion script")

    parser.add_argument(
        "--data_path",
        required=True,
        type=str,
        help="Path to the refexp data",
    )

    parser.add_argument(
        "--out_path",
        required=True,
        type=str,
        help="Path where to export the resulting dataset.",
    )

    parser.add_argument(
        "--coco_path",
        required=True,
        type=str,
        help="Path to coco 2014 dataset.",
    )

    return parser.parse_args()


def convert(dataset_path: Path, dataset_name: str, split: str, output_path, coco_path, next_img_id: int = 0, next_id: int = 0):
    """Do the heavy lifting on the given split (eg 'train')"""

    print(f"Exporting {split}...")

    with open(f"{coco_path}/annotations/instances_train2014.json", "r") as f:
        coco_annotations = json.load(f)
    coco_images = coco_annotations["images"]
    coco_anns = coco_annotations["annotations"]
    annid2cocoann = {item["id"]: item for item in coco_anns}
    imgid2cocoimgs = {item["id"]: item for item in coco_images}

    categories = coco_annotations["categories"]
    annotations = []
    images = []

    d_name = dataset_name.split("/")[0]

    with open(dataset_path / dataset_name, "rb") as f:
        data = pickle.load(f)

    for item in data:
        if item["split"] != split:
            continue

        for s in item["sentences"]:
            refexp = s["sent"]
            _, _, root_spans, neg_spans = get_root_and_nouns(refexp)
            root_spans = consolidate_spans(root_spans, refexp)
            neg_spans = consolidate_spans(neg_spans, refexp)

            filename = "_".join(item["file_name"].split("_")[:-1]) + ".jpg"
            cur_img = {
                "file_name": filename,
                "height": imgid2cocoimgs[item["image_id"]]["height"],
                "width": imgid2cocoimgs[item["image_id"]]["width"],
                "id": next_img_id,
                "original_id": item["image_id"],
                "caption": refexp,
                "dataset_name": d_name,
                "tokens_negative": neg_spans,
            }

            cur_obj = {
                "area": annid2cocoann[item["ann_id"]]["area"],
                "iscrowd": annid2cocoann[item["ann_id"]]["iscrowd"],
                "image_id": next_img_id,
                "category_id": item["category_id"],
                "id": next_id,
                "bbox": annid2cocoann[item["ann_id"]]["bbox"],
                
                "original_id": item["ann_id"],
                "tokens_positive": root_spans,
            }
            next_id += 1
            annotations.append(cur_obj)
            next_img_id += 1
            images.append(cur_img)

    ds = {
        "info": coco_annotations["info"],
        "licenses": coco_annotations["licenses"],
        "images": images,
        "annotations": annotations,
        "categories": coco_annotations["categories"],
    }
    with open(output_path / f"finetune_{d_name}_{split}.json", "w") as j_file:
        json.dump(ds, j_file)
    return next_img_id, next_id


def main(args):
    data_path = Path(args.data_path)
    output_path = Path(args.out_path)

    os.makedirs(str(output_path), exist_ok=True)

    next_img_id, next_id = 0, 0

    for dataset_name in ["refcoco/refs(unc).p", "refcoco+/refs(unc).p", "refcocog/refs(umd).p"]:
        for split in ["train", "val"]:
            next_img_id, next_id = convert(
                data_path, dataset_name, split, output_path, args.coco_path, next_img_id=next_img_id, next_id=next_id,
            )


if __name__ == "__main__":
    main(parse_args())
