# Script for generating visualizations of the VIST dataset
import csv
import json
import os
import random
from time import sleep

import requests
from tqdm import tqdm


# utility for checking image url
def server_response(url):
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36"
    }

    response = requests.get(url, headers=headers, stream=True)
    status = response.status_code

    if status == 200:
        return "ok"
    else:
        print("Response code: %d \nURI: %s" % (status, url))
        return "bad"


STORIES = "/n/fs/nlp-xxxx/datasets/VIST/sis/test.story-in-sequence.json"
OUT_DIR = "/u/xxxx/world-models/checkpoints/VIST-AMT/"
NUM_SAMPLES = 600
NUM_DISTRACTORS = 4
DISTRACTORS_PATH = "/u/xxxx/world-models/configs/vist_distractor_ids/seed-123/test_hard_sampled_distractors.json"

if __name__ == "__main__":
    
    with open(os.path.join(OUT_DIR, "vist_amt_sample_hard_seed_123_viz.csv"), "w") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(
            [
                "context_url_1",
                "context_url_2",
                "context_url_3",
                "context_url_4",
                "candidate_url_1",
                "candidate_url_2",
                "candidate_url_3",
                "candidate_url_4",
                "candidate_url_5",
                "gt_index",
            ]
        )
        data = json.load(open(STORIES))
        images = data["images"]
        image_id_2_index = {}
        album_id_2_image_id = {}
        for j, image in enumerate(images):
            image_id_2_index[image["id"]] = j
            album_id = image["album_id"]
            if album_id not in album_id_2_image_id:
                album_id_2_image_id[album_id] = []
            album_id_2_image_id[album_id].append(image["id"])

        with open(DISTRACTORS_PATH) as f:
            distractor_data = json.load(f)

        unique_set_images = {}
        cur_num_samples = 0
        # for i in tqdm(random.sample(range(len(distractor_data)), int(1.5 * NUM_SAMPLES))):
        for i in tqdm(range(int(1.5 * NUM_SAMPLES))):

            if cur_num_samples >= NUM_SAMPLES:
                break
            sample = distractor_data[i]
            distractors = sample["distractors"]
            cur_images = sample["photo_sequence"]

            url_present = ["url_o" in images[image_id_2_index[img_id]] for img_id in cur_images]
            url_present += ["url_o" in images[image_id_2_index[img_id]] for img_id in distractors]
            
            if not all(url_present):
                print("url_o key missing")
                continue
            
            cur_image_urls = [
                images[image_id_2_index[img_id]]["url_o"] for img_id in cur_images
            ]
            distractor_image_urls = [
                images[image_id_2_index[img_id]]["url_o"]
                for img_id in distractors
            ]
            # check if images have valid urls
            bad_url = False
            for url in cur_image_urls + distractor_image_urls:
                status = server_response(url)
                if "bad" == status:
                    bad_url = True
                    break
            if bad_url:
                continue

            # sample gt index as we need to shuffle the candidates
            gt_index = random.randrange(NUM_DISTRACTORS + 1)
            gt_candidate_url = cur_image_urls.pop(-1)
            distractor_image_urls.insert(gt_index, gt_candidate_url)
            row = cur_image_urls + distractor_image_urls + [gt_index]
            writer.writerow(row)
            cur_num_samples += 1
    

    """
    with open(os.path.join(OUT_DIR, "vist_amt_sample.csv"), "w") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(
            [
                "context_url_1",
                "context_url_2",
                "context_url_3",
                "context_url_4",
                "candidate_url_1",
                "candidate_url_2",
                "candidate_url_3",
                "candidate_url_4",
                "candidate_url_5",
                "gt_index",
            ]
        )
        data = json.load(open(STORIES))
        annotations = data["annotations"]
        images = data["images"]
        image_id_2_index = {}
        album_id_2_image_id = {}
        for j, image in enumerate(images):
            image_id_2_index[image["id"]] = j
            album_id = image["album_id"]
            if album_id not in album_id_2_image_id:
                album_id_2_image_id[album_id] = []
            album_id_2_image_id[album_id].append(image["id"])
        assert len(annotations) % 5 == 0

        out_data = {"output_stories": []}
        unique_set_images = {}
        cur_num_samples = 0
        for i in tqdm(random.sample(range(len(annotations) // 5), 3 * NUM_SAMPLES)):
            if cur_num_samples >= NUM_SAMPLES:
                break
            cur_sequence = annotations[5 * i : 5 * i + 5]
            cur_images = [item[0]["photo_flickr_id"] for item in cur_sequence]

            if "-".join(cur_images) not in unique_set_images:
                unique_set_images["-".join(cur_images)] = ""
            else:
                continue

            album_id = cur_sequence[0][0]["album_id"]
            distractors = set(album_id_2_image_id[album_id]) - set(cur_images)
            assert len(distractors.intersection(set(cur_images))) == 0
            distractors = list(distractors)

            # sample a distractor
            sampled_distractors = random.sample(distractors, NUM_DISTRACTORS)
            cur_image_urls = [
                images[image_id_2_index[img_id]]["url_o"] for img_id in cur_images
            ]
            distractor_image_urls = [
                images[image_id_2_index[img_id]]["url_o"]
                for img_id in sampled_distractors
            ]
            # check if images have valid urls
            bad_url = False
            for url in cur_image_urls + distractor_image_urls:
                status = server_response(url)
                if "bad" == status:
                    bad_url = True
                    break
            if bad_url:
                continue

            # sample gt index as we need to shuffle the candidates
            gt_index = random.randrange(NUM_DISTRACTORS + 1)
            gt_candidate_url = cur_image_urls.pop(-1)
            distractor_image_urls.insert(gt_index, gt_candidate_url)
            row = cur_image_urls + distractor_image_urls + [gt_index]
            writer.writerow(row)
            cur_num_samples += 1
    
    """
    """
    with open(os.path.join(OUT_DIR, "vist_amt_sample.csv")) as csv_file, open(
        os.path.join(OUT_DIR, "vist_amt_sample_easy_distractors.csv"), "w"
    ) as csv_file_out:
        writer = csv.writer(csv_file_out)
        writer.writerow(
            [
                "context_url_1",
                "context_url_2",
                "context_url_3",
                "context_url_4",
                "candidate_url_1",
                "candidate_url_2",
                "candidate_url_3",
                "candidate_url_4",
                "candidate_url_5",
                "gt_index",
            ]
        )

        data = json.load(open(STORIES))
        annotations = data["annotations"]
        images = data["images"]
        image_id_2_index = {}
        album_id_2_image_id = {}
        imageurl_2_index = {}
        for j, image in enumerate(images):
            image_id_2_index[image["id"]] = j
            if "url_o" in image:
                imageurl_2_index[image["url_o"]] = j
            album_id = image["album_id"]
            if album_id not in album_id_2_image_id:
                album_id_2_image_id[album_id] = []
            album_id_2_image_id[album_id].append(image["id"])
        assert len(annotations) % 5 == 0
        all_images = set(list(image_id_2_index.keys()))

        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            row_out = []
            row_out.extend(
                [
                    row["context_url_1"],
                    row["context_url_2"],
                    row["context_url_3"],
                    row["context_url_4"],
                ]
            )
            cur_album = images[imageurl_2_index[row["context_url_1"]]]["album_id"]
            gt_index = int(row["gt_index"])
            cur_images = [
                images[imageurl_2_index[url]]["id"]
                for url in [
                    row["context_url_1"],
                    row["context_url_2"],
                    row["context_url_3"],
                    row["context_url_4"],
                    row["candidate_url_%d" % (gt_index + 1)],
                ]
            ]
            while True:
                distractors = all_images - set(cur_images)
                distractors = list(distractors)
                sampled_distractors = random.sample(distractors, NUM_DISTRACTORS)

                distractor_image_urls = []
                absent_url = False
                for img_id in sampled_distractors:
                    if "url_o" not in images[image_id_2_index[img_id]]:
                        absent_url = True
                        break
                    else:
                        distractor_image_urls.append(
                            images[image_id_2_index[img_id]]["url_o"]
                        )
                if absent_url:
                    continue
                # check if images have valid urls
                bad_url = False
                for url in distractor_image_urls:
                    status = server_response(url)
                    if "bad" == status:
                        bad_url = True
                        break
                if bad_url:
                    continue
                else:
                    print("gt index", gt_index)
                    distractor_image_urls.insert(
                        gt_index, row["candidate_url_%d" % (gt_index + 1)]
                    )
                    row_out.extend(distractor_image_urls)
                    row_out.append(gt_index)
                    writer.writerow(row_out)
                    break
    """