#!/usr/bin/env python3


import mlflow
from omegaconf import OmegaConf

import hydra
from clcp import HF_TOKEN, Model
from clcp.data import CLF_BAD_DSS, CLF_DSS, build_dl
from clcp.metrics import BinaryMetrics, MultiClassMetrics
from clcp.models.base import TestDL, save_mdl
from ml_utils import init_mlflow, log, log_hydra
from ml_utils.tracking import log_params

OmegaConf.register_new_resolver("eval", eval)  # noqa: S307
OmegaConf.register_new_resolver("split", lambda name: name.split("/")[1].replace(".", "_"))
OmegaConf.register_new_resolver("format", lambda name: name.replace(".", "_"))


def build_test_dls(
    names: list[str],
    mdl_name: str,
    batch_size: int,
    *,
    paired_data: bool,
    is_test: bool,
    is_dummy: bool,
    bin_metrics: list[str],
    multi_metrics: list[str],
) -> list[TestDL]:
    # 'clf' is short-hand for all clf dss
    if "clf" in names:
        clf_dss = [ds for ds in CLF_DSS if ds not in CLF_BAD_DSS]
        names = [name for name in names if name != "clf"] + clf_dss

    te_dls = []
    for name in names:
        dl = build_dl(
            mdl_name=mdl_name,
            name=name,
            split="test",
            batch_size=batch_size,
            paired=paired_data,
            is_test=is_test,
            is_dummy=is_dummy,
        )
        metrics = MultiClassMetrics(metrics=multi_metrics) if name in CLF_DSS else BinaryMetrics(metrics=bin_metrics)
        te_dls.append(TestDL(name=name, dl=dl, metrics=metrics))
    return te_dls


@hydra.main(version_base=None, config_path="../hydra/configs", config_name="config")
def main(cfg) -> None:
    # -- Setup ----------
    init_mlflow()
    mlflow.start_run()
    cfg = OmegaConf.load(".azureml/hydra_config.yaml")
    log_hydra(cfg)

    # -- Model ----------
    mdl = Model.build(arch=cfg.model.arch, backbone=cfg.model.backbone)
    log.info(f"model: \n{mdl}")
    log_params(mdl=mdl, path="mdl_params_pre_training.json")

    # --- Pre-training ---
    if cfg.pretraining.is_on:
        mlflow.set_tag("stage", "pretraining")
        tr_dl = build_dl(
            mdl_name=cfg.model.backbone,
            name=cfg.data.train,
            split="train",
            batch_size=cfg.pretraining.batch_size,
            paired=mdl.requires_paired_inp,
            is_test=cfg.data.is_test,
            is_dummy=cfg.data.is_dummy,
        )
        te_dls = build_test_dls(
            names=cfg.data.test,
            mdl_name=cfg.model.backbone,
            batch_size=cfg.pretraining.batch_size,
            paired_data=mdl.requires_paired_inp,
            is_test=cfg.data.is_test,
            is_dummy=cfg.data.is_dummy,
            bin_metrics=cfg.metrics.binary,
            multi_metrics=cfg.metrics.multiclass,
        )

        hist = mdl.fit(
            tr_dl=tr_dl,
            te_dls=te_dls,
            epochs=cfg.pretraining.epochs,
            lr={
                "backbone": cfg.pretraining.lr_backbone,
                "head": cfg.pretraining.lr_head,
            },
            lr_schedule=cfg.pretraining.lr_schedule,
            freeze_backbone_until=cfg.pretraining.freeze_backbone_until,
            patience=cfg.pretraining.patience,
        )
        mlflow.log_dict(hist, "hist_pretraining.json")
        save_mdl(model=mdl, tokenizer=cfg.model.backbone, mdl_name=cfg.save_mdl.name)

    # --- Finetuning --- TODO: needs to have own data configuration
    # if cfg.finetuning.is_on:
    #     run_name = f"{cfg.model.backbone.split('/')[1].replace('.', '-')}-{cfg.model.arch}"
    #     with mlflow.start_run(run_name=run_name, nested=True):
    #         mlflow.set_tag("stage", "finetuning")

    #         tr_dl = build_dl(
    #             mdl_name=cfg.model.backbone,
    #             name=cfg.data.train,
    #             split="train",
    #             batch_size=cfg.finetuning.batch_size,
    #             paired=mdl.requires_paired_inp,
    #             is_test=cfg.data.is_test,
    #             is_dummy=cfg.data.is_dummy,
    #         )
    #         te_dls = build_test_dls(
    #             names=cfg.data.test,
    #             mdl_name=cfg.model.backbone,
    #             batch_size=cfg.finetuning.batch_size,
    #             paired_data=mdl.requires_paired_inp,
    #             is_test=cfg.data.is_test,
    #             is_dummy=cfg.data.is_dummy,
    #         )

    #         hist = mdl.fit(
    #             tr_dl=tr_dl,
    #             te_dls=te_dls,
    #             epochs=cfg.finetuning.epochs,
    #             lr={"backbone": cfg.finetuning.lr_backbone, "head": cfg.finetuning.lr_head},
    #             lr_schedule=cfg.finetuning.lr_schedule,
    #             freeze_backbone_until=cfg.finetuning.freeze_backbone_until,
    #             patience=cfg.finetuning.patience,
    #         )
    #         mlflow.log_dict(hist, "hist_finetuning.json")

    # log final model
    log_params(mdl=mdl, path="mdl_params_post_training.json")

    mlflow.end_run()


main()
