import argparse
import json
import os
import sys
import matplotlib.pyplot as plt
import numpy as np

json_files_list = [
    # "/data/dataset/dataset_json/data_rewrite/flickr30k_train_caps=1.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top1.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top2.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top3.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top4.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_1.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_2.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_3.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_4.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top1_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top2_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top3_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top4_limit_caps=1.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_0.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_1.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_2.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_3.json",
]

method_names = [
    # "cap_train",
    "cap_top1",
    "cap_top2",
    "cap_top3",
    "cap_top4",
    "EDA0.3_1",
    "EDA0.3_2",
    "EDA0.3_3",
    "EDA0.3_4",
    "llama_top1",
    "llama_top2",
    "llama_top3",
    "llama_top4",
    "OFA_0",
    "OFA_1",
    "OFA_2",
    "OFA_3",
]

files_with_metric = [p.replace(".json", "_metric.json") for p in json_files_list]

method_name2file = dict(zip(method_names, files_with_metric))


for i, metric_file in enumerate(files_with_metric):
    method_name = method_names[i]
    if not os.path.exists(metric_file):
        print(f"File {metric_file} does not exist")
        continue

    with open(metric_file, 'r') as f:
        metric_data = json.load(f)

    print(f"File {metric_file} has {len(metric_data)} entries")

    bef_aft_l2dist_list = []
    cos_sim_list = []
    alignment_list = []

    for i, entry in enumerate(metric_data):
        bef_aft_l2dist = entry["bef_aft_l2dist"]
        cos_sim = entry["diff"]
        alignment = entry["alignment"]

        bef_aft_l2dist_list.append(bef_aft_l2dist)
        cos_sim_list.append(cos_sim)
        alignment_list.append(alignment)

    # plt.scatter(cos_sim_list, alignment_list, label=method_name, alpha=0.5)
    plt.scatter(bef_aft_l2dist_list, alignment_list, label=method_name, alpha=0.5)

plt.legend()
plt.savefig("plot_metric.png")


# gt alignment percentile
gt_cap_file = "/data/dataset/dataset_json/data_rewrite/flickr30k_train_caps=1.json"
gt_cap_metric_file = gt_cap_file.replace(".json", "_metric.json")
with open(gt_cap_metric_file, 'r') as f:
    gt_cap_metric_data = json.load(f)
alignment_list = [entry["alignment"] for entry in gt_cap_metric_data]
alignment_25th = np.percentile(alignment_list, 25)
alignment_50th = np.percentile(alignment_list, 50)
alignment_75th = np.percentile(alignment_list, 75)

print(f"GT cap alignment 25th percentile: {alignment_25th}")
print(f"GT cap alignment 50th percentile: {alignment_50th}")
print(f"GT cap alignment 75th percentile: {alignment_75th}")

# dist percentile
cap_1_metric_file = "/data/dataset/dataset_json/data_idx/flickr30k_train_idx1_metric.json"
with open(cap_1_metric_file, 'r') as f:
    cap_1_metric_data = json.load(f)

bef_aft_l2dist_list = [entry["bef_aft_l2dist"] for entry in cap_1_metric_data]
dist_25th = np.percentile(bef_aft_l2dist_list, 25)
dist_50th = np.percentile(bef_aft_l2dist_list, 50)
dist_75th = np.percentile(bef_aft_l2dist_list, 75)
print(f"Dist 25th percentile: {dist_25th}")
print(f"Dist 50th percentile: {dist_50th}")
print(f"Dist 75th percentile: {dist_75th}")

# cos sim percentile
# cos_sim_list = [entry["diff"] for entry in cap_1_metric_data]
# cos_sim_25th = np.percentile(cos_sim_list, 25)
# cos_sim_50th = np.percentile(cos_sim_list, 50)
# cos_sim_75th = np.percentile(cos_sim_list, 75)
# print(f"Cos sim 25th percentile: {cos_sim_25th}")
# print(f"Cos sim 50th percentile: {cos_sim_50th}")
# print(f"Cos sim 75th percentile: {cos_sim_75th}")


# filter by alignment / dist percentile
caps_alignOver50th = []
caps_alignUnder50th = []
caps_distOver50th = []
caps_distUnder50th = []
caps_alignOver50th_distOver50th = []
caps_alignOver50th_distUnder50th = []
caps_alignUnder50th_distOver50th = []
caps_alignUnder50th_distUnder50th = []

caps_files = [
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top1.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top2.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top3.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top4.json",
]

