from config import *
import optuna
from functools import partial
from training import train_loop

def objective(trial, result_df, num_features_, num_users_, limit, dataset):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)

    meta_lr = trial.suggest_float('meta_lr', 1e-5, 1., log=True)
    main_lr = trial.suggest_float('main_lr', 1e-5, 1., log=True)
    neighbour_margin = trial.suggest_float('neighbour_margin', 0.1, 5., log=True)
    lambda_reg = trial.suggest_float('lambda_reg', 1e-4, 1e-2, log=True)
    loss_reg = trial.suggest_float('loss_reg', 1e-4, 1., log=True)
    alpha = trial.suggest_float('alpha', 0.1, 0.9)
    batch_size = trial.suggest_categorical('batch_size', [1, 2, 4, 16])
    tepoch = trial.suggest_int('tepoch', 10, 100, step=10)
    cold_start = trial.suggest_int('cold_start', 10, max(tepoch - 10, 11))
    increase_after = trial.suggest_int('increase_after', 1, max(tepoch - 10, 2), step=1)
    increase_factor = trial.suggest_float('increase_factor', 0.05, 0.5, log=True)
    dynamic_neighbour = trial.suggest_categorical('dynamic_neighbour', [True, False])
    increase_limit = trial.suggest_float('increase_limit', 1., 10., log=True)
    training_points = limit
    init_weight = trial.suggest_categorical('init_weight', ['zero', 'uniform'])
    # (Optionally) Print current hyperparameters for debugging
    print(f"Trial parameters: meta_lr={meta_lr}, main_lr={main_lr}, neighbour_margin={neighbour_margin}, "
          f"lambda_reg={lambda_reg}, loss_reg={loss_reg}, alpha={alpha}, batch_size={batch_size}, cold_start={cold_start}, tepoch={tepoch}"
          f"increase_after={increase_after}, increase_factor={increase_factor}, increase_limit={increase_limit}, dynamic_neighbour={dynamic_neighbour}"
          f"training_points={training_points}, init_weight={init_weight}")

    # Run the training loop with the suggested hyperparameters.
    acc_list, auc_list, total_loss_list, meta_loss_list = train_loop(
        num_features=num_features_,
        num_users=num_users_,
        tepoch=tepoch,
        meta_lr=meta_lr,
        main_lr=main_lr,
        neighbour_margin=neighbour_margin,
        lambda_reg=lambda_reg,
        loss_reg=loss_reg,
        alpha=alpha,
        batch_size=batch_size,
        cold_start=cold_start,
        increase_after=increase_after,
        increase_factor=increase_factor,
        increase_limit=increase_limit,
        dynamic_neighbour=dynamic_neighbour,
        # test_users = test_users_,
        training_points=training_points,
        init_weight=init_weight,
        result_df=result_df,
        dataset_type=dataset,
    )

    final_acc = np.percentile(acc_list, 5)
    final_auc = np.percentile(auc_list, 5)
    ACC_THRESHOLD = 0.5
    AUC_THRESHOLD = 0.5
    if final_acc < ACC_THRESHOLD and final_auc < AUC_THRESHOLD:
        print(f"Pruning trial {trial.number} due to low performance: ACC={final_acc}, AUC={final_auc}")
        raise optuna.exceptions.TrialPruned()

    return final_acc, final_auc

def run_optuna(dataset, limit, trials, result_df, num_features, num_users):
    obj_with_df = partial(objective, result_df=result_df, num_features_=num_features, num_users_=num_users, limit=limit, dataset = dataset)
    study = optuna.create_study(
        directions=['maximize'] * 2,
        storage=f"sqlite:///optuna_study.db",
        study_name=f"{dataset}_limit_{limit}",
        sampler=optuna.samplers.TPESampler(multivariate=True, seed=SEED),
        load_if_exists=True
    )
    study.optimize(obj_with_df, n_trials=trials)
    print(f"Best params: {study.best_params}")
