import json
import jsonlines
from collections import defaultdict

# results = []
# files = ['submit_test_challenge_public', 'submit_val_train_seen', 'submit_test_standard_public',   'submit_val_unseen', 'submit_val_seen']
# for file in files:
#     with open('../datasets/RxR/trained_models/valid_ensemble/preds/%s.json' % file, 'r') as f:
#         results = json.load(f)
#     f.close()
#
#     processed_results = []
#     for i, result in enumerate(results):
#         p_result = dict()
#         p_result['instruction_id'] = int(result['instr_id'].split("_")[-1])
#         # p_result['instruction_id'] = int(result['instr_id'])
#         p_result['path'] = [traj[0] for traj in result['trajectory']]
#         processed_results.append(p_result)
#
#     with jsonlines.open('../datasets/RxR/trained_models/valid_ensemble/preds/%s_processed.json' % file, mode='w') as writer:
#         writer.write_all(processed_results)



# results = []
# with open('snap/valid_clip/submit_test.json', 'r') as f:
#     results = json.load(f)
# f.close()
#
# for i, result in enumerate(results):
#     del results[i]['path_id']
#
# json.dump(
#                 results,
#                 open('snap/valid_clip/submit_test_processed_clip.json', 'w'),
#                 sort_keys=True, indent=4, separators=(',', ': ')
#             )
# processed_results = []
# for i, result in enumerate(results):
#     p_result = dict()
#     p_result['instruction_id'] = result['instr_id']
#     p_result['path'] = [traj[0] for traj in result['trajectory']]
#     processed_results.append(p_result)
#
# with jsonlines.open('snap/valid/submit_test_processed.jsonl', mode='w') as writer:
#     writer.write_all(processed_results)

# with open("/nas-hdd/jialu/VLN-Diffusion/datasets/R2R/exprs_map/finetune/r2r_test/preds/submit_test.json", "r") as f:
#     data = json.load(f)
# f.close()
#
# output = []
# for item in data:
#     new_item = dict()
#     instr_id = item['instr_id']
#     traj = item['trajectory']
#     new_item['instr_id'] = instr_id
#     traj_new = sum(traj, [])
#     new_item["trajectory"] = []
#     for traj in traj_new:
#         new_item["trajectory"].append([traj])
#     # if new_item["trajectory"][0] != item["trajectory"][0]:
#     #     print(item)
#     output.append(new_item)
#
# with open("/nas-hdd/jialu/VLN-Diffusion/datasets/R2R/exprs_map/finetune/r2r_test/preds/submit_test_processed.json", "w") as f:
#     json.dump(output, f)
# f.close()

with open("/nas-hdd/jialu/VLN-Diffusion/datasets/CVDN/exprs_map/finetune/cvdn_test/preds/submit_test.json", "r") as f:
    data = json.load(f)
f.close()

output = []
for item in data:
    new_item = dict()
    instr_id = item['instr_id']
    traj = item['trajectory']
    new_item['inst_idx'] = instr_id
    new_item["trajectory"] = []
    for vps in traj[:-1]:
        if len(vps) > 1:
            for vp in vps:
                new_item["trajectory"].append([vp])
        else:
            new_item["trajectory"].append(vps)
    new_item["trajectory"].append([traj[-1][0]])
    # if not isinstance(traj[-1], list):
    #     print(traj[-1])

# exit()
    # print(traj[-1])

    # new_item["trajectory"] = [[x[0]] for x in item['trajectory']]
    # print(item['trajectory'])
#     traj_new = sum(traj[:-1], [])
#
#     for traj in traj_new:
#         new_item["trajectory"].append([traj])
#     # if new_item["trajectory"][0] != item["trajectory"][0]:
#     #     print(item)
#     new_item["trajectory"].append([traj[-1][0]])
    output.append(new_item)
#
with open("/nas-hdd/jialu/VLN-Diffusion/datasets/CVDN/exprs_map/finetune/cvdn_test/preds/submit_test_processed.json", "w") as f:
    json.dump(output, f)
f.close()