import csv
from skopt.space import Integer, Real, Categorical
from skopt import gp_minimize
from skopt.utils import use_named_args
from gnnboundary.tuning.boundary_graph_generation import run_training


# Hyperparameter space
space = [
    Integer(20, 60, name="target_size"),  # The targeted number of edges in a graph.
    Categorical(["0.45-0.55", "0.4-0.6", "0.35-0.65"], name="target_probs"),
    Real(0.01, 1, prior="uniform", name="learning_rate"),
    Real(
        0.05, 0.5, prior="log-uniform", name="temperature"
    ),  # Temperature for the categorical sampling of the node feature distribution
    Real(
        1.05, 1.2, prior="uniform", name="w_budget_inc"
    ),  # Budget increase for the graph size regulation
    Real(
        0.94, 0.99, prior="uniform", name="w_budget_dec"
    ),  # Budget decrease for the graph size regulation
]

DEFAULT_FIXED_PARAMS = {
    "iterations": 1000,
    "k_samples": 32,
    "w_budget_init": 1,
}
TARGET_PROBS_MAP = {
    "0.45-0.55": (0.45, 0.55),
    "0.4-0.6": (0.4, 0.6),
    "0.35-0.65": (0.35, 0.65),
}
INVERSE_TARGET_PROBS_MAP = {
    "(0.45, 0.55)": "0.45-0.55",
    "(0.4, 0.6)": "0.4-0.6",
    "(0.35, 0.65)": "0.35-0.65",
}


# CSV file setup
# log_file = "hpo/hpo_results_motif.csv"
# log_file = "hpo/hpo_results_collab.csv"
log_file = "hpo_imdb.csv"

# Write the CSV header
with open(log_file, mode="w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(
        [
            *DEFAULT_FIXED_PARAMS.keys(),
            "target_size",
            "target_probs",
            "learning_rate",
            "temperature",
            "w_budget_inc",
            "w_budget_dec",
            "Custom Loss",
        ]
    )


# Bayesian optimization wrapper to use named arguments + logging
@use_named_args(space)
def objective(fixed_params: dict = DEFAULT_FIXED_PARAMS, **params):
    if "target_probs" in params:
        params["target_probs"] = TARGET_PROBS_MAP[params["target_probs"]]

    # Using 10 runs instead of 500 to save resources.
    # NOTE: Empirically, the 3 runs were very consistent with deviations of only around 1% of each other.
    loss = run_training(fixed_params=fixed_params, num_runs=3, **params)

    # Log hyperparameters and results to CSV
    with open(log_file, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([*list(fixed_params.values()), *list(params.values()), loss])

    # Return the loss to minimize
    return loss


def run_bayesian_optimization(n_calls=200, n_random_starts=8, n_jobs=-1, prev_results: tuple = None):
    """
    Runs Bayesian optimization to tune hyperparameters using parallelization.
    Our Laptop that we run this on has 8 CPU cores.

    Parameters:
        n_calls (int): Number of function evaluations.
        n_random_starts (int): Number of initial random points.
        n_jobs (int): Number of parallel jobs (-1 uses all available cores).

    Returns:
        Result of the optimization process.
    """
    result = gp_minimize(
        func=objective,
        dimensions=space,
        n_calls=n_calls,
        n_random_starts=n_random_starts,
        n_jobs=n_jobs,
        x0=prev_results[0] if prev_results else None,
        y0=prev_results[1] if prev_results else None,
        verbose=True,
    )
    return result


def load_previous_results(log_file: str) -> tuple:
    """
    Load the hyperparameter tuning results from a CSV file.

    Parameters:
        log_file (str): Path to the CSV file.

    Returns:
        X (list of lists): Hyperparameter values.
        y (list): Corresponding losses.
    """

    X = []
    y = []
    with open(log_file, mode="r") as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            # Extract hyperparameters and loss
            hyperparams = row[len(DEFAULT_FIXED_PARAMS):-1]
            loss = float(row[-1])
            
            # Convert values to the appropriate types
            target_size = int(hyperparams[0])
            target_probs = INVERSE_TARGET_PROBS_MAP[hyperparams[1]]
            learning_rate = float(hyperparams[2])
            temperature = float(hyperparams[3])
            w_budget_inc = float(hyperparams[4])
            w_budget_dec = float(hyperparams[5])

            X.append([target_size, target_probs, learning_rate, temperature, w_budget_inc, w_budget_dec])
            y.append(loss)
    return X, y


if __name__ == "__main__":
    prev_results = None # load_previous_results("hpo/hpo_imdb.csv") # Start from previous results.
    res = run_bayesian_optimization(n_calls=200, n_random_starts=8, prev_results=prev_results)
    print("Finished HPO!")
    print(res)
