# tune_hyperparameters_skopt_compare_v2.py

import argparse
import json
import os
import random
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from skopt import BayesSearchCV
from skopt.space import Real, Categorical, Integer

from sklearn.linear_model import LinearRegression, ElasticNet
from sklearn.ensemble import (
    ExtraTreesRegressor,
    RandomForestRegressor,
    GradientBoostingRegressor,
    AdaBoostRegressor,
)
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from data_utils import load_uci_dataset


# List of all available datasets
DEFAULT_DATASETS = [
    "CALIFORNIA_HOUSING", 
    "ABALONE",  
    "BIKE_SHARING"
]


def get_model_definitions(random_state_base=42):
    """
    Returns a list of (model_name, model_class, skopt_search_space) tuples.
    Ensures 'random_state' is part of the search space if applicable, to be fixed by BayesSearchCV.
    """
    model_defs = [
        (
            "Extratree",
            ExtraTreesRegressor,
            {
                "n_estimators": Integer(50, 200),
                "max_depth": Categorical([None, 10, 20, 30]),
                "min_samples_split": Integer(2, 10),
                "min_samples_leaf": Integer(1, 10),
                "max_features": Categorical(["sqrt", "log2", None]),
                "random_state": Categorical([random_state_base]),  # Fixed random state
            },
        ),
        (
            "Random Forest",
            RandomForestRegressor,
            {
                "n_estimators": Integer(50, 200),
                "max_depth": Categorical([None, 10, 20, 30]),
                "min_samples_split": Integer(2, 10),
                "min_samples_leaf": Integer(1, 10),
                "max_features": Categorical(["sqrt", "log2", None]),
                "random_state": Categorical([random_state_base]),
            },
        ),
        (
            "Gradient Boosting",
            GradientBoostingRegressor,
            {
                "n_estimators": Integer(50, 150),
                "learning_rate": Real(1e-3, 0.3, prior="log-uniform"),
                "max_depth": Integer(3, 8),
                "subsample": Real(0.7, 1.0, prior="uniform"),
                "loss": Categorical(["squared_error", "absolute_error", "huber"]),
                "random_state": Categorical([random_state_base]),
            },
        ),
        (
            "Elastic Net",
            ElasticNet,
            {  # ElasticNet does not take random_state in search space for BayesSearchCV typically
                "alpha": Real(1e-3, 1e1, prior="log-uniform"),
                "l1_ratio": Real(0.01, 0.99, prior="uniform"),
                "max_iter": Categorical([20000]),  # Fixed max_iter
            },
        ),
        (
            "Decision Tree",
            DecisionTreeRegressor,
            {
                "max_depth": Categorical([None, 5, 10, 15, 20]),
                "min_samples_split": Integer(2, 15),
                "min_samples_leaf": Integer(1, 15),
                "criterion": Categorical(
                    ["squared_error", "friedman_mse", "absolute_error"]
                ),
                "max_features": Categorical(["sqrt", "log2", None]),
                "random_state": Categorical([random_state_base]),
            },
        ),
        (
            "AdaBoost Reg",
            AdaBoostRegressor,
            {
                "n_estimators": Integer(50, 150),
                "learning_rate": Real(1e-3, 0.5, prior="log-uniform"),
                "loss": Categorical(["linear", "square", "exponential"]),
                "random_state": Categorical([random_state_base]),
            },
        ),
    ]
    return model_defs