for cap_file in caps_files:
    cap_metric_file = cap_file.replace(".json", "_metric.json")
    with open(cap_metric_file, 'r') as f:
        cap_metric_data = json.load(f)

    for i, entry in enumerate(cap_metric_data):
        bef_aft_l2dist = entry["bef_aft_l2dist"]
        alignment = entry["alignment"]

        if alignment >= alignment_50th:
            caps_alignOver50th.append(entry)
        else:
            caps_alignUnder50th.append(entry)

        if bef_aft_l2dist >= dist_50th:
            caps_distOver50th.append(entry)
        else:
            caps_distUnder50th.append(entry)

        if alignment >= alignment_50th and bef_aft_l2dist >= dist_50th:
            caps_alignOver50th_distOver50th.append(entry)
        elif alignment >= alignment_50th and bef_aft_l2dist < dist_50th:
            caps_alignOver50th_distUnder50th.append(entry)
        elif alignment < alignment_50th and bef_aft_l2dist >= dist_50th:
            caps_alignUnder50th_distOver50th.append(entry)
        else:
            caps_alignUnder50th_distUnder50th.append(entry)

print(f"caps_alignOver50th_distOver50th: {len(caps_alignOver50th_distOver50th)}")
print(f"caps_alignOver50th_distUnder50th: {len(caps_alignOver50th_distUnder50th)}")
print(f"caps_alignUnder50th_distOver50th: {len(caps_alignUnder50th_distOver50th)}")
print(f"caps_alignUnder50th_distUnder50th: {len(caps_alignUnder50th_distUnder50th)}")

min_len = min(len(caps_alignOver50th_distOver50th), len(caps_alignOver50th_distUnder50th), len(caps_alignUnder50th_distOver50th), len(caps_alignUnder50th_distUnder50th))
print(f"min_len: {min_len}")
caps_alignOver50th_distOver50th = caps_alignOver50th_distOver50th[:min_len]
caps_alignOver50th_distUnder50th = caps_alignOver50th_distUnder50th[:min_len]
caps_alignUnder50th_distOver50th = caps_alignUnder50th_distOver50th[:min_len]
caps_alignUnder50th_distUnder50th = caps_alignUnder50th_distUnder50th[:min_len]

dir = "/data/dataset/dataset_json/data_filtered"
os.makedirs(dir, exist_ok=True)
with open(f"{dir}/caps1234_alignOver50th_distOver50th.json", 'w') as f:
    json.dump(caps_alignOver50th_distOver50th, f, indent=4)
with open(f"{dir}/caps1234_alignOver50th_distUnder50th.json", 'w') as f:
    json.dump(caps_alignOver50th_distUnder50th, f, indent=4)
with open(f"{dir}/caps1234_alignUnder50th_distOver50th.json", 'w') as f:
    json.dump(caps_alignUnder50th_distOver50th, f, indent=4)
with open(f"{dir}/caps1234_alignUnder50th_distUnder50th.json", 'w') as f:
    json.dump(caps_alignUnder50th_distUnder50th, f, indent=4)

print(f"caps_alignOver50th: {len(caps_alignOver50th)}")
print(f"caps_alignUnder50th: {len(caps_alignUnder50th)}")
print(f"caps_distOver50th: {len(caps_distOver50th)}")
print(f"caps_distUnder50th: {len(caps_distUnder50th)}")

min_len = min(len(caps_alignOver50th), len(caps_alignUnder50th), len(caps_distOver50th), len(caps_distUnder50th))
print(f"min_len: {min_len}")
with open(f"{dir}/caps1234_alignOver50th.json", 'w') as f:
    json.dump(caps_alignOver50th[:min_len], f, indent=4)
with open(f"{dir}/caps1234_alignUnder50th.json", 'w') as f:
    json.dump(caps_alignUnder50th[:min_len], f, indent=4)
with open(f"{dir}/caps1234_distOver50th.json", 'w') as f:
    json.dump(caps_distOver50th[:min_len], f, indent=4)
with open(f"{dir}/caps1234_distUnder50th.json", 'w') as f:
    json.dump(caps_distUnder50th[:min_len], f, indent=4)

# plot
plt.figure()
plt.scatter([entry["bef_aft_l2dist"] for entry in caps_alignOver50th_distOver50th], [entry["alignment"] for entry in caps_alignOver50th_distOver50th], label="alignOver50th_distOver50th", alpha=0.5)
plt.scatter([entry["bef_aft_l2dist"] for entry in caps_alignOver50th_distUnder50th], [entry["alignment"] for entry in caps_alignOver50th_distUnder50th], label="alignOver50th_distUnder50th", alpha=0.5)
plt.scatter([entry["bef_aft_l2dist"] for entry in caps_alignUnder50th_distOver50th], [entry["alignment"] for entry in caps_alignUnder50th_distOver50th], label="alignUnder50th_distOver50th", alpha=0.5)
plt.scatter([entry["bef_aft_l2dist"] for entry in caps_alignUnder50th_distUnder50th], [entry["alignment"] for entry in caps_alignUnder50th_distUnder50th], label="alignUnder50th_distUnder50th", alpha=0.5)
plt.legend()
plt.savefig("plot_cap_filtered.png")



