import os
os.environ["PYTHONUTF8"] = "1"
import ir_datasets.util.download as download
download._CLEANUP_TMP = False

import os
os.environ["IR_DATASETS_CACHE"] = "D:/ir_datasets_cache"


import torch
import numpy as np

import utils.const as C

import utils.datasets as UD
import utils.embeddings as UE
import utils.distance as D
import utils.helpers as UH

def load_datasets():
    sqrt_num_points = int(np.sqrt(C.NUM_POINTS))
    encoder_name = C.ENCODER_NAME.split("/")[1]
    for dataset_key in C.IR_DATASETS:
        name = "_".join(dataset_key.split("/"))
        qrels, query_texts, doc_texts = UD.get_dataset(dataset_key, name)
        qrels_by_len = UD.get_qrels_by_len(qrels)
        UE.get_embeddings(qrels, query_texts, doc_texts, name)
        K1s = {}
        K2s = {}
        K3s = {}

        num_docs = sorted(list(set(qrels_by_len.keys())))
        print(num_docs)
        for n_docs in sorted(num_docs):
            for m_docs in sorted(num_docs):
                n_docs = int(n_docs)
                m_docs = int(m_docs)
                print(n_docs, m_docs)
                if not os.path.isfile(f"cache/{encoder_name}/q_embs/{n_docs}/{name}.pt"):
                    continue
                if not os.path.isfile(f"cache/{encoder_name}/q_embs/{m_docs}/{name}.pt"):
                    continue
                q_embs_n = torch.load(f"cache/{encoder_name}/q_embs/{n_docs}/{name}.pt", map_location=C.DEVICE)
                tl_embs_n = torch.load(f"cache/{encoder_name}/tl_embs/{n_docs}/{name}.pt", map_location=C.DEVICE)
                centroid_embs_n, _ = torch.load(f"cache/{encoder_name}/centroids/{n_docs}/{name}.pt", map_location=C.DEVICE)
                sf_embs_n = torch.load(f"cache/{encoder_name}/score_fields/{n_docs}/{name}.pt", map_location=C.DEVICE).float()

                q_embs_m = torch.load(f"cache/{encoder_name}/q_embs/{m_docs}/{name}.pt", map_location=C.DEVICE)
                tl_embs_m = torch.load(f"cache/{encoder_name}/tl_embs/{m_docs}/{name}.pt", map_location=C.DEVICE)
                centroid_embs_m, _ = torch.load(f"cache/{encoder_name}/centroids/{m_docs}/{name}.pt", map_location=C.DEVICE)
                sf_embs_m = torch.load(f"cache/{encoder_name}/score_fields/{m_docs}/{name}.pt", map_location=C.DEVICE).float()

                if (n_docs, m_docs) not in K1s: K1s[(n_docs, m_docs)] = []
                if (n_docs, m_docs) not in K2s: K2s[(n_docs, m_docs)] = []
                if (n_docs, m_docs) not in K3s: K3s[(n_docs, m_docs)] = []

                for i in range (q_embs_n.shape[0]):
                    for j in range(q_embs_m.shape[0]):
                        if n_docs == m_docs and i == j: continue
                        q_dist = D.calculate_embedding_distance(q_embs_n[i], q_embs_m[j])
                        tl_dist = D.calculate_token_level_distance(tl_embs_n[i], tl_embs_m[j])
                        centroid_dist = D.calculate_embedding_distance(centroid_embs_n[i], centroid_embs_m[j])
                        sf_dist = D.calculate_score_field_distance(sf_embs_n[i], sf_embs_m[j], sqrt_num_points)
                        K1 = centroid_dist / q_dist
                        K2 = sf_dist / q_dist
                        K3 = sf_dist / tl_dist
                        # print(K1, K2, K3)
                        K1s[(n_docs, m_docs)] = np.append(K1s[(n_docs, m_docs)], K1.item())
                        K2s[(n_docs, m_docs)] = np.append(K2s[(n_docs, m_docs)], K2.item())
                        K3s[(n_docs, m_docs)] = np.append(K3s[(n_docs, m_docs)], K3.item())
        for n_docs in sorted(num_docs):
            os.remove(f"cache/{encoder_name}/score_fields/{n_docs}/{name}.pt")
        # import utils.plot as UP
        # UP.plot_median_heatmap(K1s, f"K1_{name}")
        # UP.plot_median_logplot(K1s, f"K1_{name}")

        # UP.plot_median_heatmap(K2s, f"K2_{name}")
        # UP.plot_median_logplot(K2s, f"K2_{name}")

        # UP.plot_median_heatmap(K3s, f"K3_{name}")
        # UP.plot_median_logplot(K3s, f"K3_{name}")
        UH.save_dicts_to_cache(UH.encode_keys(K1s), f"{encoder_name}/K1", f"{name}.json")
        UH.save_dicts_to_cache(UH.encode_keys(K2s), f"{encoder_name}/K2", f"{name}.json")
        UH.save_dicts_to_cache(UH.encode_keys(K3s), f"{encoder_name}/K3", f"{name}.json")

def main():
    load_datasets()

if __name__ == "__main__":
    main()