from functools import partial
import ray
import numpy as np
import hydra
from ray import tune, air, train
from ray.air import session
from ray.tune.search.optuna import OptunaSearch
import optuna


def run_training_hyperparameter(tune_params, cfg):
    from train import run_training

    # Optimizer
    cfg.optimizer.params.lr = 10.0 ** tune_params["lr_log10"]
    cfg.optimizer.params.weight_decay = 10.0 ** tune_params["wd_log10"]
    cfg.model.ffn_dropout = tune_params["ffn_dropout"]

    # Attention
    cfg.model.block.attn_dropout = tune_params["attn_dropout"]
    cfg.model.n_layers = tune_params["tf_depth"]
    cfg.model.block.heads = int(2 ** tune_params["attn_heads_log2"])

    # Feat encoder
    cfg.model.feat_encoder.emb_dim = int(2 ** tune_params["feat_d_log2"])
    cfg.model.feat_encoder.hidden_dim = int(2 ** tune_params["feat_dh_log2"])

    # Eigen subset
    cfg.eigen_subset.args.num = cfg.dataset.num_nodes

    # Pos encoder
    cfg.model.pos_encoder.emb_dim = int(2 ** tune_params["pos_d_log2"])
    if cfg.model.pos_encoder.name == "src.pos_encoder.mlp.MLP":
        cfg.model.pos_encoder.hidden_dim = int(2 ** tune_params["pos_dh_log2"])
    elif cfg.model.pos_encoder.name == "src.pos_encoder.deep_mlp.DeepMLP":
        cfg.model.pos_encoder.hidden_dim = int(2 ** tune_params["pos_dh_log2"])
        cfg.model.pos_encoder.n_hidden_layers = tune_params["n_pos_hidden_layers"]
    else:
        raise ValueError(f"Unknown pos_encoder {cfg.model.pos_encoder.name}")

    # NAGphormer specific
    if cfg.model.base_class == "src.models.nagphormer_ours.NAGphormerOurs":
        cfg.model.num_hops = tune_params["num_hops"]

    try:
        split_stats = run_training(cfg)
    except Exception as e:
        import traceback

        print(f"Training failed with error: {str(e)}")
        print("Stack trace:")
        print(traceback.format_exc())
        print("Config:")
        print(cfg)
        raise

    # Average val and test metric measurements
    assert len(split_stats) == len(cfg.dataset.split_index)

    split_val_acc = []
    split_test_acc = []
    for stats in split_stats:
        split_val_acc.append(stats[-1][f"Best/Val/{cfg.dataset.metric}"])
        split_test_acc.append(stats[-1][f"Best/Test/{cfg.dataset.metric}"])

    mean_val = np.mean(split_val_acc)
    mean_test = np.mean(split_test_acc)

    train.report({"val": mean_val})
    session.report({"val": mean_val})

    return {"val": mean_val, "test": mean_test}


def create_search_space(trial: optuna.Trial, cfg):
    # optimizer
    trial.suggest_float("lr_log10", -4, -1)
    trial.suggest_float("wd_log10", -7, -2)
    trial.suggest_float("ffn_dropout", 0.0, 0.5)

    # Feat encoder
    feat_emb_dim_log2 = trial.suggest_int("feat_d_log2", 3, 7)
    trial.suggest_int("feat_dh_log2", 4, 7)

    # Pos encoder
    pos_emb_dim_log2 = trial.suggest_int("pos_d_log2", 3, 7)
    if cfg.model.pos_encoder.name == "src.pos_encoder.mlp.MLP":
        trial.suggest_int("pos_dh_log2", 4, 11)
    elif cfg.model.pos_encoder.name == "src.pos_encoder.deep_mlp.DeepMLP":
        trial.suggest_int("pos_dh_log2", 4, 11)
        trial.suggest_int("n_pos_hidden_layers", 1, 4)
    else:
        raise ValueError(f"Unknown pos_encoder {cfg.model.pos_encoder.name}")

    #  Attention
    trial.suggest_float("attn_dropout", 0.0, 0.5)
    trial.suggest_int("tf_depth", 1, 8)

    # head dimension constraint from using flash attention
    total_head_dim = (2**feat_emb_dim_log2) + (2**pos_emb_dim_log2)
    assert "Factorized" not in cfg.model.block._target_
    if total_head_dim >= 512:
        trial.suggest_int("attn_heads_log2", 2, 3)
    elif total_head_dim >= 256:
        trial.suggest_int("attn_heads_log2", 1, 3)
    else:
        trial.suggest_int("attn_heads_log2", 0, 3)

    # Nagphormer specific
    if cfg.model.base_class == "src.models.nagphormer_ours.NAGphormerOurs":
        trial.suggest_int("num_hops", 2, 20)  # from the NAGphormer paper


@hydra.main(version_base="1.3", config_path="../configs", config_name="ray.yaml")
def main(cfg):

    ray.init(address=cfg.ray.address, log_to_driver=cfg.ray.log_to_driver)

    resources_arr = {"cpu": cfg.ray.cpu_req, "gpu": cfg.ray.gpu_req}
    home_name = cfg.out_dir
    algo = OptunaSearch(
        space=partial(create_search_space, cfg=cfg),
        metric="val",
        mode="max",
        seed=0,
    )

    num_samples = cfg.dataset.num_tuning_trials
    print(f"Will tune for {num_samples} trials")

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(partial(run_training_hyperparameter, cfg=cfg)),
            resources=resources_arr,
        ),
        tune_config=tune.TuneConfig(
            search_alg=algo,
            num_samples=num_samples,
            max_concurrent_trials=cfg.ray.max_concurrent_trials,
        ),
        run_config=air.RunConfig(
            stop=None,
            storage_path=f"file:///{home_name}/hyp_tuning",
            name=f"{cfg.ray.name}",
        ),
        # param_space=search_space,
    )

    results = tuner.fit()
    best_trial = results.get_best_result(metric="val", mode="max")

    print("Best trial config: {}".format(best_trial.config))
    print(best_trial)
    print("Best trial directory path:", best_trial.path)


if __name__ == "__main__":
    main()
