from functools import partial
import ray
import numpy as np
import hydra
from ray import tune, air, train
from ray.air import session
from ray.tune.search.optuna import OptunaSearch
import optuna


def run_training_hyperparameter(tune_params, cfg):
    from train import run_training

    # Optimizer
    cfg.optimizer.params.lr = 10.0 ** tune_params["lr_log10"]
    cfg.optimizer.params.weight_decay = 10.0 ** tune_params["wd_log10"]
    cfg.model.ffn_dropout = tune_params["ffn_dropout"]

    # Eigen subset
    cfg.eigen_subset.args.num = tune_params["n_eigen"]

    # Attention
    cfg.model.block.attn_dropout = tune_params["attn_dropout"]
    cfg.model.n_layers = tune_params["tf_depth"]
    cfg.model.block.heads = int(2 ** tune_params["attn_heads_log2"])

    if cfg.model.base_class == "src.models.nagphormer.NAGphormer":
        cfg.model.num_hops = tune_params["num_hops"]
        cfg.model.hidden_dim = int(2 ** tune_params["tf_hidden_dim_log2"])
    elif cfg.model.base_class == "src.models.gt_baseline.GT":
        cfg.model.hidden_dim = int(2 ** tune_params["tf_hidden_dim_log2"])
    elif cfg.model.base_class == "src.models.simple.SimpleTransformer":  # GPS
        # Feat encoder
        assert cfg.model.feat_encoder.name == "src.feat_encoder.mlp.MLP"
        cfg.model.feat_encoder.emb_dim = int(2 ** tune_params["feat_d_log2"])
        cfg.model.feat_encoder.hidden_dim = int(2 ** tune_params["feat_dh_log2"])
        # Pos encoder
        assert cfg.model.pos_encoder.name == "src.pos_encoder.laplacian.Laplacian"
        cfg.model.pos_encoder.params.layers = tune_params["n_lapPE_layers"]
        cfg.model.pos_encoder.params.post_layers = tune_params["n_lapPE_post_layers"]
        cfg.model.pos_encoder.emb_dim = int(2 ** tune_params["pos_d_log2"])
    else:
        raise ValueError(f"Unknown model {cfg.model.base_class}")

    try:
        split_stats = run_training(cfg)
    except Exception as e:
        import traceback

        print(f"Training failed with error: {str(e)}")
        print("Stack trace:")
        print(traceback.format_exc())
        print("Config:")
        print(cfg)
        raise

    # Average val and test metric measurements
    assert len(split_stats) == len(cfg.dataset.split_index)

    split_val_acc = []
    split_test_acc = []
    for stats in split_stats:
        split_val_acc.append(stats[-1][f"Best/Val/{cfg.dataset.metric}"])
        split_test_acc.append(stats[-1][f"Best/Test/{cfg.dataset.metric}"])

    mean_val = np.mean(split_val_acc)
    mean_test = np.mean(split_test_acc)

    train.report({"val": mean_val})
    session.report({"val": mean_val})

    return {"val": mean_val, "test": mean_test}


def create_search_space(trial: optuna.Trial, cfg):
    # optimizer
    trial.suggest_float("lr_log10", -4, -1)
    trial.suggest_float("wd_log10", -7, -2)
    trial.suggest_float("ffn_dropout", 0.0, 0.5)

    # Eigen subsetting
    trial.suggest_int("n_eigen", 4, 16)

    #  Attention
    trial.suggest_float("attn_dropout", 0.0, 0.5)
    trial.suggest_int("tf_depth", 1, 8)

    if cfg.model.base_class == "src.models.nagphormer.NAGphormer":
        trial.suggest_int("num_hops", 2, 20)  # from the NAGphormer paper
        tf_hidden_log2 = trial.suggest_int("tf_hidden_dim_log2", 6, 9)
        total_head_dim = 2**tf_hidden_log2
    elif cfg.model.base_class == "src.models.gt_baseline.GT":
        tf_hidden_log2 = trial.suggest_int("tf_hidden_dim_log2", 6, 9)
        total_head_dim = 2**tf_hidden_log2
    elif cfg.model.base_class == "src.models.simple.SimpleTransformer":  # GPS
        # feat encoder
        assert cfg.model.feat_encoder.name == "src.feat_encoder.mlp.MLP"
        feat_emb_dim_log2 = trial.suggest_int("feat_d_log2", 3, 7)
        trial.suggest_int("feat_dh_log2", 4, 7)
        # pos encoder
        assert cfg.model.pos_encoder.name == "src.pos_encoder.laplacian.Laplacian"
        trial.suggest_int("n_lapPE_layers", 1, 8)
        trial.suggest_int("n_lapPE_post_layers", 0, 4)
        pos_emb_dim_log2 = trial.suggest_int("pos_d_log2", 3, 7)

        total_head_dim = (2**feat_emb_dim_log2) + (2**pos_emb_dim_log2)
    else:
        raise ValueError(f"Unknown model {cfg.model.base_class}")

    # head dimension constraint from using flash attention
    assert "Factorized" not in cfg.model.block._target_
    if total_head_dim >= 512:
        trial.suggest_int("attn_heads_log2", 2, 3)
    elif total_head_dim >= 256:
        trial.suggest_int("attn_heads_log2", 1, 3)
    else:
        trial.suggest_int("attn_heads_log2", 0, 3)


@hydra.main(version_base="1.3", config_path="../configs", config_name="ray.yaml")
def main(cfg):

    ray.init(address=cfg.ray.address, log_to_driver=cfg.ray.log_to_driver)

    resources_arr = {"cpu": cfg.ray.cpu_req, "gpu": cfg.ray.gpu_req}
    home_name = cfg.out_dir
    algo = OptunaSearch(
        space=partial(create_search_space, cfg=cfg),
        metric="val",
        mode="max",
        seed=0,
    )

    num_samples = cfg.dataset.num_tuning_trials
    print(f"Will tune for {num_samples} trials")

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(partial(run_training_hyperparameter, cfg=cfg)),
            resources=resources_arr,
        ),
        tune_config=tune.TuneConfig(
            search_alg=algo,
            num_samples=num_samples,
            max_concurrent_trials=cfg.ray.max_concurrent_trials,
        ),
        run_config=air.RunConfig(
            stop=None,
            storage_path=f"file:///{home_name}/hyp_tuning",
            name=f"{cfg.ray.name}",
        ),
        # param_space=search_space,
    )

    results = tuner.fit()
    best_trial = results.get_best_result(metric="val", mode="max")

    print("Best trial config: {}".format(best_trial.config))
    print(best_trial)
    print("Best trial directory path:", best_trial.path)


if __name__ == "__main__":
    main()
