import json
import os
import random
from argparse import ArgumentParser

import numpy as np
from tqdm import tqdm


def read_params(argv=None):
    parser = ArgumentParser(description="Options for generating distractor files")

    # settings
    parser.add_argument("--seed", help="random seed", default=123, type=int)
    parser.add_argument("--story_root", help="vist story root", default="data/VIST/sis")
    parser.add_argument("--caption_root", help="vist story root", default="data/VIST/dii")
    parser.add_argument(
        "--split", help="captions or stories", default="stories"
    )
    
    parser.add_argument(
        "--out_dir", help="out folder root", default="configs/vist_distractor_ids"
    )
    parser.add_argument(
        "--out_prefix", help="prefix for outputs", default="sampled_distractors"
    )
    parser.add_argument(
        "--distractor_type",
        help="type of distractor (hard, easy, gt)",
        default="easy",
    )
    parser.add_argument(
        "--num_distractors",
        help="Number of distractors per sample",
        default=4,
        type=int,
    )

    # args and init
    try:
        params = vars(parser.parse_args(args=argv))
    except IOError as msg:
        parser.error(str(msg))

    return params


def seed(seed):
    random.seed(seed)
    np.random.seed(seed)


if __name__ == "__main__":
    params = read_params()
    stories_root = params["story_root"]
    caption_root = params["caption_root"]
    input_seed = params["seed"]
    seed(input_seed)
    out_dir = params["out_dir"]
    out_file_prefix = params["out_prefix"]
    subsets = ["train", "val", "test"]
    num_distractors = params["num_distractors"]
    # read stories
    os.makedirs(os.path.join(out_dir, "seed-%d" % input_seed), exist_ok=True)
    for subset_id, subset in enumerate(subsets):
        out_path = os.path.join(
            out_dir,
            "seed-%d" % input_seed,
            "%s_%s_%s_%s.json" % (subset, params["distractor_type"], params["split"], out_file_prefix),
        )
        if params["split"] == "stories":
            data = json.load(
                open(os.path.join(stories_root, "%s.story-in-sequence.json" % subset))
            )
        elif params["split"] == "captions":
            data = json.load(
                open(os.path.join(caption_root, "%s.description-in-isolation.json" % subset))
            )
        else:
            raise NotImplementedError()
        annotations = data["annotations"]
        images = data["images"]
        image_id_2_index = {}
        album_id_2_image_id = {}
        for j, image in enumerate(images):
            image_id_2_index[image["id"]] = j
            album_id = image["album_id"]
            if album_id not in album_id_2_image_id:
                album_id_2_image_id[album_id] = []
            album_id_2_image_id[album_id].append(image["id"])
        assert len(annotations) % 5 == 0
        all_image_ids = set(list(image_id_2_index.keys()))

        album_id_2_index = {}
        for i in tqdm(range(len(annotations) // 5)):
            cur_sequence = annotations[5 * i : 5 * i + 5]
            cur_images = [item[0]["photo_flickr_id"] for item in cur_sequence]
            album_id = images[image_id_2_index[cur_images[0]]]["album_id"]
            if album_id not in album_id_2_index:
                album_id_2_index[album_id] = []
            album_id_2_index[album_id].append(i)

        out_data = []
        for album_id in tqdm(album_id_2_index):
            # pick one sample per album id
            sample_id = random.choice(album_id_2_index[album_id])
            sample_sequence = annotations[5 * sample_id : 5 * sample_id + 5]
            cur_images = [item[0]["photo_flickr_id"] for item in sample_sequence]
            cur_sample = {}
            cur_sample["photo_sequence"] = cur_images
            if params["distractor_type"] == "easy":
                distractor_ids = all_image_ids - set(cur_images)
                distractor_ids = list(distractor_ids)
                distractor_ids = random.sample(distractor_ids, num_distractors)
                cur_sample["distractors"] = distractor_ids
                out_data.append(cur_sample)
            elif params["distractor_type"] == "hard":
                distractor_ids = set(album_id_2_image_id[album_id]) - set(cur_images)
                distractor_ids = list(distractor_ids)
                distractor_ids = random.sample(distractor_ids, num_distractors)
                cur_sample["distractors"] = distractor_ids
                out_data.append(cur_sample)

            elif params["distractor_type"] == "gt":
                # find all images from the current album which part of a story
                all_images_album = []
                for j in album_id_2_index[album_id]:
                    cur_ann = annotations[5 * j : 5 * j + 5]
                    all_images_album.extend(
                        [item[0]["photo_flickr_id"] for item in cur_ann]
                    )
                import pdb
                pdb.set_trace()
                all_images_album = set(all_images_album)
                potential_distractors = all_images_album - set(cur_images)
                potential_distractors = list(potential_distractors)
                if len(potential_distractors) < num_distractors:
                    # don't have sufficient images in the album to create distractors
                    continue
                # randomly sample distractors
                distractor_ids = random.sample(potential_distractors, num_distractors)
                cur_sample["distractors"] = distractor_ids
                out_data.append(cur_sample)

            else:
                raise NotImplementedError()

        with open(out_path, "w") as f_out:
            json.dump(out_data, f_out)
