import pandas as pd
from pathlib import Path

cwd = Path.cwd()
# fmt: off
seamless_unsupported = { "ast_Latn", "hau_Latn", "kam_Latn", "kea_Latn", "lin_Latn", "mri_Latn", "nso_Latn", "oci_Latn", "umb_Latn", "wol_Latn", "xho_Latn", }
whisper_unsupported = { "ast_Latn", "ceb_Latn", "ckb_Arab", "fuv_Latn", "gle_Latn", "ibo_Latn", "kam_Latn", "kea_Latn", "kir_Cyrl", "lug_Latn", "luo_Latn", "nso_Latn", "umb_Latn", "wol_Latn", "xho_Latn", "zul_Latn", }
# fmt: on


def parse_hyperparameters(hparams):
    out = {}
    input, rest = hparams.split("-", maxsplit=1)
    kind, rest = rest.split("-", maxsplit=1)
    seed, rest = rest.split("-", maxsplit=1)
    out["input"] = input.split("=", maxsplit=1)[1]
    out["kind"] = kind.split("=", maxsplit=1)[1]
    out["seed"] = seed.split("=")[1]
    rest = rest.rsplit("-", maxsplit=2)
    out["lr"] = rest[0].split("=")[1]
    out["batch_size"] = rest[1].split("=")[1]
    out["accumulate_grad_batches"] = rest[2].split("=")[1]
    return out


def get_results(task="sib") -> pd.DataFrame:
    records = []
    logs = cwd / "logs" / "eval" / task
    model = next(logs.iterdir())
    for model in logs.iterdir():
        if model.name.startswith("model"):
            for hyperparams in model.iterdir():
                for ckpt in hyperparams.iterdir():
                    params = parse_hyperparameters(hyperparams.name)
                    params["model"] = model.name.split("=")[1]
                    params["ckpt"] = ckpt.name.removeprefix("ckpt=")
                    exists = False
                    metrics = None
                    for i in reversed(range(0, 9)):
                        metrics = (
                            ckpt / "lightning_logs" / f"version_{i}" / "metrics.csv"
                        )
                        exists = metrics.exists()
                        if exists:
                            break
                    # { "col_name": { row_idx: val }
                    assert metrics is not None
                    csv = pd.read_csv(metrics).to_dict()
                    for key, value in csv.items():
                        if key.startswith("test_") and key.endswith("/test/acc"):
                            params[key] = value[0]
                        if (
                            "validation" in key
                            and "eng_Latn" in key
                            and key.endswith("/test/acc")
                        ):
                            params[key] = value[0]
                    records.append(params)
    df = pd.DataFrame.from_records(records)

    return df


def get_belebele_results(task="sib") -> pd.DataFrame:
    records = []
    logs = cwd / "logs" / "eval" / task
    model = next(logs.iterdir())
    for model in logs.iterdir():
        if model.name.startswith("model"):
            for hyperparams in model.iterdir():
                for ckpt in hyperparams.iterdir():
                    params = parse_hyperparameters(hyperparams.name)
                    params["model"] = model.name.split("=")[1]
                    params["ckpt"] = ckpt.name.removeprefix("ckpt=")
                    exists = False
                    metrics = None
                    for i in reversed(range(0, 9)):
                        metrics = (
                            ckpt / "lightning_logs" / f"version_{i}" / "metrics.csv"
                        )
                        exists = metrics.exists()
                        if exists:
                            break
                    # { "col_name": { row_idx: val }
                    assert metrics is not None
                    csv = pd.read_csv(metrics).to_dict()
                    for key, value in csv.items():
                        if key.startswith("test_") and key.endswith("/test/acc"):
                            params[key] = value[0]
                        if (
                            "validation" in key
                            and "eng_Latn" in key
                            and key.endswith("/test/acc")
                        ):
                            params[key] = value[0]
                    train_path = Path(str(hyperparams).replace("eval", "train"))
                    train_csv_path = (
                        train_path / "lightning_logs" / "version_0" / "metrics.csv"
                    )
                    assert train_csv_path.exists()
                    train_csv = pd.read_csv(train_csv_path)
                    val_col = "validation_belebele_eng_Latn/val/acc"
                    mask = train_csv[val_col].notnull()
                    train_csv = train_csv.loc[mask]
                    ckpt_ = train_csv[val_col].argmax(0)
                    assert ckpt_ == int(params["ckpt"])
                    params["validation_eng"] = train_csv.reset_index().loc[
                        ckpt_, val_col
                    ]
                    records.append(params)

    df = pd.DataFrame.from_records(records)
    return df


