import os

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

from models.models import NormedGNN, NormedGNN_Residuals
from utils.data_utils import get_dataloaders
from utils.training_utils import (
    setup_pytorch,
    get_device,
    get_dist_grid_codes,
    train
)

def train_gnn(config, model_class, data_dir, grids):
    device = get_device()
    loader_train, loader_val, _ = get_dataloaders(data_dir,
                                                  training_grids=grids,
                                                  testing_grid=None,
                                                  batch_size=config["batch_size"])
    model = model_class(
        num_layers=config["num_layers"]
    ).to(device)

    _, _, best_val_loss, corresponding_train_loss, _, _ = \
        train(model,
              device,
              loader_train,
              loader_val,
              epochs=config['epochs'],
              learning_rate=config['lr'],
              early_stopping=True,
              patience=500,
              best_val_weights=True)

    tune.report({"validation_performance": best_val_loss, "train_performance": corresponding_train_loss})

def run_hyperparameter_tuning(gpus_per_trial=1):
    # Set up pytorch and training
    setup_pytorch()

    # Select the model type
    model_classes = {'n-gnn': NormedGNN, 'n-gnn-residuals': NormedGNN_Residuals}
    model_class = model_classes['n-gnn']

    # Configuration space
    config = {
        "model_class": model_class,
        "batch_size": tune.grid_search([64, 128, 256]),
        "lr": tune.grid_search([1e-3]),
        "epochs": tune.grid_search([1000]),
        "num_layers": tune.grid_search([3, 4, 5, 6, 7, 8, 9, 10]),
    }
 
    # Scheduler for early stopping
    scheduler = ASHAScheduler(
        metric="validation_performance",
        mode="min",
        max_t=3000,
        grace_period=10,
        reduction_factor=2
    )
    
    # Progress reporter
    reporter = CLIReporter(
        metric_columns=["validation_performance", "train_performance"]
    )
    
    # Initialize Ray
    ray.init(num_cpus=4, num_gpus=gpus_per_trial)
    
    # Dataset details
    data_dir = os.path.abspath('data/ENGAGE_dataset/')
    all_grids = get_dist_grid_codes(scenario=1)
    
    # Run the hyperparameter search
    result = tune.run(
        tune.with_parameters(train_gnn, model_class=model_class, data_dir=data_dir, grids=all_grids),
        resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
        config=config,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="n-gnn_hyperparameter_tuning"
    )
    
    # Get the best configuration
    best_trial = result.get_best_trial("validation_performance", "min", "last")
    print("Best trial config:", best_trial.config)
    print("Best trial final validation loss:", best_trial.last_result["validation_performance"])
    print("Best trial corresponding train loss:", best_trial.last_result["train_performance"])
    
    # Return the best configuration
    return best_trial.config

if __name__ == '__main__':
    # Run the hyperparameter tuning
    best_config = run_hyperparameter_tuning()
    print(f"Best number_layers is: {best_config['num_layers']}\n")