# # filter all by alignment / dist percentile
# all_alignOver50th = []
# all_alignUnder50th = []
# all_distOver50th = []
# all_distUnder50th = []
# all_alignOver50th_distOver50th = []
# all_alignOver50th_distUnder50th = []
# all_alignUnder50th_distOver50th = []
# all_alignUnder50th_distUnder50th = []

# json_files_list = [
#     "/data/dataset/dataset_json/data_sorted/flickr30k_train_top1.json",
#     "/data/dataset/dataset_json/data_sorted/flickr30k_train_top2.json",
#     "/data/dataset/dataset_json/data_sorted/flickr30k_train_top3.json",
#     "/data/dataset/dataset_json/data_sorted/flickr30k_train_top4.json",
#     "/data/dataset/dataset_json/EDA0.3/flickr30k_train_1.json",
#     "/data/dataset/dataset_json/EDA0.3/flickr30k_train_2.json",
#     "/data/dataset/dataset_json/EDA0.3/flickr30k_train_3.json",
#     "/data/dataset/dataset_json/EDA0.3/flickr30k_train_4.json",
#     "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top1_limit_caps=1.json",
#     "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top2_limit_caps=1.json",
#     "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top3_limit_caps=1.json",
#     "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top4_limit_caps=1.json",
#     "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_0.json",
#     "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_1.json",
#     "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_2.json",
#     "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_3.json",
#     # "/data/dataset/dataset_json/data_stableDiffusion/many-to-one_cap=0_useCap0_imgGenIdx0_metric.json"
# ]

# for json_file in json_files_list:
#     json_metric_file = json_file.replace(".json", "_metric.json")
#     with open(json_metric_file, 'r') as f:
#         json_metric_data = json.load(f)

#     for i, entry in enumerate(json_metric_data):
#         bef_aft_l2dist = entry["bef_aft_l2dist"]
#         alignment = entry["alignment"]

#         if alignment >= alignment_50th:
#             all_alignOver50th.append(entry)
#         else:
#             all_alignUnder50th.append(entry)

#         if bef_aft_l2dist >= dist_50th:
#             all_distOver50th.append(entry)
#         else:
#             all_distUnder50th.append(entry)

#         if alignment >= alignment_50th and bef_aft_l2dist >= dist_50th:
#             all_alignOver50th_distOver50th.append(entry)
#         elif alignment >= alignment_50th and bef_aft_l2dist < dist_50th:
#             all_alignOver50th_distUnder50th.append(entry)
#         elif alignment < alignment_50th and bef_aft_l2dist >= dist_50th:
#             all_alignUnder50th_distOver50th.append(entry)
#         else:
#             all_alignUnder50th_distUnder50th.append(entry)

# print(f"all_alignOver50th_distOver50th: {len(all_alignOver50th_distOver50th)}")
# print(f"all_alignOver50th_distUnder50th: {len(all_alignOver50th_distUnder50th)}")
# print(f"all_alignUnder50th_distOver50th: {len(all_alignUnder50th_distOver50th)}")
# print(f"all_alignUnder50th_distUnder50th: {len(all_alignUnder50th_distUnder50th)}")

# min_len = min(len(all_alignOver50th_distOver50th), len(all_alignOver50th_distUnder50th), len(all_alignUnder50th_distOver50th), len(all_alignUnder50th_distUnder50th))
# print(f"min_len: {min_len}")
# all_alignOver50th_distOver50th = all_alignOver50th_distOver50th[:min_len]
# all_alignOver50th_distUnder50th = all_alignOver50th_distUnder50th[:min_len]
# all_alignUnder50th_distOver50th = all_alignUnder50th_distOver50th[:min_len]
# all_alignUnder50th_distUnder50th = all_alignUnder50th_distUnder50th[:min_len]

