import json
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from numpy.linalg import norm
from collections import Counter
import os
import torch
import copy
import glob
import random
import tqdm


def distance_score(embed1, embed2, budget1=None, budget2=None):
    embed_dis = norm(embed1 - embed2, ord=2)
    if budget1 is not None and budget2 is not None:
        budget_dis = norm(budget1 - budget2, ord=2)
    else:
        budget_dis = None
    return embed_dis, budget_dis


def read_json(file):
    with open(file) as f:
        content = json.load(f)
    return content


def embed_budget_similarity_score_scatter(
    proxy_res_list,
    test_res_list,
    proxy_budgets,
    test_budgets,
    proxy_orig_idxs,
    test_orig_idxs,
    log_dir,
):
    proxy_test_distance_list = []
    for proxy_sample in tqdm.tqdm(
        proxy_res_list, desc="Calculating proxy sample distance on each test query"
    ):
        p_orig_idx = proxy_sample["idx"]
        p_embed = proxy_sample["embed"]
        p_budget = proxy_budgets[p_orig_idx]
        for test_sample in test_res_list:
            t_orig_idx = test_sample["idx"]
            t_embed = test_sample["embed"]
            t_budget = test_budgets[t_orig_idx]
            embed_dis, budget_dis = distance_score(p_embed, t_embed, p_budget, t_budget)
            proxy_test_distance_list.append(
                {
                    "proxy_idx": p_orig_idx,
                    "test_idx": t_orig_idx,
                    "embed_dis": embed_dis,
                    "budget_dis": budget_dis,
                }
            )

    proxy_test_distance_df = pd.DataFrame(proxy_test_distance_list)
    print("Embedding-Budget distance DataFrame build done")

    # --- plot scatter
    img_folder = os.path.join(
        "imgs/embed_budget_check/SST5",
        "/".join(log_dir.split("/")[7:]),
    )
    if not os.path.exists(img_folder):
        os.makedirs(img_folder, exist_ok=True)

    selected_t_idxs = random.sample(test_orig_idxs, 500)
    for t_idx in tqdm.tqdm(
        selected_t_idxs, desc="Ploting scatter for selected test samples"
    ):
        tmp = proxy_test_distance_df[proxy_test_distance_df["test_idx"] == t_idx]

        fig, ax = plt.subplots(figsize=(5, 5))
        sns.scatterplot(x="embed_dis", y="budget_dis", data=tmp, ax=ax)
        ax.set_title(f"budget dist vs embed dist scatter for test sample idx={t_idx}")
        img_path = os.path.join(img_folder, f"scatter-test_idx={t_idx}.png")
        fig.savefig(img_path)
        plt.close()

    plt.close("all")


def embed_sim_client_ovelap_vis(
    proxy_res_list,
    test_res_list,
    proxy_budgets,
    test_budgets,
    proxy_orig_idxs,
    test_orig_idxs,
    log_dir,
):
    proxy_test_top_client_overlap_list = []
    for proxy_sample in tqdm.tqdm(
        proxy_res_list,
        desc="Calculating proxy sample distance and top client overlap on each test query",
    ):
        p_orig_idx = proxy_sample["idx"]
        p_embed = proxy_sample["embed"]
        p_budget = proxy_budgets[p_orig_idx]
        p_client_ranking = np.argsort(p_budget)[::-1]
        p_top1_client = p_client_ranking[0]
        p_top2_client = set(p_client_ranking[:2])
        for test_sample in test_res_list:
            t_orig_idx = test_sample["idx"]
            t_embed = test_sample["embed"]
            t_budget = test_budgets[t_orig_idx]
            t_client_ranking = np.argsort(t_budget)[::-1]
            t_top1_client = t_client_ranking[0]
            t_top2_client = set(t_client_ranking[:2])
            # check top-1 client overlap
            if t_top1_client == p_top1_client:
                top1_client_overlap_score = 1
            else:
                top1_client_overlap_score = 0
            # check top-2 client overlap
            top2_client_overlap = p_top2_client & t_top2_client
            top2_client_overlap_score = len(top2_client_overlap) / 2

            embed_dis, _ = distance_score(p_embed, t_embed, None, None)
            proxy_test_top_client_overlap_list.append(
                {
                    "proxy_idx": p_orig_idx,
                    "test_idx": t_orig_idx,
                    "embed_dis": embed_dis,
                    "top1_client_overlap_score": top1_client_overlap_score,
                    "top2_client_overlap_score": top2_client_overlap_score,
                }
            )

    proxy_test_top_client_overlap_df = pd.DataFrame(proxy_test_top_client_overlap_list)
    print("Embedding-ClientOverlap DataFrame build done")

    # --- plot scatter
    img_folder = os.path.join(
        "imgs/embed_top_client_overlap_check/SST5",
        "/".join(log_dir.split("/")[7:]),
    )
    if not os.path.exists(img_folder):
        os.makedirs(img_folder, exist_ok=True)

    selected_t_idxs = random.sample(test_orig_idxs, 500)
    for t_idx in tqdm.tqdm(
        selected_t_idxs, desc="Ploting histplot for selected test samples"
    ):
        tmp = proxy_test_top_client_overlap_df[
            proxy_test_top_client_overlap_df["test_idx"] == t_idx
        ]
        fig, axes = plt.subplots(2, 2, figsize=(18, 10))
        # sns.scatterplot(x="embed_dis", y="top1_client_overlap_score", data=tmp, ax=ax)
        sns.histplot(
            x="embed_dis",
            hue="top1_client_overlap_score",
            multiple="stack",
            data=tmp,
            ax=axes[0][0],
        )
        axes[0][0].set_title(f"top1_client_overlap_score (stack)")

        sns.histplot(
            x="embed_dis",
            hue="top1_client_overlap_score",
            data=tmp,
            ax=axes[1][0],
        )
        axes[1][0].set_title(f"top1_client_overlap_score")

        sns.histplot(
            x="embed_dis",
            hue="top2_client_overlap_score",
            multiple="stack",
            data=tmp,
            ax=axes[0][1],
        )
        axes[0][1].set_title(f"top2_client_overlap_score (stack)")

        sns.histplot(
            x="embed_dis",
            hue="top2_client_overlap_score",
            data=tmp,
            ax=axes[1][1],
        )
        axes[1][1].set_title(f"top2_client_overlap_score")

        fig.suptitle(
            f"EmbedDist vs TopClientOverlap for test sample idx={t_idx}", fontsize=15
        )
        img_path = os.path.join(img_folder, f"histplot-test_idx={t_idx}.png")
        fig.savefig(img_path, bbox_inches="tight")
        plt.close()

    plt.close("all")