def parse_belebele(df, metric="acc", task="belebele"):
    general_columns = [
        c for c in df.columns if not c.endswith("/acc") and c != "validation_eng"
    ]
    dfs = []
    for key in (
        "text",
        "whisper_asr",
        "seamlessm4t_asr",
        "whisper_translation",
        "seamlessm4t_translation",
    ):
        if key == "text":
            columns = (
                general_columns
                + [
                    c
                    for c in df.columns
                    if c.startswith(f"test_{task}_")
                    and "translation" not in c
                    and "asr" not in c
                ]
                + ["validation_eng"]
            )
        else:
            columns = (
                general_columns
                + [c for c in df.columns if c.startswith(f"test_{task}_") and key in c]
                + ["validation_eng"]
            )
        df_ = df.loc[:, columns]
        mask = df_.notnull().all(axis=1)
        df_ = df_.loc[mask]
        df_["model"] = df["model"].apply(lambda string: string + "_" + key)
        df_.columns = [
            c.replace(f"test_belebele_{key}_", "")
            .replace("test_belebele_", "")
            .removesuffix(f"/test/{metric}")
            for c in columns
        ]
        dfs.append(df_)
    df_ = pd.concat(dfs, axis=0)
    cols = [c for c in df_.columns if c not in general_columns]
    cols_excl_eng = [c for c in cols if c not in {"eng_Latn", "validation_eng"}]
    assert len(cols) == len(cols_excl_eng) + 2
    whisper_supported_cols = [c for c in cols_excl_eng if c not in whisper_unsupported]
    # not all whisper_unsupported are part of belebele-fleurs
    # assert len(whisper_supported_cols) == len(cols_excl_eng) - len(whisper_unsupported)
    # not all
    seamlessm4t_supported_cols = [
        c for c in cols_excl_eng if c not in seamless_unsupported
    ]
    unsupported_cols = [
        c
        for c in cols_excl_eng
        if c not in seamlessm4t_supported_cols and c not in whisper_supported_cols
        # if c in seamless_unsupported  # and c not in whisper_supported_cols
    ]
    df_["avg"] = df_.loc[:, cols].mean(1)
    df_["non_eng_avg"] = df_.loc[:, cols_excl_eng].mean(1)
    df_["seamlessm4t_supported_avg"] = df_.loc[:, seamlessm4t_supported_cols].mean(1)
    df_["whisper_supported_avg"] = df_.loc[:, whisper_supported_cols].mean(1)
    df_["unsupported_avg"] = df_.loc[:, unsupported_cols].mean(1)
    df_ = df_.sort_values(["validation_eng"], ascending=False)
    print(f"{len(whisper_supported_cols)=}")
    print(f"{len(seamlessm4t_supported_cols)=}")
    print(f"{len(unsupported_cols)=}")
    print(f"{len(cols_excl_eng)=}")
    # target_lang_columns = [
    #     c for c in df_.columns if c not in general_columns and c != "eng_Latn"
    # ]
    # df_["avg"] = df_.loc[:, target_lang_columns].mean(1)
    # # pivot english to front
    # new_columns = general_columns + ["eng_Latn"] + target_lang_columns + ["avg"]
    # df_ = df_.loc[:, new_columns]
    # df_.to_csv("belebele1.csv", index=False)
    return df_


