import wandb
import traceback
from itertools import product
from src.tool import load_config, set_seed, wait_for_gpu_memory
import src.model as generate_utils
import src.task as task_utils


def sweep(cfg):
    # All hyperparameters with lists will be ablated
    param_dict = {}
    for k in dir(cfg):
        v = getattr(cfg, k)
        if isinstance(v, list):
            param_dict[k] = v

    assert "model" not in param_dict, "model should not be a list to sweep"
    assert "generate_mode" not in param_dict, (
        "generate_mode should not be a list to sweep"
    )

    keys = param_dict.keys()
    values = param_dict.values()
    param_combinations = [
        dict(zip(keys, combination)) for combination in product(*values)
    ]

    wait_for_gpu_memory(getattr(cfg, "gpu_memory", 0), getattr(cfg, "time_least_wait", 10))

    # Load the model
    model = getattr(generate_utils, cfg.model)(cfg)

    for i, ablation in enumerate(param_combinations):
        print(
            f"\nNow running {i + 1} / {len(param_combinations)} combinations: {ablation}"
        )

        for k, v in ablation.items():
            setattr(cfg, k, v)

        # Set seed
        set_seed(cfg.seed)

        # Run
        hyperparams = {
            attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")
        }
        print(hyperparams)
        run_name = "_".join([k for k in keys if len(param_dict[k]) > 1])

        if cfg.use_wandb:
            wandb_run = wandb.init(
                project=cfg.wandb_proj_name,
                entity=cfg.entity,
                config=hyperparams,
                name=getattr(cfg, "wandb_name", "") + "_" + run_name,
            )

        # Experiment with the model
        try:
            getattr(task_utils, f"task_{cfg.dataset}")(cfg=cfg, model=model)
        except Exception:
            traceback.print_exc()

        # Finish
        if cfg.use_wandb:
            wandb_run.finish()


if __name__ == "__main__":
    cfg = load_config()
    func = globals().get(cfg.task)
    if func and callable(func):
        func(cfg)
    else:
        print(f"Function '{cfg.task}' not found or not callable.")
