# Script for generating visualizations of the VIST dataset
import copy
import json
import os
import random
from argparse import ArgumentParser

def read_params(argv=None):
    parser = ArgumentParser(description="Options for generating distractor files")

    # settings
    parser.add_argument("--seed", help="random seed", default=123, type=int)
    parser.add_argument(
        "--gen_caption_root", help="generated story files root", default="data/VIST/sis"
    )

    parser.add_argument("--out_dir", help="out folder root", default="checkpoints")
    parser.add_argument(
        "--out_prefix", help="prefix for outputs", default="gen_stories"
    )
    parser.add_argument(
        "--distractor_suffix",
        help="distractor files suffix",
        default="easy_sampled_distractors",
    )
    parser.add_argument(
        "--distractor_root",
        help="distractor root",
        default="configs/vist_distractor_ids/seed-123",
    )

    parser.add_argument(
        "--num_distractors",
        help="Number of distractors per sample",
        default=4,
        type=int,
    )

    # args and init
    try:
        params = vars(parser.parse_args(args=argv))
    except IOError as msg:
        parser.error(str(msg))

    return params


if __name__ == "__main__":
    params = read_params()
    subsets = ["train", "val", "test"]
    # subsets = ["train"]
    for subset_id, subset in enumerate(subsets):
        distractor_path = os.path.join(
            params["distractor_root"],
            "%s_%s.json" % (subset, params["distractor_suffix"]),
        )
        with open(distractor_path) as f:
            distractor_data = json.load(f)
        gen_caption_path = os.path.join(
            params["gen_caption_root"], "VIST_%s_gen_caption.tsv" % subset
        )

        image2gencaption = {}
        f = open(gen_caption_path)
        for line in f.readlines():
            image_id = line.split("\t")[0]
            image2gencaption[image_id] = json.loads(line.split("\t")[1])[0]["caption"]

        out_data = {"output_stories": []}
        for sample in distractor_data:
            cur_images = sample["photo_sequence"]
            distractors = sample["distractors"]

            cur_captions = [
                image2gencaption[image_id] if image_id in image2gencaption else None
                for image_id in cur_images
            ]
            if None in cur_captions:
                print("story not found")
                continue

            out_data["output_stories"].append(
                {"photo_sequence": cur_images, "story_text_normalized": cur_captions}
            )

            for distractor in distractors:
                if distractor in image2gencaption:

                    distractor_image_ids = copy.deepcopy(cur_images)
                    distractor_image_ids.pop(-1)
                    distractor_image_ids.append(distractor)

                    distractor_captions = copy.deepcopy(cur_captions)
                    distractor_captions.pop(-1)
                    distractor_captions.append(image2gencaption[distractor])

                    out_data["output_stories"].append(
                        {
                            "photo_sequence": distractor_image_ids,
                            "story_text_normalized": distractor_captions,
                        }
                    )

                else:
                    print("distractor not found")

        # write file
        print("tot samples", len(out_data["output_stories"]))
        out_path = os.path.join(
            params["out_dir"], "%s_%s.json" % (subset, params["out_prefix"])
        )
        print("out path", out_path)
        with open(out_path, "w") as f_out:
            json.dump(out_data, f_out)