x = get_belebele_results("belebele")
df = x.copy()
d = parse_belebele(df.copy())
del d["batch_size"]
del d["accumulate_grad_batches"]
del d["ckpt"]
d["seed"] = d["seed"].astype(int)
x = d.groupby(["input", "kind", "lr", "model"]).mean()
x_std = d.groupby(["input", "kind", "lr", "model"]).std()
idx = x.groupby(["model", "input", "kind"])["validation_eng"].idxmax()
y = x.loc[idx].reset_index()
columns = [
    "model",
    "kind",
    "lr",
    "eng_Latn",
    "whisper_supported_avg",
    "seamlessm4t_supported_avg",
    "unsupported_avg",
    "non_eng_avg",
]
out = y.loc[:, columns].copy()
out["model"] = out.model.apply(
    lambda s: s.replace(
        "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse", "LLM2Vec"
    )
).apply(
    lambda s: s.replace(
        "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse", "NLLB-LLM2Vec"
    )
)

pd.concat([out.iloc[:, :3], (out.iloc[:, 3:] * 100).round(1)], axis=1)


def parse_sib(df, task, metric="acc"):
    if task == "sib-fleurs":
        task = "sib"
    # split into general columns and test test_whisper columns
    general_columns = [c for c in df.columns if not c.endswith("/acc")]
    test_columns = [
        c
        for c in df.columns
        if c.startswith(f"test_{task}_")
        and "translation" not in c
        and "whisper" not in c
        and "seamless" not in c
    ]
    test_columns.append("validation_sib_eng_Latn/test/acc")
    whisper_columns = [
        c
        for c in df.columns
        if c.startswith(f"test_{task}_") and "whisper" in c and "translation" not in c
    ]
    whisper_columns.append("validation_sib_whisper_eng_Latn/test/acc")
    df_whisper = df.loc[:, general_columns + whisper_columns]
    seamless_columns = [
        c
        for c in df.columns
        if c.startswith(f"test_{task}_") and "seamless" in c and "translation" not in c
    ]
    seamless_columns.append("validation_sib_seamlessm4t_eng_Latn/test/acc")
    df_seamless = df.loc[:, general_columns + seamless_columns]
    whisper_translation_test_columns = [
        c
        for c in df.columns
        if c.startswith(f"test_{task}_") and "whisper_translation" in c
    ]
    whisper_translation_test_columns.append(
        "validation_sib_whisper_translation_eng_Latn/test/acc"
    )
    df_whisper_translation = df.loc[
        :, general_columns + whisper_translation_test_columns
    ]
    seamlessm4t_translation_test_columns = [
        c
        for c in df.columns
        if c.startswith(f"test_{task}_") and "seamlessm4t_translation" in c
    ]
    seamlessm4t_translation_test_columns.append(
        "validation_sib_seamlessm4t_translation_eng_Latn/test/acc"
    )
    df_seamlessm4t_translation = df.loc[
        :, general_columns + seamlessm4t_translation_test_columns
    ]
    df_ = df.loc[:, general_columns + test_columns]

    columns = [
        c.removeprefix(f"test_{task}_").removesuffix(f"/test/{metric}")
        for c in df_.columns
    ]
    df_.columns = columns
    df_ = df_.loc[df_.deu_Latn.notnull()]
    df_whisper.columns = [
        c.removeprefix(f"test_{task}_whisper_").removesuffix(f"/test/{metric}")
        for c in df_whisper.columns
    ]
    df_whisper = df_whisper.loc[
        df_whisper.deu_Latn.notnull()  # just a language that exists across both tasks
    ]
    df_seamless.columns = [
        c.removeprefix(f"test_{task}_seamlessm4t_").removesuffix(f"/test/{metric}")
        for c in df_seamless.columns
    ]
    df_seamless = df_seamless.loc[
        df_seamless.deu_Latn.notnull()  # just a language that exists across both tasks
    ]
    df_whisper_translation.columns = [
        c.removeprefix(f"test_{task}_whisper_translation_").removesuffix(
            f"/test/{metric}"
        )
        for c in df_whisper_translation.columns
    ]
    df_whisper_translation = df_whisper_translation.loc[
        df_whisper_translation.deu_Latn.notnull()  # just a language that exists across both tasks
    ]
    df_whisper_translation.model = df_whisper_translation.model.apply(
        lambda x: x + "_whisper_translation"
    )

    df_seamlessm4t_translation.columns = [
        c.removeprefix(f"test_{task}_seamlessm4t_translation_").removesuffix(
            f"/test/{metric}"
        )
        for c in df_seamlessm4t_translation.columns
    ]
    df_seamlessm4t_translation = df_seamlessm4t_translation.loc[
        df_seamlessm4t_translation.deu_Latn.notnull()
    ]
    df_seamlessm4t_translation.model = df_seamlessm4t_translation.model.apply(
        lambda x: x + "_seamlessm4t_translation"
    )

    df_ = df_.rename({"validation_sib_eng_Latn": "validation_eng"}, axis=1)
    df_whisper = df_whisper.rename(
        {"validation_sib_whisper_eng_Latn": "validation_eng"}, axis=1
    )
    df_seamless = df_seamless.rename(
        {"validation_sib_seamlessm4t_eng_Latn": "validation_eng"}, axis=1
    )
    df_whisper_translation = df_whisper_translation.rename(
        {"validation_sib_whisper_translation_eng_Latn": "validation_eng"},
        axis=1,
    )
    df_seamlessm4t_translation = df_seamlessm4t_translation.rename(
        {"validation_sib_seamlessm4t_translation_eng_Latn": "validation_eng"},
        axis=1,
    )
    # del df["validation_sib_whisper_translation_eng_Latn/test/acc"]
    # del df["validation_sib_seamlessm4t_translation_eng_Latn/test/acc"]
    df = pd.concat(
        [
            df_.reset_index(drop=True).copy(),
            df_whisper.reset_index(drop=True).copy(),
            df_seamless.reset_index(drop=True).copy(),
            df_whisper_translation.reset_index(drop=True).copy(),
            df_seamlessm4t_translation.reset_index(drop=True).copy(),
        ],
        axis=0,
        ignore_index=True,
    )
    cols = [c for c in df.columns if c not in general_columns]
    cols_excl_eng = [c for c in cols if c not in {"eng_Latn", "validation_eng"}]
    assert len(cols) == len(cols_excl_eng) + 2
    whisper_supported_cols = [c for c in cols_excl_eng if c not in whisper_unsupported]
    assert len(whisper_supported_cols) == len(cols_excl_eng) - len(whisper_unsupported)
    seamlessm4t_supported_cols = [
        c for c in cols_excl_eng if c not in seamless_unsupported
    ]
    unsupported_cols = [
        c
        for c in cols_excl_eng
        if c not in seamlessm4t_supported_cols and c not in whisper_supported_cols
    ]
    # print(f"{len(unsupported_cols)=}")

    df["avg"] = df.loc[:, cols].mean(1)
    cols_excl_eng = [c for c in cols if "eng_Latn" != c and "validation_eng" != c]
    df["non_eng_avg"] = df.loc[:, cols_excl_eng].mean(1)
    df["seamlessm4t_supported_avg"] = df.loc[:, seamlessm4t_supported_cols].mean(1)
    df["whisper_supported_avg"] = df.loc[:, whisper_supported_cols].mean(1)
    df["unsupported_avg"] = df.loc[:, unsupported_cols].mean(1)
    df = df.sort_values(["validation_eng"], ascending=False)
    print(f"{len(whisper_supported_cols)=}")
    print(f"{len(seamlessm4t_supported_cols)=}")
    print(f"{len(unsupported_cols)=}")
    print(f"{len(cols_excl_eng)=}")
    return df.copy()


