import pandas as pd

from pathlib import Path

cwd = Path.cwd()

# logs = cwd / "logs" / "train" / "sib-fleurs"
logs = cwd / "logs" / "train" / "belebele"


def parse_hyperparameters(hparams):
    out = {}
    input, rest = hparams.split("-", maxsplit=1)
    kind, rest = rest.split("-", maxsplit=1)
    seed, rest = rest.split("-", maxsplit=1)
    out["input"] = input.split("=", maxsplit=1)[1]
    out["kind"] = kind.split("=", maxsplit=1)[1]
    out["seed"] = seed.split("=")[1]
    rest = rest.rsplit("-", maxsplit=2)
    out["lr"] = rest[0].split("=")[1]
    out["batch_size"] = rest[1].split("=")[1]
    out["accumulate_grad_batches"] = rest[2].split("=")[1]
    return out


# "./logs/train/sib-fleurs/"

records = []
for model in logs.iterdir():
    if model.name.startswith("model"):
        if "roberta" in model.stem:
        # if "NLLB" in model.stem and "LLM2Vec" in model.stem:
            for hyperparams in model.iterdir():
                params = parse_hyperparameters(hyperparams.name)
                params["model"] = model.name.split("=")[1]
                if not params["kind"] == "best":
                    continue
                metrics = hyperparams / "lightning_logs" / "version_0" / "metrics.csv"
                assert metrics.exists(), str(metrics)
                csv = pd.read_csv(metrics)
                if "sib" in logs.stem:
                    if "whisper" in params["input"]:  # type: ignore
                        val_col = "validation_sib_whisper_eng_Latn/val/acc"
                        test_col = "test_sib_whisper_eng_Latn/val/acc"
                    elif "seamless" in params["input"]:  # type: ignore
                        val_col = "validation_sib_seamlessm4t_eng_Latn/val/acc"
                        test_col = "test_sib_seamlessm4t_eng_Latn/val/acc"
                    else:
                        val_col = "validation_sib_eng_Latn/val/acc"
                        test_col = "test_sib_eng_Latn/val/acc"
                elif "belebele" in logs.stem:
                    val_col = "validation_belebele_eng_Latn/val/acc"
                    test_col = "test_belebele_eng_Latn/val/acc"
                else:
                    raise ValueError("Wrong folder")
                mask = csv[val_col].notnull()
                csv = csv.loc[mask]
                ckpt = csv[val_col].argmax(0)
                perf = csv[test_col].iloc[ckpt]
                params["ckpt"] = ckpt
                params["test"] = perf
                records.append(params)
df = pd.DataFrame.from_records(records)

model2arch = {
    "seamless-m4t-v2-large-speech-encoder": "seamlessm4tv2",
    "seamless-m4t-v2-large-speech-encoder-excl-adapter": "seamlessm4tv2_excl_adapter",
    "roberta-large": "text",
    "mHuBERT-147": "wav2vec2",
    "whisper-large-v3-turbo": "whisper",
    "whisper-large-v3": "whisper",
    "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "text",
    "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "text",
    "mms-1b": "mms",
    "mms-1b-fl102": "mms",
    "mms-1b-all": "mms",
}

model2gres = {
    "seamless-m4t-v2-large-speech-encoder": "gpu:a100l:1",
    "seamless-m4t-v2-large-speech-encoder-excl-adapter": "gpu:a100l:1",
    "roberta-large": "gpu:l40s:1",
    "mHuBERT-147": "gpu:l40s:1",
    "whisper-large-v3": "gpu:a100l:1",
    "whisper-large-v3-turbo": "gpu:l40s:1",
    # "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "gpu:a100l:1",
    "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "gpu:l40s:1",
    # "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "gpu:a100l:1",
    "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "gpu:l40s:1",
    "mms-1b": "gpu:a100l:1",
    "mms-1b-fl102": "gpu:a100l:1",
    "mms-1b-all": "gpu:a100l:1",
}

model2batch_size = {
    "seamless-m4t-v2-large-speech-encoder": "2",
    "seamless-m4t-v2-large-speech-encoder-excl-adapter": "2",
    "roberta-large": "32",
    "mHuBERT-147": "8",
    "whisper-large-v3-turbo": "16",
    "whisper-large-v3": "16",
    "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "16",
    "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "16",
    "mms-1b": "8",
    "mms-1b-all": "8",
    "mms-1b-fl102": "8",
}

kind2dataset = {
    # "whisper_asr": "sib_val_test_whisper_translation",
    # "seamlessm4t_asr": "sib_val_test_seamlessm4t_translation",
    "whisper_asr": "sib_val_test_whisper",
    "seamlessm4t_asr": "sib_val_test_seamlessm4t",
    "speech": "sib_val_test",
    "text": "sib_val_test_translate",
}

