from copy import deepcopy
import os
import json
import argparse

from httpx import options
from pycocotools import mask as mask_utils
import cv2
from region_features.region_utils import show_image

def load_json(file: str):
    if file.endswith('.json'):
        with open(file, 'r') as f:
            return json.load(f)
    elif file.endswith('.jsonl'):
        with open(file, 'r') as f:
            return [json.loads(line) for line in f]
    else:
        raise ValueError("File must be a json or jsonl file")

def istype(ans_type, answer: str):
    if ans_type == "all":
        return True
    elif ans_type == "yesno":
        return answer in ["yes", "no"]
    elif ans_type == "number":
        return answer.isdigit()
    elif ans_type == "other":
        return not (answer in ["yes", "no"] or answer.isdigit())
    else:
        raise ValueError(f"Invalid answer type: {ans_type}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--data_path",
        type=str,
        default=None,
        help="json file with image paths and annotations",
    )

    parser.add_argument(
        "--image-folder",
        type=str,
        default="playground/data/eval/vqav2/test2015",
    )

    parser.add_argument(
        "--mask-folder",
        type=str,
        default="playground/data/regions/vqav2/regions-mixed",
    )

    parser.add_argument(
        "--save_dir",
        type=str,
        default=None,
        help="path to save the differences",
    )

    parser.add_argument(
        "--ans1_file",
        type=str,
        default=None,
        help="json file for the first answer",
    )

    parser.add_argument(
        "--ans2_file",
        type=str,
        default=None,
        help="json file for the second answer",
    )

    parser.add_argument(
        "--ans_type",
        type=str,
        default="all",
        help="[all, yesno, number, other]",
    )

    parser.add_argument(
        "--save_limit",
        type=int,
        default=32,
        help="max number of images to save",
    )


    args = parser.parse_args()

    dataset = load_json(args.data_path)
    dataset = {d['question_id']: d for d in dataset}
    ans1 = load_json(args.ans1_file)
    ans1 = {a['question_id']: a for a in ans1}
    ans2 = load_json(args.ans2_file)
    ans2 = {a['question_id']: a for a in ans2}
    print(f"Loaded {len(dataset)} questions, {len(ans1)} answers from ans1, {len(ans2)} answers from ans2")

    diff = []
    for question_id in dataset.keys():
        data = dataset[question_id]
        ans_1 = ans1.get(question_id, {"answer": ""})
        ans_2 = ans2.get(question_id, {"answer": ""})
        answer1 = ans_1.get("answer", ans_1.get("text", ""))
        answer2 = ans_2.get("answer", ans_2.get("text", ""))
        if answer1 != answer2 and istype(args.ans_type, answer1):
            data_ = deepcopy(data)
            data_["answer1"] = answer1
            data_["answer2"] = answer2
            diff.append(data_)

    print(f"Found {len(diff)} differences")
    os.makedirs(args.save_dir, exist_ok=True)
    diff_file = os.path.join(args.save_dir, "diff.json")
    with open(diff_file, "w") as f:
        json.dump(diff, f, indent=4)

    os.makedirs(os.path.join(args.save_dir, "img"), exist_ok=True)
    os.makedirs(os.path.join(args.save_dir, "mask"), exist_ok=True)
    for data in diff[:args.save_limit]:
        img  = cv2.imread(os.path.join(args.image_folder, data["image"]))
        cv2.imwrite(os.path.join(args.save_dir, f"img/{data['image']}"), img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        masks_file = os.path.join(args.mask_folder, os.path.splitext(data["image"])[0] + ".json")
        masks = json.load(open(masks_file))
        for mask in masks:
            mask["segmentation"] = mask_utils.decode(mask["segmentation"]).astype(bool)
        show_image(img, masks, os.path.join(args.save_dir, f"mask/{data['image']}"))