# whisper supported
# seamless supported
# unsupported
# zs-xlt


sib_fleurs = get_results("sib-fleurs")
df = sib_fleurs.copy()
sib_fleurs_ = parse_sib(df, "sib-fleurs")

d = sib_fleurs_.copy()
del d["batch_size"]
del d["accumulate_grad_batches"]
del d["ckpt"]
d["model"] = d.model.apply(
    lambda s: s.replace(
        "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse", "LLM2Vec"
    )
).apply(
    lambda s: s.replace(
        "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse", "NLLB-LLM2Vec"
    )
)
d["seed"] = d["seed"].astype(int)
x = d.groupby(["input", "kind", "lr", "model"]).mean()
x_std = d.groupby(["input", "kind", "lr", "model"]).std()
idx = x.groupby(["model", "input", "kind"])["validation_eng"].idxmax()
y = x.loc[idx].reset_index()
columns = [
    "model",
    "kind",
    "input",
    "lr",
    "eng_Latn",
    "whisper_supported_avg",
    "seamlessm4t_supported_avg",
    "unsupported_avg",
    "non_eng_avg",
]
out = y.loc[:, columns].copy()

out_ = pd.concat([out.iloc[:, :4], (out.iloc[:, 4:] * 100).round(1)], axis=1)

# Assume your DataFrame is named df
# The metrics that you want to show side-by-side:
metrics = [
    "eng_Latn",
    "whisper_supported_avg",
    "seamlessm4t_supported_avg",
    "unsupported_avg",
    "non_eng_avg",
]