def tune_hyperparameters_for_dataset_skopt(
    X_tune,
    y_tune_original,
    model_definitions,
    n_iter_search,
    cv_folds,
    random_state_seed,
):
    """
    Tunes hyperparameters using BayesSearchCV with y-scaling.
    Returns:
        dict: Best hyperparameters found for each model.
        StandardScaler: Fitted scaler for X.
        StandardScaler: Fitted scaler for y.
    Raises:
        ValueError: If a model instance cannot be created or fitting fails fundamentally.
    """
    dataset_best_params = {}
    x_scaler = StandardScaler()
    X_tune_scaled = x_scaler.fit_transform(X_tune)

    y_scaler = StandardScaler()
    y_tune_scaled_2d = y_scaler.fit_transform(y_tune_original.reshape(-1, 1))
    y_tune_scaled_1d = y_tune_scaled_2d.ravel()

    print(
        f"    Original y_tune stats: mean={np.mean(y_tune_original):.2f}, std={np.std(y_tune_original):.2f}"
    )
    print(
        f"    Scaled y_tune stats: mean={np.mean(y_tune_scaled_1d):.2f}, std={np.std(y_tune_scaled_1d):.2f}"
    )

    for name, model_class, search_space_def in model_definitions:
        print(f"    Tuning {name} with BayesSearchCV (y scaled)...")
        if (
            not search_space_def
        ):  # Handles Linear Regression or models with no defined search space
            print(f"      Skipping tuning for {name} (no parameters specified).")
            dataset_best_params[name] = {}
            continue

        # Instantiate model - random_state is now part of search_space_def if applicable
        # For ElasticNet, max_iter is fixed via Categorical.
        # For models not taking random_state (like SVR, LinearRegression), it's fine.
        try:
            model_instance = model_class()
            # If random_state is NOT in search_space_def but model supports it, set it for the estimator instance
            # This is for models like ElasticNet where random_state is an __init__ param but not tuned.
            if (
                "random_state" in model_instance.get_params()
                and "random_state" not in search_space_def
            ):
                model_instance.set_params(random_state=random_state_seed)
            if (
                name == "Elastic Net" and "max_iter" not in search_space_def
            ):  # Ensure max_iter for ElasticNet if not in space
                model_instance.set_params(max_iter=20000)

        except TypeError as te:
            raise ValueError(f"Could not instantiate model {name}: {te}")

        # BayesSearchCV will pass parameters from search_space_def to the model.
        # If 'random_state' is in search_space_def, it will be handled.
        search = BayesSearchCV(
            estimator=model_instance,
            search_spaces=search_space_def,
            n_iter=n_iter_search,
            cv=cv_folds,
            n_jobs=-1,
            random_state=random_state_seed,  # For BayesSearchCV's own sampling process
            verbose=0,
            error_score="raise",  # This will raise errors if a CV fold fails for a param set
        )

        # The fit method itself can raise errors if all trials fail or other issues occur
        search.fit(X_tune_scaled, y_tune_scaled_1d)
        dataset_best_params[name] = dict(search.best_params_)
        print(f"      Best params for {name}: {dataset_best_params[name]}")
        print(
            f"      Best CV score for {name} (on scaled y, default R^2): {search.best_score_:.4f}"
        )

    return dataset_best_params, x_scaler, y_scaler


