import os
import sys

import hydra


# Add the shared library to the path, else the modules in the
# library will not be able to import from each other
sys.path.insert(0, os.path.dirname(__file__) + "/lib_llm/lib_llm")
sys.path.insert(0, os.path.dirname(__file__) + "/lib_dl/lib_dl")

# from experiments.named_entity_detection.config import NEDHandle
# from experiments.incremental_memorization.config import IMHandle

from experiments.memorization_hyperparam_rel.config import MHRHandle

# from experiments.relevant_context_search.config import RCSHandle
# from experiments.same_data_different_division.config import SDDDHandle
from experiments.model_training.config import MTHandle

# from experiments.context_size.config import CSHandle
from experiments.prefix_performance.config import PPHandle
from lib_dl.analysis.project import ProjectConfig, run_project, setup_project


experiment_handles = [
    # IMHandle,
    # NEDHandle,
    MHRHandle,
    # CSHandle,
    # RCSHandle,
    PPHandle,
    # SDDDHandle,
    MTHandle,
]
setup_project(experiment_handles)


CONFIG_PATH = "../conf"
CONFIG_NAME = "config"

local_rank = -1


@hydra.main(
    version_base=None,
    config_path=CONFIG_PATH,
    config_name=CONFIG_NAME,
)
def main(cfg: ProjectConfig) -> None:
    run_project(
        cfg,
        experiment_handles,
        handle_args={"local_rank": local_rank},
        # Use different ports to not conflict with other experiments
        nb_port=8881,
        tb_port=6007,
    )


def parse_local_rank() -> int:
    """Parse the --local_rank=<num> argument set by deepseed.
    Forward the rest of the arguemnts to Hydra."""
    global local_rank
    remaining_args = []
    for arg in sys.argv:
        if arg.startswith("--local_rank="):
            local_rank = int(arg[len("--local_rank=") :])
            os.environ["LOCAL_RANK"] = str(local_rank)
        else:
            remaining_args.append(arg)
    sys.argv = remaining_args
    return local_rank


if __name__ == "__main__":
    parse_local_rank()
    main()