# First, set an index that uniquely identifies a row in the final table.
# We include "kind" in the index so that we can unstack it later.
out_pivot = out_.set_index(["model","input", "kind"])[metrics].unstack("kind")

# At this point, for each (model, input, lr) we have a MultiIndex column, e.g.
# ("eng_Latn", "best") and ("eng_Latn", "worst").
#
# For rows where input is "text" the "worst" value is missing, so fill it
# with the "best" value:
for metric in metrics:
    out_pivot[(metric, "worst")] = out_pivot[(metric, "worst")].fillna(
        out_pivot[(metric, "best")]
    )

# Optionally, flatten the MultiIndex in columns so that they are easier to work with:
out_pivot.columns = [f"{metric}_{kind}" for metric, kind in out_pivot.columns]
out_pivot = out_pivot.reset_index()
for c in out_pivot.columns:
    out_pivot[c] = out_pivot[c].astype(str)

# Now you have a DataFrame where for each row (unique by model, input, lr) the metrics appear as:
# eng_Latn_best, eng_Latn_worst, whisper_supported_avg_best, whisper_supported_avg_worst, etc.
#
# You can now output it to LaTeX:
latex_table = out_pivot.to_latex(index=False)


cols = list(sib_fleurs_.columns[:8]) + ["eng_Latn", "non_eng_avg"]
sib_fleurs_.loc[:, cols].to_csv("new.csv", index=False)

sib_fleurs_.loc[:, ["model", "lr", "eng_Latn", "non_eng_avg"]].groupby(
    ["model", "lr"]
).mean().reset_index().sort_values("non_eng_avg", ascending=False).to_csv("agg.csv")

missing = ["bul_Cyrl", "oci_Latn" "ckb_Arab", "snd_Arab"]