def compare_tuned_models(
    X_tune_original,
    y_tune_original,
    x_scaler_fitted,
    y_scaler_fitted,
    dataset_best_params,
    model_definitions,
    random_state_seed,  # Passed for consistency if models need it
):
    """
    Compares tuned models pairwise based on pointwise MSE.
    Raises:
        ValueError: If model instantiation or fitting for comparison fails.
    """
    print("\n    Starting pairwise model comparison...")
    model_names = [m[0] for m in model_definitions]
    comparison_matrix = pd.DataFrame(
        0, index=model_names, columns=model_names, dtype=float
    )  # Use float for potential NaNs

    X_tune_scaled_for_comp = x_scaler_fitted.transform(X_tune_original)
    fitted_models = {}

    for name, model_class, _ in model_definitions:
        params = dataset_best_params.get(name, {})

        # Parameters for instantiation. random_state is now typically handled by BayesSearchCV
        # if it was in the search_space_def. If not, we might need to add it here for models
        # that take it in __init__ but wasn't part of the search space.
        # The `get_model_definitions` now includes random_state as Categorical for relevant models.
        # For ElasticNet, max_iter is also fixed via Categorical.

        # Create a copy to modify for instantiation if needed
        instantiation_params = params.copy()

        # For models like ElasticNet that don't have random_state in their search_space_def
        # but take it in __init__ and we want it fixed for this comparison fitting.
        try:
            temp_model_instance = model_class()
            if (
                "random_state" in temp_model_instance.get_params()
                and "random_state" not in instantiation_params
            ):
                instantiation_params["random_state"] = random_state_seed
            if name == "Elastic Net" and "max_iter" not in instantiation_params:
                instantiation_params["max_iter"] = 20000
        except TypeError:
            pass  # Some models like SVR don't have a simple constructor

        model = model_class(**instantiation_params)
        y_tune_scaled_for_fit_2d = y_scaler_fitted.transform(
            y_tune_original.reshape(-1, 1)
        )
        y_tune_scaled_for_fit_1d = y_tune_scaled_for_fit_2d.ravel()

        model.fit(
            X_tune_scaled_for_comp, y_tune_scaled_for_fit_1d
        )  # Errors here will propagate
        fitted_models[name] = model
        # print(f"      Fitted {name} for comparison.")

    for i in range(len(model_names)):
        for j in range(len(model_names)):
            if i == j:
                comparison_matrix.iloc[i, j] = 0.0
                continue

            model_A_name = model_names[i]
            model_B_name = model_names[j]
            model_A = fitted_models.get(
                model_A_name
            )  # Should exist if fit didn't raise error
            model_B = fitted_models.get(model_B_name)

            pred_A_scaled = model_A.predict(X_tune_scaled_for_comp)
            pred_B_scaled = model_B.predict(X_tune_scaled_for_comp)

            pred_A_orig = y_scaler_fitted.inverse_transform(
                pred_A_scaled.reshape(-1, 1)
            ).ravel()
            pred_B_orig = y_scaler_fitted.inverse_transform(
                pred_B_scaled.reshape(-1, 1)
            ).ravel()

            se_A = (y_tune_original - pred_A_orig) ** 2
            se_B = (y_tune_original - pred_B_orig) ** 2

            n_A_better = np.sum(se_A < se_B)
            comparison_matrix.loc[model_A_name, model_B_name] = float(n_A_better)

    print("\n    Pairwise Model Comparison Matrix (N_row_better_than_N_col):")
    print(comparison_matrix)
    return comparison_matrix


