"""
Model selection and hyperparameter optimization utilities for Beta-Binomial model.
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .beta_binomial_model import (
    fit_beta_binomial_model,
    predict,
)


def grid_search_beta_binomial_priors(
    x_data,
    successes,
    trials,
    hyperparams_grid=None,
    n_samples=1000,
    tune=500,
    verbose=True,
    plot_results=True,
    progressbar=True,
):
    """
    Perform grid search to find optimal prior parameters for Beta-Binomial model.

    Args:
        x_data: Input data (capability differences)
        successes: Number of successful attacks (count data)
        trials: Number of trials for each data point
        hyperparams_grid: Dictionary of hyperparameter lists to try, e.g.,
                          {'sigma_w': [0.1, 0.5, 1.0], 'sigma_b': [0.1, 0.5, 1.0]}
                          If None, defaults to standard grid
        n_samples: Number of MCMC samples for each fit
        tune: Number of tuning steps for each fit
        verbose: Whether to print progress
        plot_results: Whether to plot the results
        progressbar: Whether to display PyMC sampling progress bar

    Returns:
        Dictionary with:
            - results_df: DataFrame with all results
            - best_params: Dictionary with best parameters
            - best_model: Tuple (model, trace) for best parameters
    """
    # Initialize results storage
    results = []

    # Set default hyperparameter grid if none provided
    if hyperparams_grid is None:
        hyperparams_grid = {
            "sigma_w": [0.1, 0.5, 1.0, 2.0],
            "sigma_b": [0.1, 0.5, 1.0, 2.0],
            "sigma_nu": [1.0, 5.0, 10.0, 20.0],
        }

    # Ensure inputs are properly shaped
    x_data = np.atleast_1d(x_data)
    successes = np.atleast_1d(successes)
    trials = np.atleast_1d(trials)

    # Calculate all possible hyperparameter combinations
    hyperparams_keys = list(hyperparams_grid.keys())
    hyperparams_values = [hyperparams_grid[key] for key in hyperparams_keys]

    # Generate all combinations
    import itertools

    all_combinations = list(itertools.product(*hyperparams_values))
    total_combinations = len(all_combinations)

    # Store best parameters and model
    best_log_likelihood = -np.inf
    best_params = None
    best_model_result = None

    # Grid search
    for i, combo in enumerate(all_combinations):
        current = i + 1

        # Create parameter dictionary for this combination
        prior_params = {key: value for key, value in zip(hyperparams_keys, combo)}

        if verbose:
            combo_str = ", ".join(
                [f"{key}={value}" for key, value in prior_params.items()]
            )
            print(f"Trying combination {current}/{total_combinations}: {combo_str}")

        # Fit model with these parameters
        try:
            model, idata = fit_beta_binomial_model(
                x_data,
                successes,
                trials,
                prior_params=prior_params,
                n_samples=n_samples,
                tune=tune,
                progressbar=progressbar,
            )

            # Extract log-likelihood from InferenceData
            if hasattr(idata, "log_likelihood") and "y_obs" in idata.log_likelihood:
                log_likelihood = idata.log_likelihood["y_obs"].mean().item()
            else:
                if verbose:
                    print("  Log likelihood not found in InferenceData.")
                log_likelihood = -np.inf  # Assign a poor score

            # Calculate metrics (optional, could skip in grid search for speed)
            # If calculating, ensure predict uses idata
            pred_result = predict(
                idata,  # Pass idata
                x_data,
                n_trials=trials[0],  # Assuming constant trials
                successes_true=successes,
                trials_true=trials,
            )
            # metrics = pred_result["metrics"]
            # Extract metrics safely
            metrics = pred_result.get("metrics", {})
            rmse_prob = metrics.get("rmse_prob", float("nan"))
            r2_prob = metrics.get("r2_prob", float("nan"))
            rmse_count = metrics.get("rmse_count", float("nan"))

            # Store results
            result = {**prior_params}  # Start with all parameters
            result.update(
                {
                    "log_likelihood": log_likelihood,
                    "rmse_prob": rmse_prob,
                    "r2_prob": r2_prob,
                    "rmse_count": rmse_count,
                }
            )
            results.append(result)

            # Check if this is the best model so far
            if log_likelihood > best_log_likelihood:
                best_log_likelihood = log_likelihood
                best_params = prior_params.copy()
                best_model_result = (model, idata)  # Store idata

                if verbose:
                    print(
                        f"  New best: log_likelihood={log_likelihood:.4f}, "
                        f"rmse_prob={rmse_prob:.4f}, "
                        f"r2_prob={r2_prob:.4f}"
                    )

        except Exception as e:
            if verbose:
                print(f"  Error with combination: {e}")

    # Convert results to DataFrame
    results_df = pd.DataFrame(results)

    # Plot results if requested
    if plot_results and not results_df.empty:
        plot_grid_search_results(results_df, hyperparams_keys)

    # Return results
    return {
        "results_df": results_df,
        "best_params": best_params,
        "best_model": best_model_result,
    }


def plot_grid_search_results(results_df, param_keys=None):
    """
    Plot the grid search results for Beta-Binomial model.

    Args:
        results_df: DataFrame with grid search results
        param_keys: List of hyperparameter keys to plot
    """
    if results_df.empty:
        print("No results to plot")
        return

    # If param_keys not provided, use all numeric columns except metrics
    if param_keys is None:
        metrics_cols = [
            "log_likelihood",
            "rmse_prob",
            "r2_prob",
            "rmse_count",
        ]
        param_keys = [col for col in results_df.columns if col not in metrics_cols]

    # Create figure with subplots
    n_params = len(param_keys)
    n_plots = n_params + 1  # +1 for the heatmap
    fig_rows = (n_plots + 1) // 2
    fig, axes = plt.subplots(fig_rows, 2, figsize=(14, 5 * fig_rows))

    # If only one row, ensure axes is 2D
    if fig_rows == 1:
        axes = axes.reshape(1, 2)

    # Flatten axes array for easier indexing
    axes_flat = axes.flatten()

    # Plot log-likelihood vs each parameter
    for i, param in enumerate(param_keys):
        if i >= len(axes_flat):
            break

        # Group by all other parameters
        other_params = [p for p in param_keys if p != param]

        if other_params:
            for group, group_df in results_df.groupby(other_params):
                # Format group name
                if len(other_params) == 1:
                    group = [group]  # Ensure group is iterable

                group_name = "-".join([f"{p}={g}" for p, g in zip(other_params, group)])

                # Keep group name reasonably short
                if len(group_name) > 30:
                    group_name = group_name[:27] + "..."

                # Sort by parameter value for cleaner plots
                plot_df = group_df.sort_values(param)

                axes_flat[i].plot(
                    plot_df[param], plot_df["log_likelihood"], "o-", label=group_name
                )
        else:
            # If only one parameter, just plot it directly
            plot_df = results_df.sort_values(param)
            axes_flat[i].plot(plot_df[param], plot_df["log_likelihood"], "o-")

        axes_flat[i].set_xlabel(param)
        axes_flat[i].set_ylabel("Log-likelihood")
        axes_flat[i].set_title(f"Log-likelihood vs {param}")
        axes_flat[i].grid(True, alpha=0.3)

        # Add legend only if there are multiple groups and it's not too cluttered
        if other_params and len(results_df.groupby(other_params)) < 10:
            axes_flat[i].legend(fontsize="small")

    # If we have at least two parameters, create a heatmap for the first two
    if len(param_keys) >= 2 and n_plots < len(axes_flat):
        heatmap_idx = n_params  # Last plot is for heatmap

        param1 = param_keys[0]
        param2 = param_keys[1]

        # Get best model for each param1/param2 combination
        if len(param_keys) > 2:
            # Group by all other parameters and find best combination
            other_params = param_keys[2:]
            best_models = results_df.sort_values(
                "log_likelihood", ascending=False
            ).drop_duplicates([param1, param2])
        else:
            best_models = results_df

        # Find unique values (sorted) for the heatmap
        param1_unique = sorted(results_df[param1].unique())
        param2_unique = sorted(results_df[param2].unique())

        # Create a matrix for the heatmap
        heatmap_data = np.zeros((len(param1_unique), len(param2_unique)))

        # Fill the matrix with log-likelihood values
        for _, row in best_models.iterrows():
            i = param1_unique.index(row[param1])
            j = param2_unique.index(row[param2])
            heatmap_data[i, j] = row["log_likelihood"]

        # Plot the heatmap
        im = axes_flat[heatmap_idx].imshow(heatmap_data, cmap="viridis")

        # Set tick labels
        axes_flat[heatmap_idx].set_xticks(np.arange(len(param2_unique)))
        axes_flat[heatmap_idx].set_yticks(np.arange(len(param1_unique)))
        axes_flat[heatmap_idx].set_xticklabels(param2_unique)
        axes_flat[heatmap_idx].set_yticklabels(param1_unique)

        # Set labels and title
        axes_flat[heatmap_idx].set_xlabel(param2)
        axes_flat[heatmap_idx].set_ylabel(param1)
        axes_flat[heatmap_idx].set_title(
            f"Log-likelihood heatmap: {param1} vs {param2}"
        )

        # Add colorbar
        plt.colorbar(im, ax=axes_flat[heatmap_idx], label="Log-likelihood")

    # Hide any unused subplots
    for i in range(n_plots, len(axes_flat)):
        axes_flat[i].axis("off")

    # Adjust layout
    plt.tight_layout()
    plt.show()


def optuna_search_beta_binomial_priors(
    x_data,
    successes,
    trials,
    n_trials=100,
    n_samples=1000,
    tune=500,
    verbose=True,
    plot_results=True,
    progressbar=True,
):
    """
    Perform hyperparameter optimization using Optuna for Beta-Binomial model.

    Args:
        x_data: Input data (capability differences)
        successes: Number of successful attacks (count data)
        trials: Number of trials for each data point
        n_trials: Number of Optuna trials to run
        n_samples: Number of MCMC samples for each fit
        tune: Number of tuning steps for each fit
        verbose: Whether to print progress
        plot_results: Whether to plot the results
        progressbar: Whether to display PyMC sampling progress bar

    Returns:
        Dictionary with:
            - results_df: DataFrame with all results
            - best_params: Dictionary with best parameters
            - best_model: Tuple (model, trace) for best parameters
    """
    import optuna
    import optuna.visualization

    # Initialize storage for results
    results = []
    best_log_likelihood = -np.inf
    best_params = None
    best_model_result = None

    def objective(trial):
        nonlocal best_log_likelihood, best_params, best_model_result

        # Define the hyperparameters to optimize
        prior_params = {
            "sigma_w": trial.suggest_float("sigma_w", 0.1, 10, log=True),
            "sigma_b": trial.suggest_float("sigma_b", 0.1, 10, log=True),
            "sigma_nu": trial.suggest_float("sigma_nu", 0.1, 10, log=True),
        }

        if verbose:
            print(f"\nTrial {trial.number + 1}/{n_trials}:")
            print("Parameters:", prior_params)

        try:
            # Fit model with these parameters
            model, idata = fit_beta_binomial_model(
                x_data,
                successes,
                trials,
                prior_params=prior_params,
                n_samples=n_samples,
                tune=tune,
                progressbar=progressbar,
            )

            # Extract log-likelihood from InferenceData
            if hasattr(idata, "log_likelihood") and "y_obs" in idata.log_likelihood:
                # Use the mean log likelihood across all posterior samples and data points
                log_likelihood = idata.log_likelihood["y_obs"].mean().item()
            else:
                if verbose:
                    print("  Log likelihood not found in InferenceData.")
                return float("-inf")  # Or handle error appropriately

            # Calculate metrics
            pred_result = predict(
                idata,
                x_data,
                n_trials=trials[0],
                successes_true=successes,
                trials_true=trials,
            )
            # Extract metrics, handling potential absence of the key
            metrics = pred_result.get("metrics", {})
            rmse_prob = metrics.get("rmse_prob", float("nan"))
            r2_prob = metrics.get("r2_prob", float("nan"))
            rmse_count = metrics.get("rmse_count", float("nan"))

            # Store results
            result = {**prior_params}
            result.update(
                {
                    "log_likelihood": log_likelihood,
                    "rmse_prob": rmse_prob,
                    "r2_prob": r2_prob,
                    "rmse_count": rmse_count,
                }
            )
            results.append(result)

            # Update best model if this is better
            if log_likelihood > best_log_likelihood:
                best_log_likelihood = log_likelihood
                best_params = prior_params.copy()
                best_model_result = (model, idata)

                if verbose:
                    print(
                        f"  New best: log_likelihood={log_likelihood:.4f}, "
                        f"rmse_prob={rmse_prob:.4f}, "
                        f"r2_prob={r2_prob:.4f}"
                    )

            return log_likelihood

        except Exception as e:
            if verbose:
                print(f"  Error in trial: {e}")
            return float("-inf")

    # Create and run Optuna study
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    # Convert results to DataFrame
    results_df = pd.DataFrame(results)

    # Plot results if requested
    if plot_results and not results_df.empty:
        try:
            # Plot optimization history
            optuna.visualization.plot_optimization_history(study)

            # Plot parameter importances
            optuna.visualization.plot_param_importances(study)

            # Plot parallel coordinate plot
            optuna.visualization.plot_parallel_coordinate(study)
        except Exception as e:
            if verbose:
                print(f"Warning: Could not create plots: {e}")

    # Return results
    return {
        "results_df": results_df,
        "best_params": best_params,
        "best_model": best_model_result,
        "study": study,  # Include the study for additional analysis if needed
    }


def find_best_beta_binomial_model_for_target(
    df,
    target_model_key,
    hyperparams_grid=None,
    capability_diff_col="capability_diff",
    n_samples=1000,
    tune=500,
    progressbar=True,
    use_optuna=False,
):
    """
    Find the best prior parameters for Beta-Binomial model for a specific target model.

    Args:
        df: DataFrame with model data
        target_model_key: Key of the target model to optimize for
        hyperparams_grid: Dictionary of hyperparameter lists to try
                          If None, defaults to standard grid
        capability_diff_col: Name of the capability difference column
        n_samples: Number of MCMC samples for each fit
        tune: Number of tuning steps for each fit
        progressbar: Whether to display PyMC sampling progress bar
        use_optuna: Whether to use Optuna optimization instead of grid search

    Returns:
        Dictionary with optimization results
    """
    # Filter data for the target model
    model_df = df[df["target_model_key"] == target_model_key]

    if len(model_df) == 0:
        print(f"No data found for target model {target_model_key}")
        return None

    # Extract data - convert ASR values to success counts and trials
    x_data = model_df[capability_diff_col].values

    # Assuming total_behaviors contains the number of trials
    # If ASR is already a probability, convert to counts
    trials = (
        model_df["total_behaviors"].values
        if "total_behaviors" in model_df.columns
        else np.full(len(model_df), 50)
    )
    successes = np.round(model_df["ASR"].values * trials).astype(int)

    print(f"Optimizing Beta-Binomial model for: {target_model_key}")
    print(f"Data points: {len(x_data)}")
    print(f"Using {'Optuna' if use_optuna else 'grid search'} optimization")

    if use_optuna:
        results = optuna_search_beta_binomial_priors(
            x_data,
            successes,
            trials,
            n_trials=50,
            n_samples=n_samples,
            tune=tune,
            verbose=True,
            plot_results=True,
            progressbar=progressbar,
        )
    else:
        results = grid_search_beta_binomial_priors(
            x_data,
            successes,
            trials,
            hyperparams_grid=hyperparams_grid,
            n_samples=n_samples,
            tune=tune,
            verbose=True,
            plot_results=True,
            progressbar=progressbar,
        )

    # Process results
    if results is None or results["best_model"] is None:
        print("Optimization failed or returned no best model.")
        return None

    print("Best Parameters:", results["best_params"])

    # Get metrics for best model
    best_model, best_idata = results["best_model"]
    pred_result = predict(
        best_idata,
        x_data,
        n_trials=trials[0],
        successes_true=successes,
        trials_true=trials,
    )
    metrics = pred_result.get("metrics", {})
    print("Best Model Metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value:.4f}")

    # Return best model details
    return {
        "best_params": results["best_params"],
        "best_model": best_model,
        "best_idata": best_idata,
        "metrics": metrics,
    }
