import wandb
from .util import train_and_test, set_cfg_value


def run_ablation(cfg):
    assert isinstance(cfg.head_model_list, list)
    assert isinstance(cfg.lr_list, list)
    assert isinstance(cfg.bz_list, list)
    assert isinstance(cfg.opt_list, list)
    assert isinstance(cfg.hidden_size_list, list)
    assert isinstance(cfg.shuffle_list, list)

    # STD Settings
    # hyperparams_std = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
    # if cfg.use_wandb:
    #     run_name = f"STD Settings: head_model={cfg.head_model_type}, lr={cfg.learning_rate}, bz={cfg.batch_size}, opt={cfg.optimizer_type}, hidden_size={cfg.hidden_size}, shuffle={cfg.shuffle}, k={cfg.k}"
    #     wandb_run = wandb.init(project=cfg.wandb_proj_name, config=hyperparams_std, name=run_name)
    # train_and_test(cfg)
    # if cfg.use_wandb:
    #     wandb_run.finish()

    # Ablation Settings
    ablation_dict = {
        'learning_rate': cfg.lr_list,
        'batch_size': cfg.bz_list,
        'optimizer_type': cfg.opt_list,
        'hidden_size': cfg.hidden_size_list,
        'shuffle': cfg.shuffle_list,
        'head_model_type': cfg.head_model_list,
    }
    for ablation_name, ablation_list in ablation_dict.items():
        for ablation_value in ablation_list:
            successful = set_cfg_value(cfg, ablation_name, ablation_value)
            if not successful:
                break

            if cfg.use_wandb:
                hyperparams = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
                run_name = f"{ablation_name}={ablation_value}"
                wandb_run = wandb.init(project=cfg.wandb_proj_name, config=hyperparams, name=run_name)
            train_and_test(cfg)
            if cfg.use_wandb:
                wandb_run.finish()

        # _ = set_cfg_value(cfg, ablation_name, hyperparams_std[ablation_name])


def run_different_LLMs(cfg):
    # STD Settings
    # hyperparams_std = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
    # if cfg.use_wandb:
    #     run_name = f"seq_type={cfg.seq_type}, LLM={cfg.embed_name}, k={cfg.k}"
    #     wandb_run = wandb.init(project=cfg.wandb_proj_name, config=hyperparams_std, name=run_name)
    # train_and_test(cfg)
    # if cfg.use_wandb:
    #     wandb_run.finish()

    for i in range(len(cfg.k_list)):
        cfg.k = cfg.k_list[i]
        for lr in cfg.lr_list:
            cfg.learning_rate = lr
            if cfg.use_wandb:
                hyperparams = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
                run_name = f"seq_type={cfg.seq_type}, LLM={cfg.embed_name}, k={cfg.k}, lr={cfg.learning_rate}"
                wandb_run = wandb.init(project=cfg.wandb_proj_name, config=hyperparams, name=run_name)
            train_and_test(cfg)
            if cfg.use_wandb:
                wandb_run.finish()


    # Different LLM
    # for i in range(len(cfg.llm_list)):
    #     cfg.embed_name = cfg.llm_list[i]
    #     cfg.seq_type = cfg.seq_type_list[i]
    #     cfg.num_features = cfg.feature_list[i]
    #     cfg.k = cfg.k_list[i]

    #     if cfg.use_wandb:
    #         hyperparams = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
    #         run_name = f"seq_type={cfg.seq_type}, LLM={cfg.embed_name}, k={cfg.k}"
    #         wandb_run = wandb.init(project=cfg.wandb_proj_name, config=hyperparams, name=run_name)
    #     train_and_test(cfg)
    #     if cfg.use_wandb:
    #         wandb_run.finish()


def run_finetune(cfg):
    assert len(cfg.num_features) == len(cfg.backbone)
    num_features = cfg.num_features
    backbone_names = cfg.backbone
    seq_types = cfg.seq_type
    cfgs = [(backbone, i, lr, opt) for i, backbone in enumerate(cfg.backbone) for lr in cfg.learning_rate 
            for opt in cfg.optimizer_type]
    for backbone, i, lr, opt in cfgs:
        cfg.num_features = num_features[i]
        cfg.seq_type = seq_types[i]
        cfg.backbone = backbone
        cfg.learning_rate = lr
        cfg.optimizer_type = opt
        if cfg.use_wandb:
            hyperparams = {attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("__")}
            run_name = f"Finetune: seq_type={cfg.seq_type}, LLM={cfg.backbone}, k={cfg.k}, lr={cfg.learning_rate}"
            wandb_run = wandb.init(project=cfg.wandb_proj_name, entity=cfg.entity, config=hyperparams, name=run_name)
        train_and_test(cfg)
        if cfg.use_wandb:
            wandb_run.finish()