import os

import torch
import numpy as np

import utils.const as C
import utils.helpers as UH
import utils.plot as UP

def main():
    K_path = "K3"
    encoders = [
        "sentence-transformers/all-MiniLM-L6-v2",
        "sentence-transformers/all-mpnet-base-v2",
        "sentence-transformers/all-roberta-large-v1",
    ]
    for encoder in encoders:
        means = {}
        mins = {}
        maxs = {}
        num_docs = []
        encoder_name = encoder.split("/")[1]
        emb_name = encoder_name
        base_path = f"cache/{emb_name}"
        for file in os.listdir(f"{base_path}/{K_path}"):
            name = file.split(".")[0]

            for folder in os.listdir(f"cache/{encoder_name}/centroids"):
                if folder == ".DS_Store": continue
                files = os.listdir(f"cache/{encoder_name}/centroids/{folder}")
                n_docs = folder
                if f"{name}.pt" in files:
                    num_docs.append(int(n_docs))
                    _, errors = torch.load(f"cache/{encoder_name}/centroids/{n_docs}/{name}.pt", map_location=C.DEVICE)
                    mean_errors, min_errors, max_errors = errors

                    key = (int(n_docs), int(n_docs))

                    if key not in means:
                        means[key] = []
                        mins[key] = []
                        maxs[key] = []
                    def safe_extend(dst, tensor):
                        arr = tensor.detach().cpu().numpy()
                        arr = arr[np.isfinite(arr)]
                        dst.extend(arr.tolist())
                    safe_extend(means[key], mean_errors)
                    safe_extend(mins[key], min_errors)
                    safe_extend(maxs[key], max_errors)
        means = UH.bin_tuple_dict_log_to_tuple(means)
        G = UH.get_median_Ks(means)
        G = UH.get_mean_deltas_diagonal(G)
        del G[(1, 1)]
        path = f"errors/{emb_name}"
        os.makedirs(f"plots/{path}", exist_ok=True)
        UP.plot_individ_median_logplot_with_fit_linear(G, f"{path}/errors")
        # UP.plot_individ_median_logplot_with_fit(G, f"errors/error_aggregate_glob", use_fixed_scaling=False)

if __name__ == "__main__":
    main()