# Script for generating visualizations of the VIST dataset
import json
import os
import random

from table_visualizer import TableVisualizer

DATA = "/n/fs/nlp-xxxx/datasets/VIST/sis/val.story-in-sequence.json"
# GEN_CAPTIONS = (
#     "/u/xxxx/world-models/data/VIST/preprocessed/VIST_val_gen_caption.tsv"
# )
OUT_DIR = "/n/fs/nlp-xxxx/projects/world-models/checkpoints"
if __name__ == "__main__":

    data = json.load(open(DATA))
    annotations = data["annotations"]
    images = data["images"]
    image_id_2_index = {}
    for j, image in enumerate(images):
        image_id_2_index[image["id"]] = j
    assert len(annotations) % 5 == 0

    # image2gencaption = {}
    # f = open(GEN_CAPTIONS)
    # for line in f.readlines():
    #     image_id = int(line.split("\t")[0])
    #     image2gencaption[image_id] = json.loads(line.split("\t")[1])[0]["caption"]

    table_configs = []

    table_configs.append(
        {
            "id": "idx",
            "display_name": "idx",
            "type": "text",
            "sortable": True,
            "width": "3%",
        }
    )

    table_configs.append(
        {
            "id": "GTCaptions",
            "display_name": "GT Captions",
            "type": "text",
            "sortable": True,
            "width": "15%",
        }
    )
    # table_configs.append(
    #     {
    #         "id": "GenCaptions",
    #         "display_name": "Generated Captions",
    #         "type": "text",
    #         "sortable": True,
    #         "width": "15%",
    #     }
    # )
    table_configs.append(
        {
            "id": "Image1",
            "display_name": "Image-1",
            "type": "image",
            "height": 200,
            "width": "13%",
        }
    )

    table_configs.append(
        {
            "id": "Image2",
            "display_name": "Image-2",
            "type": "image",
            "height": 200,
            "width": "13%",
        }
    )

    table_configs.append(
        {
            "id": "Image3",
            "display_name": "Image-3",
            "type": "image",
            "height": 200,
            "width": "13%",
        }
    )
    table_configs.append(
        {
            "id": "Image4",
            "display_name": "Image-4",
            "type": "image",
            "height": 200,
            "width": "13%",
        }
    )
    table_configs.append(
        {
            "id": "Image5",
            "display_name": "Image-5",
            "type": "image",
            "height": 200,
            "width": "13%",
        }
    )

    # get 5 images for each album, create visualization
    table_viz = TableVisualizer(
        table_configs,
        "assets/table_visualizer_style.css",
        os.path.join(OUT_DIR, "viz_sample_stories_val.html"),
    )
    count = 0

    # for i in random.sample(range(1000), 400)
    for i in range(len(annotations) // 5):
        cur_album = annotations[5 * i : 5 * i + 5]
        row_viz = []
        cur_album_captions = []
        cur_album_gen_captions = []
        cur_album_images = []
        cur_image_ids = []
        row_viz.append(str(i))
        for cur_album_item in cur_album:

            cur_album_captions.append(cur_album_item[0]["original_text"])
            cur_image_id = image_id_2_index[cur_album_item[0]["photo_flickr_id"]]

            cur_album_images.append(
                images[cur_image_id]["url_o"]
                if "url_o" in images[cur_image_id]
                else None
            )

            cur_image_ids.append(images[cur_image_id]["id"])

            # cur_album_gen_captions.append(
            #     image2gencaption[int(cur_album_item[0]["photo_flickr_id"])]
            # )
        if None in cur_album_images:
            continue
        row_viz.append(cur_album_captions + cur_image_ids)
        # row_viz.append(cur_album_gen_captions)
        row_viz.extend(cur_album_images)
        table_viz.add_row(row_viz)

    table_viz.render()