model2dataset = {
    "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "belebele_translation_test",
    "roberta-large": "belebele_translation_test",
    "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "belebele_test",
}
model2model = {
    "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse": "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse",
    "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse": "NLLB-LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse",
    "roberta-large": "roberta-large",
    "mms-1b": "facebook/mms-1b",
    "mms-1b-all": "facebook/mms-1b-all",
    "mms-1b-fl102": "facebook/mms-1b-fl102",
}

scripts = []
for line in records:
    if line["model"] not in model2model:
        command = "sbatch --partition=long -c 4 -t 02:00:00 --gres={gres} --wrap \"bash -c 'source ~/miniforge3/etc/profile.d/conda.sh && conda activate trident-fleurs && python -m trident.run experiment=sib run.seed={seed} run.kind={kind} run.text_column={text_column} module.optimizer.lr={lr} trainer.accumulate_grad_batches={accumulate_grad_batches} run.train_batch_size={train_batch_size} run.val_test_batch_size={val_test_batch_size} datasets={datasets} +arch={arch} +stage=eval run.ckpt={ckpt}'\""
        c = command.format(
            gres=model2gres[line["model"]],
            seed=line["seed"],
            kind=line["kind"],
            text_column=line["input"],
            lr=line["lr"],
            accumulate_grad_batches=line["accumulate_grad_batches"],
            train_batch_size=line["batch_size"],
            val_test_batch_size=model2batch_size[line["model"]],
            arch=model2arch[line["model"]],
            ckpt=line["ckpt"],
            datasets=kind2dataset[line["input"]],
        )
    else:
        command = "sbatch --partition=long -c 4 -t 02:00:00 --gres={gres} --wrap \"bash -c 'source ~/miniforge3/etc/profile.d/conda.sh && conda activate trident-fleurs && python -m trident.run experiment=sib run.seed={seed} run.kind={kind} run.text_column={text_column} module.optimizer.lr={lr} trainer.accumulate_grad_batches={accumulate_grad_batches} run.train_batch_size={train_batch_size} run.val_test_batch_size={val_test_batch_size} datasets={datasets} +arch={arch} model.pretrained_model_name_or_path={model_name} +stage=eval run.ckpt={ckpt}'\""
        c = command.format(
            gres=model2gres[line["model"]],
            seed=line["seed"],
            kind=line["kind"],
            text_column=line["input"],
            lr=line["lr"],
            accumulate_grad_batches=line["accumulate_grad_batches"],
            train_batch_size=line["batch_size"],
            val_test_batch_size=model2batch_size[line["model"]],
            arch=model2arch[line["model"]],
            ckpt=line["ckpt"],
            datasets=kind2dataset[line["input"]],
            model_name=model2model[line["model"]],
        )
    scripts.append(c)
# #
# with open("nllb-sib.sh", "w") as f:
#     for line in scripts:
#         f.write(line + "\n")
# with open("llm-sib-sib.sh", "w") as f:
#     for line in scripts:
#         f.write(line + "\n")
# with open("s4t-excl-sib.sh", "w") as f:
#     for line in scripts:
#         f.write(line + "\n")

scripts = []
for line in records:
    if "roberta" in line["model"]:
        command = "sbatch --partition=long -c 4 --mem=64GB -t 02:30:00 --gres={gres} --wrap \"bash -c 'source ~/miniforge3/etc/profile.d/conda.sh && conda activate trident-fleurs && python -m trident.run experiment=belebele run.seed={seed} run.kind={kind} run.text_column={text_column} module.optimizer.lr={lr} trainer.accumulate_grad_batches={accumulate_grad_batches} run.train_batch_size={train_batch_size} run.val_test_batch_size={val_test_batch_size} datasets={datasets} +arch={arch} +peft=lora +stage=eval run.ckpt={ckpt} model.pretrained_model_name_or_path={model} trainer.deterministic=false'\""
        c = command.format(
            gres=model2gres[line["model"]],
            seed=line["seed"],
            kind=line["kind"],
            text_column=line["input"],
            lr=line["lr"],
            accumulate_grad_batches=line["accumulate_grad_batches"],
            train_batch_size=line["batch_size"],
            val_test_batch_size=model2batch_size[line["model"]],
            arch=model2arch[line["model"]],
            ckpt=line["ckpt"],
            datasets=model2dataset[line["model"]],
            model=model2model[line["model"]],
        )
        scripts.append(c)

with open("rob-belebele.sh", "w") as f:
    for line in scripts:
        f.write(line + "\n")