x = {
    "afr_Latn",
    "amh_Ethi",
    "arb_Arab",
    "asm_Beng",
    "ast_Latn",
    "azj_Latn",
    "bel_Cyrl",
    "ben_Beng",
    "bos_Latn",
    "cat_Latn",
    "ceb_Latn",
    "ces_Latn",
    "cym_Latn",
    "dan_Latn",
    "deu_Latn",
    "ell_Grek",
    "eng_Latn",
    "est_Latn",
    "fin_Latn",
    "fra_Latn",
    "fuv_Latn",
    "gaz_Latn",
    "gle_Latn",
    "glg_Latn",
    "guj_Gujr",
    "hau_Latn",
    "heb_Hebr",
    "hin_Deva",
    "hrv_Latn",
    "hun_Latn",
    "hye_Armn",
    "ibo_Latn",
    "ind_Latn",
    "isl_Latn",
    "ita_Latn",
    "jav_Latn",
    "jpn_Jpan",
    "kam_Latn",
    "kan_Knda",
    "kat_Geor",
    "kaz_Cyrl",
    "kea_Latn",
    "khk_Cyrl",
    "khm_Khmr",
    "kir_Cyrl",
    "kor_Hang",
    "lao_Laoo",
    "lin_Latn",
    "lit_Latn",
    "ltz_Latn",
    "lug_Latn",
    "luo_Latn",
    "lvs_Latn",
    "mal_Mlym",
    "mar_Deva",
    "mkd_Cyrl",
    "mlt_Latn",
    "mri_Latn",
    "mya_Mymr",
    "nld_Latn",
    "nob_Latn",
    "npi_Deva",
    "nso_Latn",
    "nya_Latn",
    "ory_Orya",
    "pan_Guru",
    "pbt_Arab",
    "pes_Arab",
    "pol_Latn",
    "por_Latn",
    "ron_Latn",
    "rus_Cyrl",
    "slk_Latn",
    "slv_Latn",
    "sna_Latn",
    "som_Latn",
    "spa_Latn",
    "srp_Cyrl",
    "swe_Latn",
    "swh_Latn",
    "tam_Taml",
    "tel_Telu",
    "tgk_Cyrl",
    "tgl_Latn",
    "tha_Thai",
    "tur_Latn",
    "ukr_Cyrl",
    "umb_Latn",
    "urd_Arab",
    "uzn_Latn",
    "vie_Latn",
    "wol_Latn",
    "xho_Latn",
    "yor_Latn",
    "zho_Hans",
    "zho_Hant",
    "zsm_Latn",
    "zul_Latn",
}


y = {
    "afr_Latn",
    "amh_Ethi",
    "arb_Arab",
    "asm_Beng",
    "ast_Latn",
    "azj_Latn",
    "bel_Cyrl",
    "bul_Cyrl",
    "ben_Beng",
    "bos_Latn",
    "cat_Latn",
    "ceb_Latn",
    "ckb_Arab",
    "zho_Hans",
    "ces_Latn",
    "cym_Latn",
    "dan_Latn",
    "deu_Latn",
    "ell_Grek",
    "eng_Latn",
    "spa_Latn",
    "est_Latn",
    "pes_Arab",
    "fin_Latn",
    "tgl_Latn",
    "fra_Latn",
    "gle_Latn",
    "glg_Latn",
    "guj_Gujr",
    "hau_Latn",
    "heb_Hebr",
    "hin_Deva",
    "hrv_Latn",
    "hun_Latn",
    "hye_Armn",
    "ind_Latn",
    "ibo_Latn",
    "isl_Latn",
    "ita_Latn",
    "jpn_Jpan",
    "jav_Latn",
    "kat_Geor",
    "kam_Latn",
    "kea_Latn",
    "kaz_Cyrl",
    "khm_Khmr",
    "kan_Knda",
    "kor_Hang",
    "kir_Cyrl",
    "ltz_Latn",
    "lug_Latn",
    "lin_Latn",
    "lao_Laoo",
    "lit_Latn",
    "luo_Latn",
    "lvs_Latn",
    "mri_Latn",
    "mkd_Cyrl",
    "mal_Mlym",
    "khk_Cyrl",
    "mar_Deva",
    "zsm_Latn",
    "mlt_Latn",
    "mya_Mymr",
    "nob_Latn",
    "npi_Deva",
    "nld_Latn",
    "nso_Latn",
    "nya_Latn",
    "oci_Latn",
    "ory_Orya",
    "pan_Guru",
    "pol_Latn",
    "pbt_Arab",
    "por_Latn",
    "ron_Latn",
    "rus_Cyrl",
    "snd_Arab",
    "slk_Latn",
    "slv_Latn",
    "sna_Latn",
    "som_Latn",
    "srp_Cyrl",
    "swe_Latn",
    "swh_Latn",
    "tam_Taml",
    "tel_Telu",
    "tgk_Cyrl",
    "tha_Thai",
    "tur_Latn",
    "ukr_Cyrl",
    "umb_Latn",
    "urd_Arab",
    "uzn_Latn",
    "vie_Latn",
    "wol_Latn",
    "xho_Latn",
    "yor_Latn",
    "zho_Hant",
    "zul_Latn",
    "fuv_Latn",
    "gaz_Latn",
}
