import PIL
from PIL import Image

import json
import os
import numpy as np
# data file: scan, vp, idx, direction, prompt

# with open("../../VLN-DUET/datasets/R2R/annotations/R2R_train_enc.json", 'r') as f:
#     train_data = json.load(f)
# f.close()

# with open("../../VLN-DUET/datasets/R2R/annotations/R2R_val_seen_enc.json", 'r') as f:
#     train_data = json.load(f)
#     print(len(train_data))
# f.close()

with open("../../datasets/R2R/annotations/R2R_val_unseen_enc.json", 'r') as f:
    train_data = json.load(f)
    print(len(train_data))
f.close()

scene_ids = set()

for data in train_data:
    scene_ids.add(data["scan"])

scans = sorted(list(scene_ids))


# with open("../../BLIP-2/captions.json", "r") as f:
#     results = json.load(f)
#
# results_new = dict()
# print(len(results))
# for key_, value in results.items():
#     scan = key_.split("_")[0]
#     if scan in scans:
#         results_new[key_] = value
#
# print(len(results_new))

scanvp_cands = json.load(open("../../datasets/R2R/annotations/scanvp_candview_relangles.json"))

output_list = []

start_index=0
end_index=1200
print(len(train_data))
correct = 0
total = 0

# for i, data in enumerate(train_data[start_index:end_index]):
for i, data in enumerate(train_data):
    print("Processing:", i, len(train_data), correct, total)
    # data = train_data[100]
    scan = data["scan"]
    # sim.newEpisode([scan], [data["path"][0]], [data["heading"]], [0])
    for j in range(len(data["path"])-1):
        for k in range(len(data["instructions"])):
            vp = data["path"][j]
            candidates = list(scanvp_cands['%s_%s' % (scan, vp)].keys())
            target_vp = data["path"][j+1]
            viewidx = scanvp_cands['%s_%s' % (scan, vp)][target_vp][0]
            min_sim = 1000000000000000000
            for m, candidate in enumerate(candidates):
                cand_vp = candidate
                cand_viewidx = scanvp_cands['%s_%s' % (scan, vp)][candidate][0]
                cand_img = np.array(Image.open(os.path.join("/nas-ssd/jialu/datasets/views_img", scan, cand_vp,
                                                            str(cand_viewidx) + ".jpg")).resize((512, 512)))
                similarity = 0
                for c in range(len(candidates)):
                    generate_path = os.path.join("/nas-ssd/jialu/datasets/views_img_sd_tuneip2p_future_val_cand_all",
                                                 str(data["path_id"]) + "_" + str(k) + "_" + str(j) + "_" + str(
                                                     c) + ".jpg")
                    if not os.path.exists(generate_path):
                        continue
                    generate_img = Image.open(generate_path)
                    generate_img = np.array(generate_img.resize((512, 512)))

                    similarity += np.sum((generate_img - cand_img) ** 2) ** (1/2)

                if similarity < min_sim:
                    min_sim = similarity
                    target = candidate

            if target == target_vp:
                correct += 1

            total += 1

print(correct, total)
print(correct / total)

