import argparse
import importlib
import os

import huggingface_hub
import torch
from dotenv import load_dotenv
from transformers import AutoTokenizer

from scripts.runner.evolve import run_evolve_cpu, run_evolve_gpu, run_evolve_cpu2
from scripts.utils.pure import (
    SimpleDataLoder,
    configure_args,
    configure_default,
    configure_path,
    get_metrics,
)
from scripts.utils.task_dependents import get_data, get_dataset


def main(cfg):
    data = get_data(cfg)
    print(data[0] if isinstance(data, tuple) else data)
    dataset = get_dataset(cfg, data)
    dataloader = SimpleDataLoder(dataset)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(cfg.vlm.name, trust_remote_code=True)

    metrics = get_metrics(cfg, tokenizer)

    # wandb.init(
    #     project="xray-evo-merge",
    #     name=cfg.config_name,
    # )
    if cfg.l4:
        run_evolve_cpu2(
            cfg,
            cfg.exclude_param_names_regex,
            tokenizer,
            dataloader,
            metrics,
            device,
        )

    [run_evolve_cpu, run_evolve_gpu][cfg.gpu](
        cfg,
        cfg.exclude_param_names_regex,
        tokenizer,
        dataloader,
        metrics,
        device,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, default="config")
    parser.add_argument("--default_cfg", type=str)
    parser.add_argument("--gpu", action="store_true")
    parser.add_argument("--l4", action="store_true")
    parser.add_argument("--dev_null", type=str)
    args = parser.parse_args()
    cfg = importlib.import_module(f"scripts.config.{args.config}")
    configure_args(cfg, args)
    configure_default(cfg, args)
    configure_path(cfg)
    load_dotenv()
    huggingface_hub.login(os.getenv("HF_TOKEN"))
    # wandb.login(key=os.getenv("WANDB_API_KEY"))

    main(cfg)
