# %% Setup libs
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
import yaml

from clcp.data import CLF_DSS, build_dl, get_n_classes
from clcp.metrics import BinaryMetrics, MultiClassMetrics
from clcp.models.base import TestDL

plt.style.use(["science", "no-latex"])

# %% Setup env


TOK_NAME = "answerdotai/ModernBERT-base"

with open("./scripts/paper/utils/dss.yml", encoding="utf-8") as f:
    ds_info = yaml.safe_load(f)["dataset_info"]


def build_test_dls(
    names: list[str],
    mdl_name: str,
    batch_size: int,
    *,
    paired_data: bool,
    is_test: bool,
    is_dummy: bool,
) -> list[TestDL]:
    te_dls = []
    for name in names:
        dl = build_dl(
            mdl_name=mdl_name,
            name=name,
            split="test",
            batch_size=batch_size,
            paired=paired_data,
            is_test=is_test,
            is_dummy=is_dummy,
        )
        metrics = MultiClassMetrics(metrics=["f1", "acc"]) if name in CLF_DSS else BinaryMetrics(metrics=["roc", "f1"])
        te_dls.append(TestDL(name=name, dl=dl, metrics=metrics))
    return te_dls


def get_test_dls(mdl_name, paired_data):
    return build_test_dls(
        names=ds_info.keys(),
        mdl_name=mdl_name,
        batch_size=64,
        paired_data=paired_data,
        is_test=False,
        is_dummy=False,
    )


def avg_token_count(texts: pd.Series) -> int:
    encodings = tok(list(texts), padding=False, truncation=False, add_special_tokens=False)
    token_counts = [len(ids) for ids in encodings["input_ids"]]
    return round(sum(token_counts) / len(token_counts))


# %% Compute table stats
dls = get_test_dls(TOK_NAME, paired_data=True)
tok = dls[0].dl.dataset.tok
lenghts = [avg_token_count(te_dl.dl.dataset.data.to_pandas()["text"]) for te_dl in dls]
n_classes = [get_n_classes(te_dl.dl.dataset.data) for te_dl in dls]
# %% tbl_01_dss_stats

dss = list(ds_info.keys())

df = pd.DataFrame({
    "Task": [ds_info[d]["task"] for d in dss],
    "Domain": [ds_info[d]["domain"] for d in dss],
    "Dataset": dss,
    "Num Classes": n_classes,
    "Avg Token Count": lenghts,
})

df = df.sort_values(by=["Task", "Domain", "Dataset"]).reset_index(drop=True)

latex_code = df.to_latex(
    index=False,
    caption="Summary statistics of BTZSC datasets.",
    label="tab:tbl_01_dss_stats",
    column_format="lllrr",
    escape=True,
    longtable=False,
    multicolumn=False,
    multicolumn_format="c",
    bold_rows=False,
)
latex_code = latex_code.replace(r"\begin{table}", r"\begin{table*}").replace(r"\end{table}", r"\end{table*}")
print(latex_code)


# %% fig_01_dss_jaq_sim


def get_normalized_token_freqs(tokenizer, texts):
    # Flatten all tokens in the dataset
    tokens = []
    for text in texts:
        tokens.extend(tokenizer.encode(text, add_special_tokens=False))
    total = len(tokens)
    counts = Counter(tokens)
    return {tok: cnt / total for tok, cnt in counts.items()}


def weighted_jaccard(freqs1, freqs2):
    keys = set(freqs1.keys()).union(freqs2.keys())
    min_sum = sum(min(freqs1.get(k, 0), freqs2.get(k, 0)) for k in keys)
    max_sum = sum(max(freqs1.get(k, 0), freqs2.get(k, 0)) for k in keys)
    return min_sum / max_sum if max_sum > 0 else 0


tok = dls[0].dl.dataset.tok
datasets_dict = {te_dl.name: te_dl.dl.dataset["text"] for te_dl in dls}

norm_freqs_dict = {name: get_normalized_token_freqs(tok, texts) for name, texts in datasets_dict.items()}

weighted_jaccard_matrix = pd.DataFrame(index=datasets_dict, columns=datasets_dict)
for ds1 in datasets_dict:
    for ds2 in datasets_dict:
        weighted_jaccard_matrix.loc[ds1, ds2] = weighted_jaccard(norm_freqs_dict[ds1], norm_freqs_dict[ds2])

weighted_jaccard_matrix = weighted_jaccard_matrix.astype(float).round(2)

# %% plot

plt.figure(figsize=(12, 10))
ax = sns.heatmap(
    weighted_jaccard_matrix.astype(float),
    annot=True,  # show similarity values
    fmt=".2f",  # 2 decimals
    cmap="BuGn",
    square=True,
    cbar_kws={"label": "Weighted Jaccard Similarity"},
)
ax.set_facecolor("none")
fig = plt.gcf()
fig.patch.set_facecolor("none")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()

plt.savefig("./paper/figs/fig_01_dss_jaq_sim.pdf", format="pdf", bbox_inches="tight")