if __name__ == "__main__":
    data_dir = "cache/root/directoryFL-ICL-checkpoint/fed_icl_log/embed_opt_budget_check/sst5_proxy=test-500"
    log_dirs = glob.glob(os.path.join(data_dir, "**", "seed=0"), recursive=True)

    for log_dir in log_dirs:
        print(f"Processing {log_dir}...")
        proxy_res_file = os.path.join(log_dir, "proxy_forward_result.pt")
        proxy_res_list = torch.load(proxy_res_file)["res_list"]

        test_res_file = os.path.join(log_dir, "remain_test_forward_result.pt")
        test_res_list = torch.load(test_res_file)["res_list"]

        # === sanity check
        proxy_orig_idxs = [res["idx"] for res in proxy_res_list]
        test_orig_idxs = [res["idx"] for res in test_res_list]
        print(
            f"len(proxy)={len(proxy_orig_idxs)}, proxy unique_num={len(set(proxy_orig_idxs))}"
        )
        print(
            f"len(test)={len(test_orig_idxs)}, test unique_num={len(set(test_orig_idxs))}"
        )

        print(
            f"Size of unique samples in proxy + test = {len(set(proxy_orig_idxs + test_orig_idxs))}"
        )

        # ==== read the optimal budgets
        proxy_budgets_raw = read_json(
            os.path.join(log_dir, "proxy_opt_client_budget.json")
        )
        max_budget = list(set([sum(b) for b in proxy_budgets_raw.values()]))[0]
        print(f"Server-side Budget: {max_budget}")
        proxy_budgets = {}
        for orig_id in proxy_budgets_raw:
            proxy_budgets[int(orig_id)] = (
                np.array(proxy_budgets_raw[orig_id]) / max_budget
            )

        # -- sanity check budget idx
        print(
            f"Proxy idxs check correct: {set(list(proxy_budgets.keys())) == set(proxy_orig_idxs)}"
        )

        test_budgets_raw = read_json(
            os.path.join(log_dir, "test_opt_client_budget.json")
        )
        test_budgets = {}
        for orig_id in test_budgets_raw:
            test_budgets[int(orig_id)] = (
                np.array(test_budgets_raw[orig_id]) / max_budget
            )

        # -- sanity check test idx
        print(
            f"Test idxs check correct: {set(list(test_budgets.keys())) == set(test_orig_idxs)}"
        )

        # embed_budget_similarity_score_scatter(
        #     proxy_res_list,
        #     test_res_list,
        #     proxy_budgets,
        #     test_budgets,
        #     proxy_orig_idxs,
        #     test_orig_idxs,
        #     log_dir,
        # )
        embed_sim_client_ovelap_vis(
            proxy_res_list,
            test_res_list,
            proxy_budgets,
            test_budgets,
            proxy_orig_idxs,
            test_orig_idxs,
            log_dir,
        )