# dir = "/data/dataset/dataset_json/data_filtered"
# os.makedirs(dir, exist_ok=True)
# with open(f"{dir}/all_alignOver50th_distOver50th.json", 'w') as f:
#     json.dump(all_alignOver50th_distOver50th, f, indent=4)
# with open(f"{dir}/all_alignOver50th_distUnder50th.json", 'w') as f:
#     json.dump(all_alignOver50th_distUnder50th, f, indent=4)
# with open(f"{dir}/all_alignUnder50th_distOver50th.json", 'w') as f:
#     json.dump(all_alignUnder50th_distOver50th, f, indent=4)
# with open(f"{dir}/all_alignUnder50th_distUnder50th.json", 'w') as f:
#     json.dump(all_alignUnder50th_distUnder50th, f, indent=4)

# print(f"all_alignOver50th: {len(all_alignOver50th)}")
# print(f"all_alignUnder50th: {len(all_alignUnder50th)}")
# print(f"all_distOver50th: {len(all_distOver50th)}")
# print(f"all_distUnder50th: {len(all_distUnder50th)}")

# min_len = min(len(all_alignOver50th), len(all_alignUnder50th), len(all_distOver50th), len(all_distUnder50th))
# print(f"min_len: {min_len}")
# with open(f"{dir}/all_alignOver50th.json", 'w') as f:
#     json.dump(all_alignOver50th[:min_len], f, indent=4)
# with open(f"{dir}/all_alignUnder50th.json", 'w') as f:
#     json.dump(all_alignUnder50th[:min_len], f, indent=4)
# with open(f"{dir}/all_distOver50th.json", 'w') as f:
#     json.dump(all_distOver50th[:min_len], f, indent=4)
# with open(f"{dir}/all_distUnder50th.json", 'w') as f:
#     json.dump(all_distUnder50th[:min_len], f, indent=4)

# # plot
# plt.figure()
# plt.scatter([entry["bef_aft_l2dist"] for entry in all_alignOver50th_distOver50th], [entry["alignment"] for entry in all_alignOver50th_distOver50th], label="alignOver50th_distOver50th", alpha=0.5)
# plt.scatter([entry["bef_aft_l2dist"] for entry in all_alignOver50th_distUnder50th], [entry["alignment"] for entry in all_alignOver50th_distUnder50th], label="alignOver50th_distUnder50th", alpha=0.5)
# plt.scatter([entry["bef_aft_l2dist"] for entry in all_alignUnder50th_distOver50th], [entry["alignment"] for entry in all_alignUnder50th_distOver50th], label="alignUnder50th_distOver50th", alpha=0.5)
# plt.scatter([entry["bef_aft_l2dist"] for entry in all_alignUnder50th_distUnder50th], [entry["alignment"] for entry in all_alignUnder50th_distUnder50th], label="alignUnder50th_distUnder50th", alpha=0.5)
# plt.legend()
# plt.savefig("plot_all_filtered.png")


# filter all by alignment / dist percentile
all_align25_75_dist_25_75 = []

json_files_list = [
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top1.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top2.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top3.json",
    "/data/dataset/dataset_json/data_sorted/flickr30k_train_top4.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_1.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_2.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_3.json",
    "/data/dataset/dataset_json/EDA0.3/flickr30k_train_4.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top1_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top2_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top3_limit_caps=1.json",
    "/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_coco_multi_top4_limit_caps=1.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_0.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_1.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_2.json",
    "/data/dataset/dataset_json/OFA-large-caption/flickr30k_train_3.json",
    # "/data/dataset/dataset_json/data_stableDiffusion/many-to-one_cap=0_useCap0_imgGenIdx0_metric.json"
]

for json_file in json_files_list:
    json_metric_file = json_file.replace(".json", "_metric.json")
    with open(json_metric_file, 'r') as f:
        json_metric_data = json.load(f)

    n = 0
    for i, entry in enumerate(json_metric_data):
        bef_aft_l2dist = entry["bef_aft_l2dist"]
        alignment = entry["alignment"]

        if alignment_25th < alignment < alignment_75th and dist_25th < bef_aft_l2dist < dist_75th:
            all_align25_75_dist_25_75.append(entry)
            n += 1

    print(f"File {json_file} has {n} entries in 25-75 alignment and 25-75 dist")

print(f"all_align25_75_dist_25_75: {len(all_align25_75_dist_25_75)}")

dir = "/data/dataset/dataset_json/data_filtered"
os.makedirs(dir, exist_ok=True)
with open(f"{dir}/all_align25_75_dist_25_75.json", 'w') as f:
    json.dump(all_align25_75_dist_25_75, f, indent=4)

# plot
plt.figure()
plt.scatter([entry["bef_aft_l2dist"] for entry in all_align25_75_dist_25_75], [entry["alignment"] for entry in all_align25_75_dist_25_75], label="alignOver50th_distOver50th", alpha=0.5)
plt.legend()
plt.savefig("plot_all_filtered.png")