def main(args):
    np.random.seed(args.seed)
    random.seed(args.seed)

    all_datasets_hyperparams = {}
    all_datasets_comparison_matrices = {}

    model_definitions = get_model_definitions(random_state_base=args.seed)
    datasets_to_process = args.datasets

    for dataset_name_str in datasets_to_process:
        print(f"\n--- Processing dataset: {dataset_name_str} ---")
        try:
            X, y, _ = load_uci_dataset(dataset_name_str)
        except ValueError as e:  # Catch specific ValueError from load_uci_dataset
            print(
                f"  Failed to load or process dataset {dataset_name_str}: {e}. Skipping."
            )
            # Initialize empty/NaN results for this dataset to allow others to proceed
            all_datasets_hyperparams[dataset_name_str] = {
                m[0]: {} for m in model_definitions
            }
            all_datasets_comparison_matrices[dataset_name_str] = pd.DataFrame(
                np.nan,
                index=[m[0] for m in model_definitions],
                columns=[m[0] for m in model_definitions],
            ).to_dict()
            continue  # Skip to the next dataset

        if X.empty or y.empty:
            print(f"  Dataset {dataset_name_str} is empty after loading. Skipping.")
            all_datasets_hyperparams[dataset_name_str] = {
                m[0]: {} for m in model_definitions
            }
            all_datasets_comparison_matrices[dataset_name_str] = pd.DataFrame(
                np.nan,
                index=[m[0] for m in model_definitions],
                columns=[m[0] for m in model_definitions],
            ).to_dict()
            continue

        min_samples_in_tune_set = args.cv_folds
        if int(X.shape[0] * args.train_ratio_for_tuning) < min_samples_in_tune_set:
            print(
                f"  Not enough samples in {dataset_name_str} ({X.shape[0]}) for tuning with "
                f"{args.cv_folds}-fold CV and train_ratio {args.train_ratio_for_tuning}. "
                f"Required at least {min_samples_in_tune_set} in tuning set. Skipping."
            )
            all_datasets_hyperparams[dataset_name_str] = {
                m[0]: {} for m in model_definitions
            }
            all_datasets_comparison_matrices[dataset_name_str] = pd.DataFrame(
                np.nan,
                index=[m[0] for m in model_definitions],
                columns=[m[0] for m in model_definitions],
            ).to_dict()
            continue

        X_tune, _, y_tune_df, _ = train_test_split(
            X,
            y,
            train_size=args.train_ratio_for_tuning,
            random_state=args.seed,
            shuffle=True,
        )

        y_tune_np = y_tune_df.to_numpy().ravel()

        # If tune_hyperparameters or compare_tuned_models raises an error for this dataset,
        # it will propagate up and stop the script unless caught here.
        # For now, let's assume if one dataset's tuning fails critically, we might want to know.
        dataset_best_params, x_scaler, y_scaler = (
            tune_hyperparameters_for_dataset_skopt(
                X_tune,
                y_tune_np,
                model_definitions,
                args.n_iter,
                args.cv_folds,
                args.seed,
            )
        )
        all_datasets_hyperparams[dataset_name_str] = dataset_best_params

        comparison_matrix_df = compare_tuned_models(
            X_tune_original=X_tune.copy(),
            y_tune_original=y_tune_np.copy(),
            x_scaler_fitted=x_scaler,
            y_scaler_fitted=y_scaler,
            dataset_best_params=dataset_best_params,
            model_definitions=model_definitions,
            random_state_seed=args.seed,
        )
        all_datasets_comparison_matrices[dataset_name_str] = (
            comparison_matrix_df.to_dict()
        )

    # --- Saving results ---
    # Ensure output directories exist
    hyperparams_output_dir = os.path.dirname(args.output_hyperparams_file)
    if hyperparams_output_dir and not os.path.exists(hyperparams_output_dir):
        os.makedirs(hyperparams_output_dir, exist_ok=True)

    comparison_output_dir = os.path.dirname(args.output_comparison_file)
    if comparison_output_dir and not os.path.exists(comparison_output_dir):
        os.makedirs(comparison_output_dir, exist_ok=True)

    with open(args.output_hyperparams_file, "w") as f:
        json.dump(all_datasets_hyperparams, f, indent=4)
    print(f"\nBest hyperparameters saved to {args.output_hyperparams_file}")

    with open(args.output_comparison_file, "w") as f:
        json.dump(all_datasets_comparison_matrices, f, indent=4)
    print(f"Model comparison matrices saved to {args.output_comparison_file}")

    print(
        "\nReminder: Y-scaling was applied during tuning. Ensure your main "
        "experimental script handles this consistently (scales y_train, inverse-transforms predictions)."
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Hyperparameter Tuning and Comparison Script using Scikit-Optimize."
    )
    parser.add_argument(
        "--datasets",
        nargs="+",
        default=DEFAULT_DATASETS,
        help=f"List of UCI dataset names to tune. If not provided, tunes all available datasets: {', '.join(DEFAULT_DATASETS)}",
    )
    parser.add_argument(
        "--train_ratio_for_tuning",
        type=float,
        default=0.6,
        help="Proportion of data for tuning.",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=25,
        help="Number of parameter settings sampled by BayesSearchCV.",
    )
    parser.add_argument(
        "--cv_folds", type=int, default=3, help="Number of cross-validation folds."
    )
    parser.add_argument(
        "--output_hyperparams_file",
        type=str,
        default="results/tuned_hyperparams_skopt_y_scaled.json",
        help="File to save tuned hyperparameters.",
    )
    parser.add_argument(
        "--output_comparison_file",
        type=str,
        default="results/model_comparison_matrices.json",
        help="File to save model comparison matrices.",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    args = parser.parse_args()
    main(args)
