import copy
import json
import os
import random
from argparse import ArgumentParser

import numpy as np
from tqdm import tqdm


def read_params(argv=None):
    parser = ArgumentParser(description="Options for generating distractor files")

    # settings
    parser.add_argument("--story_root", help="vist story root", default="data/VIST/sis")
    parser.add_argument(
        "--out_dir", help="out folder root", default="data/VIST/preprocessed/gt_stories"
    )
    parser.add_argument(
        "--out_prefix",
        help="prefix for outputs",
        default="gt_stories_sampled_distractors",
    )
    parser.add_argument(
        "--distractor_suffix",
        help="distractor files suffix",
        default="gt_stories_sampled_distractors",
    )
    parser.add_argument(
        "--distractor_root",
        help="distractor root",
        default="configs/vist_distractor_ids/seed-123",
    )
    # args and init
    try:
        params = vars(parser.parse_args(args=argv))
    except IOError as msg:
        parser.error(str(msg))

    return params


if __name__ == "__main__":

    params = read_params()
    stories_root = params["story_root"]
    out_dir = params["out_dir"]
    out_file_prefix = params["out_prefix"]
    subsets = ["train", "val", "test"]

    for subset_id, subset in enumerate(subsets):
        distractor_path = os.path.join(
            params["distractor_root"],
            "%s_%s.json" % (subset, params["distractor_suffix"]),
        )
        with open(distractor_path) as f:
            distractor_data = json.load(f)

        data = json.load(
            open(os.path.join(stories_root, "%s.story-in-sequence.json" % subset))
        )

        annotations = data["annotations"]
        images = data["images"]
        image_id_2_index = {}
        story_2_index = {}
        for i in tqdm(range(len(annotations) // 5)):
            cur_sequence = annotations[5 * i : 5 * i + 5]
            cur_images = [item[0]["photo_flickr_id"] for item in cur_sequence]
            story_2_index["-".join(cur_images)] = i
            for img in cur_images:
                image_id_2_index[img] = i

        out_data = {"output_stories": []}
        for sample in distractor_data:
            cur_images = sample["photo_sequence"]
            index = story_2_index["-".join(cur_images)]
            cur_sequence = annotations[5 * index : 5 * index + 5]
            cur_captions = [item[0]["text"] for item in cur_sequence]

            out_data["output_stories"].append(
                {"photo_sequence": cur_images, "story_text_normalized": cur_captions}
            )

            distractors = sample["distractors"]
            for distractor in distractors:
                distractor_image_ids = copy.deepcopy(cur_images)
                distractor_image_ids.pop(-1)
                distractor_image_ids.append(distractor)

                distractor_captions = copy.deepcopy(cur_captions)
                distractor_captions.pop(-1)
                distractor_sample_id = image_id_2_index[distractor]
                distractor_sequence = annotations[
                    5 * distractor_sample_id : 5 * distractor_sample_id + 5
                ]
                distractor_caption = distractor_sequence[
                    [item[0]["photo_flickr_id"] for item in distractor_sequence].index(
                        distractor
                    )
                ][0]["text"]
                distractor_captions.append(distractor_caption)

                out_data["output_stories"].append(
                    {
                        "photo_sequence": distractor_image_ids,
                        "story_text_normalized": distractor_captions,
                    }
                )
        out_path = os.path.join(
            params["out_dir"], "%s_%s.json" % (subset, params["out_prefix"])
        )
        with open(out_path, "w") as f_out:
            json.dump(out_data, f_out)
