import os
import numpy as np

import utils.const as C

import utils.helpers as UH

import utils.plot as UP


# def main():

#     # all-MiniLM-L6-v2/K2/beir_climate-fever.json
#     encoders = [
#         "sentence-transformers/all-MiniLM-L6-v2",
#         "sentence-transformers/all-mpnet-base-v2",
#         "sentence-transformers/all-roberta-large-v1",
#     ]
#     for K_path in ["K1", "K2", "K3"]:
#         G = {}
#         for encoder in encoders:
#             emb_name = encoder.split("/")[1]
#             base_path = f"cache/{emb_name}"

            

#             for file in os.listdir(f"{base_path}/{K_path}"):

#                 K = UH.load_from_cache(f"{emb_name}/{K_path}", file)

#                 file_name = file.split("/")[-1]

#                 K = UH.bin_tuple_dict_log_to_tuple(K)

#                 M = UH.get_median_Ks(K)

#                 # delta = UH.get_tuple_deltas_from_medians(M)

#                 glob = UH.get_mean_deltas_diagonal(M)

#                 for k, v in glob.items():

#                     if k not in G:

#                         G[k] = []

#                     G[k].append(v)

#         G = UH.get_median_Ks(G)

#         UP.plot_individ_median_logplot_with_fit(G, f"deltas/{K_path}_all")

def main():

    # all-MiniLM-L6-v2/K2/beir_climate-fever.json
    encoders = [
        "sentence-transformers/all-MiniLM-L6-v2",
        "sentence-transformers/all-mpnet-base-v2",
        "sentence-transformers/all-roberta-large-v1",
    ]
    for K_path in ["K1", "K2", "K3"]:
        G = {}
        for encoder in encoders:
            emb_name = encoder.split("/")[1]
            base_path = f"cache/{emb_name}"
            for file in os.listdir(f"{base_path}/{K_path}"):
                K = UH.load_from_cache(f"{emb_name}/{K_path}", file)
                for k, v in K.items():
                    if k not in G:
                        G[k] = []
                    v = [a for a in v if not np.isnan(a) and not np.isinf(a)]
                    G[k].extend(v)
        D = UH.bin_tuple_dict_log_to_tuple(G)
        D = UH.diagonalize(D)
        # for k, v in D.items():
        #     print(k, len(v), min(v), np.median(v), np.max(v))

        M = UH.get_max_Ks(D)
        D = UH.get_median_Ks(D)
        D = UH.get_mean_deltas_diagonal(D)
        M = UH.get_mean_deltas_diagonal(M)
        # UP.plot_individ_median_logplot_with_fit_linear(D, f"deltas/{K_path}_all_v2")
        UP.plot_individ_median_logplot_with_fit_exponential(D, f"deltas/{K_path}_all_v2")
        # UP.plot_individ_median_logplot_with_fit(M, f"deltas/{K_path}_max", use_fixed_scaling=False)


if __name__ == "__main__":

    main()
