from sklearn.linear_model import Lasso, ElasticNet, ElasticNetCV, LinearRegression
from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.aggregate_responses import AggregateResponses
from modules.response_converter import Responses, BinaryExtendedResponses, ResponseUtils
from sklearn.model_selection import GridSearchCV
from scipy.spatial import distance
from scipy.stats import wasserstein_distance
from sklearn.metrics import root_mean_squared_error, r2_score
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
import pandas as pd
from joblib import dump, load
from typing import Callable

def save_model(model, save_path):
    """
    Save a fitted scikit-learn model.

    Args:
        model: fitted model (must have feature_names_)
        save_path (str): output .joblib file path
    """
    dump(model, save_path)

def load_model(load_path):
    """
    Load a model with embedded feature names.

    Returns:
        model
    """
    return load(load_path)

def fit_elastic_net_model(experiment, config, verbose = True, random_seed = 101):
    """
    Fit an ElasticNet model using YAML-driven config and experiment object.

    Returns:
        model: trained ElasticNet model
        best_alpha: float
        best_l1_ratio: float
        diagnostics: dict with error metrics and hyperparameter trace
    """
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    strategy = config["elasticnet"].get("validation", {}).get("strategy")
    cv_folds = config["elasticnet"].get("validation", {}).get("cv_folds", 5)

    alpha_expr = config["elasticnet"]["alpha_expr"]
    alphas = eval(alpha_expr, {"np": np})
    max_iter = config["elasticnet"]["max_iter"]
    l1_ratio_expr = config["elasticnet"].get("l1_ratio_expr", [0.1, 0.5, 0.9])
    l1_ratios = eval(l1_ratio_expr, {"np": np})

    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)
    target_col = 'aggregate'
    feature_cols = [col for col in df_train.columns if col != target_col]

    X_train = df_train[feature_cols].astype(float)
    y_train = df_train[target_col].astype(float).values
    X_val = df_val[feature_cols].astype(float)
    y_val = df_val[target_col].astype(float).values

    diagnostics = {}

    if strategy == "cv":
        if verbose:
            print(f"Using {cv_folds}-fold CV to select alpha and l1_ratio.")

        X_all = pd.concat([X_train, X_val])
        y_all = np.concatenate([y_train, y_val])

        enet_cv = ElasticNetCV(
            alphas=alphas,
            l1_ratio=l1_ratios,
            cv=cv_folds,
            max_iter=max_iter,
            random_state=random_seed
        )
        enet_cv.fit(X_all, y_all)

        mean_mse_grid = enet_cv.mse_path_.mean(axis=2)  # shape: (n_l1, n_alpha)
        std_mse_grid = enet_cv.mse_path_.std(axis=2)


        best_alpha = enet_cv.alpha_
        best_l1_ratio = enet_cv.l1_ratio_

        if verbose:
            print(f"Best alpha: {best_alpha:.2e}, Best l1_ratio: {best_l1_ratio:.2f}")

        model = ElasticNet(alpha=best_alpha, l1_ratio=best_l1_ratio, max_iter=max_iter)
        model.fit(X_all, y_all)

        diagnostics["mean_mse"] = mean_mse_grid
        diagnostics["std_mse"] = std_mse_grid

        diagnostics['cv_mse_path'] = enet_cv.mse_path_
        diagnostics['alphas'] = enet_cv.alphas_
        diagnostics['l1_ratios'] = enet_cv.l1_ratio

    else:
        raise NotImplementedError("Only CV-based selection is currently supported for ElasticNet.")

    return model, best_alpha, best_l1_ratio, diagnostics


def fit_lasso_model(experiment, config, verbose=True):
    """
    Fit a Lasso model using YAML-driven design and experiment object.

    Args:
        experiment (object): An object with a `.get_dataframe_by_split(split, proxy_only)` method.
        config (dict): YAML config dictionary. Must contain 'lasso', 'split_settings', and 'validation'.
        verbose (bool): Whether to print progress messages.

    Returns:
        tuple: A 3-tuple containing:
            model (Lasso): Trained scikit-learn Lasso model.
            best_alpha (float): Best alpha value selected by CV or hold-out.
            diagnostics (dict): Evaluation metrics.

            The structure of `diagnostics` depends on the validation strategy:

            If using CV:
                - 'mean_rmse' (float): Mean RMSE across folds.
                - 'std_rmse' (float): Standard deviation of RMSE.
                - 'alphas' (list of float): Alpha values tested.

            If using hold-out:
                - 'train_errors' (list of float): Training errors per alpha.
                - 'val_errors' (list of float): Validation errors per alpha.
                - 'alphas' (list of float): Alpha values tested.
    """
    # Load split config
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    strategy = config["lasso"].get("validation", {}).get("strategy")
    cv_folds = config["lasso"].get("validation", {}).get("cv_folds", 5)

    # Load alpha config
    alpha_expr = config["lasso"]["alpha_expr"]
    alphas = eval(alpha_expr, {"np": np})
    max_iter = config["lasso"]["max_iter"]

    # Prepare data from experiment
    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)

    target_col = 'aggregate'
    feature_cols = [col for col in df_train.columns if col != target_col]

    X_train = df_train[feature_cols].astype(float)
    y_train = df_train[target_col].astype(float).values
    X_val = df_val[feature_cols].astype(float)
    y_val = df_val[target_col].astype(float).values

    # Fit model
    diagnostics = {}
    if strategy == "cv":
        if verbose:
            print(f"Using {cv_folds}-fold cross-validation to select alpha.")
        X_all = pd.concat([X_train, X_val])
        y_all = np.concatenate([y_train, y_val])

        grid = GridSearchCV(
            Lasso(max_iter=max_iter),
            param_grid={'alpha': alphas},
            cv=cv_folds,
            scoring='neg_mean_squared_error'
        )
        grid.fit(X_all, y_all)

        best_alpha = grid.best_params_['alpha']
        if verbose:
            print(f"Best alpha (CV): {best_alpha:.2e}")

        model = Lasso(alpha=best_alpha, max_iter=max_iter)
        model.fit(X_all, y_all)

        diagnostics['mean_rmse'] = -grid.cv_results_['mean_test_score']
        diagnostics['std_rmse'] = grid.cv_results_['std_test_score']
        diagnostics['alphas'] = grid.cv_results_['param_alpha'].data

    else:  # Hold-out strategy
        if verbose:
            print("Using hold-out validation (train/valid split) to select alpha.")
        train_errors, val_errors = [], []

        for alpha in alphas:
            model = Lasso(alpha=alpha, max_iter=max_iter)
            model.fit(X_train, y_train)

            y_train_pred = model.predict(X_train)
            y_val_pred = model.predict(X_val)

            train_errors.append(root_mean_squared_error(y_train, y_train_pred))
            val_errors.append(root_mean_squared_error(y_val, y_val_pred))

        best_idx = np.argmin(val_errors)
        best_alpha = alphas[best_idx]
        if verbose:
            print(f"Best alpha (Hold-out): {best_alpha:.2e}, RMSE = {val_errors[best_idx]:.4f}")

        X_all = pd.concat([X_train, X_val])
        y_all = np.concatenate([y_train, y_val])
        model = Lasso(alpha=best_alpha, max_iter=max_iter)
        model.fit(X_all, y_all)

        diagnostics['train_errors'] = train_errors
        diagnostics['val_errors'] = val_errors
        diagnostics['alphas'] = alphas

    return model, best_alpha, diagnostics

def fit_post_selection_ols(model, X_all, y_all, feature_cols):
    """
    Fit an OLS model using only the non-zero features selected by a Lasso or ElasticNet model.

    Parameters:
        model: fitted Lasso or ElasticNet model
        X_all: np.ndarray of full input data
        y_all: target array
        feature_cols: list of all feature names

    Returns:
        ols_model: trained LinearRegression model
        selected_features: list of feature names retained by L1 or ElasticNet model
    """
    coef_mask = model.coef_ != 0
    selected_features = [feat for feat, keep in zip(feature_cols, coef_mask) if keep]
    X_selected = pd.DataFrame(X_all, columns=feature_cols)[selected_features]

    ols_model = LinearRegression()
    ols_model.fit(X_selected, y_all)

    return ols_model, selected_features

def plot_lasso_diagnostics(diagnostics, best_alpha, strategy="cv"):
    """
    Plot Lasso RMSE diagnostics for either cross-validation or hold-out validation.

    Parameters:
        diagnostics: dict returned by `fit_lasso_model`, containing relevant RMSE data
        best_alpha: float, the selected regularization parameter
        strategy: str, "cv" or "holdout" (used for branching)
    """
    with sns.axes_style("whitegrid", rc={"grid.linestyle": "--", "grid.linewidth": 0.3, "grid.alpha": 1}), \
         sns.plotting_context("notebook", font_scale=1.15):
        if strategy == "cv":
            mean_rmse = diagnostics['mean_rmse']
            std_rmse = diagnostics['std_rmse']
            alpha_vals = diagnostics['alphas']

            fig, ax = plt.subplots(figsize=(6, 4))

            ax.plot(alpha_vals, mean_rmse, marker='o', linestyle='-', linewidth=2,
                    color='#26567e', label='Mean CV RMSE')
            ax.fill_between(alpha_vals,
                            mean_rmse - std_rmse,
                            mean_rmse + std_rmse,
                            alpha=0.5,
                            color='#ccd5dd',
                            label='± 1 STD')

            ax.axvline(best_alpha, linestyle='--', color='#f50000', linewidth=2, alpha=0.6,
                    label=f'Best $\\alpha = {best_alpha:.2e}$')

            ax.set_xscale('log')
            ax.set_xlabel("$\\alpha$ (log scale)")
            ax.set_ylabel("RMSE (CV mean ± std)")
            ax.set_title("Cross-Validated RMSE by Lasso Alpha ($\\alpha$)")
            ax.legend()
            fig.tight_layout()

            return fig

        elif strategy == "holdout":
            train_errors = diagnostics['train_errors']
            val_errors = diagnostics['val_errors']
            alphas = diagnostics['alphas']

            fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

            # Train RMSE
            axes[0].plot(alphas, train_errors, marker='o', linestyle='-', linewidth=2,
                        color='#26567e', label='Train RMSE')
            axes[0].axvline(best_alpha, color='#f50000', linestyle='--', linewidth=2, alpha=0.6)
            axes[0].set_xscale('log')
            axes[0].set_title("Train RMSE")
            axes[0].set_xlabel("$\\alpha$ (log scale)")
            axes[0].set_ylabel("RMSE")
            axes[0].grid(True)

            # Validation RMSE
            axes[1].plot(alphas, val_errors, marker='o', linestyle='-', linewidth=2,
                        color='#77ab59', label='Validation RMSE')
            axes[1].axvline(best_alpha, color='#f50000', linestyle='--', linewidth=2, alpha=0.6,
                            label=f'Best $\\alpha = {best_alpha:.2e}$')
            axes[1].set_xscale('log')
            axes[1].set_title("Validation RMSE")
            axes[1].set_xlabel("$\\alpha$ (log scale)")
            axes[1].grid(True)

            # Unified legend below
            handles, labels = [], []
            for ax in axes:
                h, l = ax.get_legend_handles_labels()
                handles += h
                labels += l

            fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.0),
                    ncol=3, frameon=False)

            plt.suptitle("Lasso Regression: RMSE by $\\alpha$ (Train vs Validation)", fontsize=14, y=0.91)
            plt.tight_layout(rect=[0, 0.05, 1, 0.95])
            return fig

def plot_lasso_predictions(model, experiment, config):
    """
    Plot predicted vs true aggregate responses for train/val and test data.

    Parameters:
        model: trained sklearn Lasso model
        experiment: object with `.get_dataframe_by_split(split, proxy_only)` method
        config: loaded YAML config dict (used to determine proxy_only flag)
    """
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    split_test = config["split_settings"]["test_split"]

    # Load data and prepare inputs
    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)
    df_test = experiment.get_dataframe_by_split(split_test, proxy_only=use_proxy_only)

    target_col = 'aggregate'
    feature_cols = [col for col in df_train.columns if col != target_col]

    X_trainval = pd.concat([df_train, df_val])[feature_cols].astype(float).values
    y_trainval = pd.concat([df_train, df_val])[target_col].astype(float).values
    X_test = df_test[feature_cols].astype(float).values
    y_test = df_test[target_col].astype(float).values

    # Predictions and metrics
    y_pred_trainval = model.predict(X_trainval)
    y_pred_test = model.predict(X_test)
    rmse_train = root_mean_squared_error(y_trainval, y_pred_trainval)
    rmse_test = root_mean_squared_error(y_test, y_pred_test)
    r2_train = r2_score(y_trainval, y_pred_trainval)
    r2_test = r2_score(y_test, y_pred_test)

    # Plot
    with sns.axes_style("whitegrid", rc={"grid.linestyle": "--", "grid.linewidth": 0.3, "grid.alpha": 1}), \
         sns.plotting_context("notebook", font_scale=1.15):
        fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

        for ax, y_true, y_pred, title, color, rmse, r2 in zip(
            axes,
            [y_trainval, y_test],
            [y_pred_trainval, y_pred_test],
            ["Train-time performance", "Test-time performance"],
            ['#1D4F91', '#f4b505'],
            [rmse_train, rmse_test],
            [r2_train, r2_test],
        ):
            ax.scatter(y_pred, y_true, alpha=0.5, color=color)
            ax.plot([0, 1], [0, 1], linewidth=2, color='#a3b09a')
            ax.set_xlabel(r"$y$ (ground truth)")
            ax.set_ylabel(r"$\hat{y}$ (prediction)")
            ax.set_title(title)
            ax.text(0.05, 0.9, f'RMSE = {rmse:.2f}', transform=ax.transAxes, fontsize=10)
            ax.text(0.05, 0.83, f'R² = {r2:.2f}', transform=ax.transAxes, fontsize=10)

        plt.suptitle("Lasso Regression: Predicting Aggregate Response with Proxy Agents", fontsize=14, y=0.93)
        plt.tight_layout()
        return fig
    
def plot_model_predictions(model, experiment, config, selected_features=None, method_label="Regression"):
    """
    Plot predicted vs. true aggregate responses for train/val and test data.

    Parameters:
        model: Trained model (Lasso, ElasticNet, or OLS)
        experiment: Provides split-based data
        config: YAML config dict for split control
        selected_features: List of features used (optional; needed for OLS)
        method_label: Label for title and legend (e.g., "ElasticNet", "OLS")
    """
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    split_test = config["split_settings"]["test_split"]
    target_col = "aggregate"

    # Get feature names
    feature_cols = get_model_features(model, experiment, target_col, selected_features, use_proxy_only)

    # Load data
    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)
    df_test = experiment.get_dataframe_by_split(split_test, proxy_only=use_proxy_only)

    # Prepare arrays
    df_trainval = pd.concat([df_train, df_val])
    X_trainval = df_trainval[feature_cols].astype(float)
    y_trainval = df_trainval[target_col].astype(float).values
    X_test = df_test[feature_cols].astype(float)
    y_test = df_test[target_col].astype(float).values

    # Predict
    y_pred_trainval = model.predict(X_trainval)
    y_pred_test = model.predict(X_test)
    rmse_train = root_mean_squared_error(y_trainval, y_pred_trainval)
    rmse_test = root_mean_squared_error(y_test, y_pred_test)
    r2_train = r2_score(y_trainval, y_pred_trainval)
    r2_test = r2_score(y_test, y_pred_test)

    # Plot
    with sns.axes_style("whitegrid", rc={"grid.linestyle": "--", "grid.linewidth": 0.3, "grid.alpha": 1}), \
         sns.plotting_context("notebook", font_scale=1.15):
        fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

        for ax, y_true, y_pred, title, color, rmse, r2 in zip(
            axes,
            [y_trainval, y_test],
            [y_pred_trainval, y_pred_test],
            ["Train-time performance", f"Test-time performance"],
            ['#1D4F91', '#f4b505'],
            [rmse_train, rmse_test],
            [r2_train, r2_test],
        ):
            ax.scatter(y_pred, y_true, alpha=0.5, color=color)
            ax.plot([0, 1], [0, 1], linewidth=2, color='#a3b09a')
            ax.set_xlabel(r"$\hat{y}$ (prediction)")
            ax.set_ylabel(r"$y$ (ground truth)")
            ax.set_title(title)
            ax.text(0.05, 0.9, f'RMSE = {rmse:.2f}', transform=ax.transAxes, fontsize=10)
            ax.text(0.05, 0.83, f'R² = {r2:.2f}', transform=ax.transAxes, fontsize=10)

        plt.suptitle(f"{method_label}: Predicting Aggregate Response with Proxy Agents", fontsize=14, y=0.93)
        plt.tight_layout()
        return fig

def compute_prediction_metrics(model, experiment, config, selected_features=None):
    """
    Compute prediction performance of a model using trainval and test splits.

    Returns a dictionary with:
        - trainval_mse
        - test_mse
        - trainval_r2
        - test_r2
    """
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    split_test = config["split_settings"]["test_split"]
    target_col = "aggregate"

    feature_cols = get_model_features(model, experiment, target_col, selected_features, use_proxy_only)

    # Load data
    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)
    df_test = experiment.get_dataframe_by_split(split_test, proxy_only=use_proxy_only)

    df_trainval = pd.concat([df_train, df_val])
    X_trainval = df_trainval[feature_cols].astype(float)
    y_trainval = df_trainval[target_col].astype(float).values
    X_test = df_test[feature_cols].astype(float)
    y_test = df_test[target_col].astype(float).values

    agg_dist = compute_aggregate_responses(model, config)
    survey = agg_dist["survey"]
    one_minus_jsd_map = _compute_one_minus_jsd_per_qid(agent_aggregate_responses = agg_dist["agent_aggregate_responses"], 
                                           human_aggregate_responses = agg_dist["human_aggregate_responses"])
    emd_map = _compute_emd_per_qid(agent_aggregate_responses = agg_dist["agent_aggregate_responses"], 
                                            human_aggregate_responses = agg_dist["human_aggregate_responses"],
                                            survey = survey)
    acc_map = _compute_majority_accuracy_per_qid(
        agg_dist["agent_aggregate_responses"], 
        agg_dist["human_aggregate_responses"]
    )

    q_train = [q["id"] for q in survey.get_questions_by_split(split_train)]
    q_val   = [q["id"] for q in survey.get_questions_by_split(split_val)]
    q_test  = [q["id"] for q in survey.get_questions_by_split(split_test)]
    q_trainval = list(set(q_train + q_val))


    return {
        "trainval_mse": model.score(X_trainval, y_trainval),
        "test_mse": model.score(X_test, y_test),
        "trainval_r2": model.r2(X_trainval, y_trainval),
        "test_r2": model.r2(X_test, y_test),
        "trainval_one_minus_jsd": _mean_over_qids(one_minus_jsd_map, q_trainval),
        "test_one_minus_jsd": _mean_over_qids(one_minus_jsd_map, q_test),
        "trainval_emd": _mean_over_qids(emd_map,  q_trainval),
        "test_emd": _mean_over_qids(emd_map,  q_test),
        "trainval_accuracy": _mean_over_qids(acc_map, q_trainval),
        "test_accuracy": _mean_over_qids(acc_map, q_test)
    }

def compute_aggregate_responses(model, config, survey = None, responses = None):
    """
    Compute aggregate responses of agents per qid and returns together with the human aggregate responses

    Returns a dictionary with:
        - human_aggregate_responses: {original_qid -> {answer_label -> prob}}
        - agent_aggregate_responses: {original_qid -> {answer_label -> prob}}

    """
# Load more data for the computation of aggregate responses
    paths = config["paths"]
    if survey is None:
        survey = Survey(csv_path=paths['survey_csv'], config_path=paths['survey_yaml'])
    survey_bin = BinaryExtendedSurvey.from_survey(survey)
    if responses is None:
        responses = Responses(source=paths['responses_csv'], survey=survey, output_format='answer')
    agent_weights = model.coef_dict_
    aggregate = AggregateResponses(survey=survey_bin, json_path=paths['aggregate_json'])
    agent_aggregate_responses = ResponseUtils.aggregate_weighted_responses(responses, agent_weights)
    return {
        "human_aggregate_responses": aggregate.raw,
        "agent_aggregate_responses": agent_aggregate_responses,
        "survey": survey
    }

def _to_prob_vec(v):
    """Convert a 1D vector of probabilities or percentages to a prob. vector that sums to 1. Safeguard function."""
    v = np.asarray(v, dtype=float)
    s = v.sum()
    if s == 0:
        return v
    if s > 1.5:  # treat as percentages (≈100)
        v = v / 100.0
        s = v.sum()
    return v / s if s > 0 else v

def _align_support_dicts(human_dict, agent_dict, order=None):
    """
    Align dict supports and return aligned probability vectors and an index support.

    If `order` is provided, it is treated as the canonical sequence of category codes
    (e.g. from the survey's option_mapping). Any extra categories present only on 
    the agent side are appended at the end.
    """
    if order is not None:
        # Start from canonical order, but drop codes that don't appear anywhere
        all_keys = set(human_dict.keys()) | set(agent_dict.keys())
        keys = [k for k in order if k in all_keys]

        # Append any weird extra keys the agents invented
        for k in all_keys:
            if k not in keys:
                keys.append(k)
    else:
        # Fallback: use insertion order from human_dict, then agent_dict
        keys = list(human_dict.keys())
        for k in agent_dict.keys():
            if k not in human_dict:
                keys.append(k)

    p = np.array([human_dict.get(k, 0.0) for k in keys], dtype=float)
    q = np.array([agent_dict.get(k, 0.0) for k in keys], dtype=float)

    p = _to_prob_vec(p)
    q = _to_prob_vec(q)

    idx = np.arange(len(keys), dtype=float)
    return p, q, idx

def _compute_one_minus_jsd_per_qid(agent_aggregate_responses, human_aggregate_responses):
    """
    Returns: dict {original_qid -> (1 - JSD)} over joint QIDs.
    """
    out = {}
    joint = set(human_aggregate_responses) & set(agent_aggregate_responses)
    for qid in joint:
        h = human_aggregate_responses[qid]
        a = agent_aggregate_responses[qid]
        try:
            p, q, _ = _align_support_dicts(h, a)
            out[qid] = 1.0 - float(distance.jensenshannon(p, q))
        except Exception:
            # skip problematic qid
            continue
    return out

def _compute_emd_per_qid(agent_aggregate_responses, human_aggregate_responses, survey):
    """
    Returns: dict {original_qid -> EMD} over joint QIDs.
    """
    out = {}
    joint = set(human_aggregate_responses) & set(agent_aggregate_responses)
    for qid in joint:
        h = human_aggregate_responses[qid]
        a = agent_aggregate_responses[qid]
        try:
            order = survey.get_option_order_text(qid)
            p, q, idx = _align_support_dicts(h, a, order=order)
            emd_raw = float(wasserstein_distance(idx, idx, u_weights=p, v_weights=q))

            L = len(idx)
            max_emd = L - 1 if L > 1 else 1
            out[qid] = emd_raw / max_emd
        except Exception:
            continue
    return out


def _mean_over_qids(metric_map: dict, qids: list[str]) -> float:
    vals = [metric_map[q] for q in qids if q in metric_map and np.isfinite(metric_map[q])]
    return float(np.mean(vals)) if vals else float('nan')

def _compute_majority_accuracy_per_qid(agent_aggregate_responses, human_aggregate_responses):
    """
    Returns dict {qid -> 0/1}, indicating whether predicted majority matches human majority.
    """
    out = {}
    joint = set(human_aggregate_responses) & set(agent_aggregate_responses)
    for qid in joint:
        h = human_aggregate_responses[qid]
        a = agent_aggregate_responses[qid]
        try:
            # Ensure supports are aligned and prob-normalized
            p, q, keys = _align_support_dicts(h, a)
            
            # p = human distribution, q = model distribution
            human_majority = int(np.argmax(p))
            model_majority = int(np.argmax(q))
            
            out[qid] = 1.0 if (human_majority == model_majority) else 0.0
        except Exception:
            continue
    return out


def get_model_features(model, experiment, target_col="aggregate", selected_features=None, proxy_only = False):
    """Infer feature columns used for model prediction."""
    if selected_features is not None:
        return selected_features
    df = experiment.get_dataframe_by_split("train", proxy_only)
    return [col for col in df.columns if col != target_col]
    
def compute_mode_summary(model, endowments):
    """
    Compute per-mode selection stats from a fitted model.

    Returns:
        dict: {
            mode: {
                "num_selected": int,
                "total_weight": float,
                "fraction_selected": float
            }
        }
    """
    if not model.is_fitted:
        raise ValueError("Model has not been fit yet.")

    # Get support vector and coefficients
    try:
        support_mask = model.get_support()
    except AttributeError:
        support_mask = model.coef_ != 0

    eids = model.feature_names_
    beta = model.beta_

    # Create eid → coef dict for nonzero support
    eid_to_coef = {
        eid: abs(coef)
        for eid, coef, selected in zip(eids, beta, support_mask)
        if selected
    }

    mode_summary = {}
    for mode, entries in endowments.group_by_mode().items():
        total_in_mode = len(entries)
        selected_weights = [
            eid_to_coef[e["eid"]]
            for e in entries
            if e["eid"] in eid_to_coef
        ]
        num_selected = len(selected_weights)
        total_weight = sum(selected_weights)
        avg_weight = total_weight / num_selected if num_selected > 0 else 0.0

        mode_summary[mode] = {
            "num_selected": num_selected,
            "total_weight": total_weight,
            "fraction_selected": num_selected / total_in_mode if total_in_mode > 0 else 0.0,
            "avg_weight": avg_weight
        }

    return mode_summary


def plot_elasticnet_diagnostics(diagnostics, best_alpha, best_l1_ratio, style="2D"):
    if style == "2D":
        return _plot_elasticnet_heatmap(diagnostics, best_alpha, best_l1_ratio)
    elif style == "3D":
        return _plot_elasticnet_3d_static(diagnostics, best_alpha, best_l1_ratio)
    elif style == "interactive":
        return _plot_elasticnet_3d_interactive(diagnostics, best_alpha, best_l1_ratio)
    else:
        raise ValueError(f"Unknown plot style: {style}. Choose from '2D', '3D', or 'interactive'.")

def _plot_elasticnet_heatmap(diagnostics, best_alpha, best_l1_ratio):
    mean_mse = diagnostics["mean_mse"]
    alphas = diagnostics["alphas"]
    l1_ratios = diagnostics["l1_ratios"]

    best_idx = np.unravel_index(np.argmin(mean_mse), mean_mse.shape)
    best_val = mean_mse[best_idx]

    with sns.axes_style("whitegrid", rc={"grid.linestyle": "--", "grid.linewidth": 0.3, "grid.alpha": 1}), \
         sns.plotting_context("notebook", font_scale=1.15):

        plt.figure(figsize=(12, 6))
        ax = sns.heatmap(
            mean_mse,
            xticklabels=np.round(alphas, 3),
            yticklabels=np.round(l1_ratios, 2),
            cmap="YlGnBu",
            cbar_kws={'label': 'Mean CV MSE'},
            annot=False
        )

        ax.plot(best_idx[1] + 0.5, best_idx[0] + 0.5, 'ro', markersize=8, label='Best')
        ax.text(best_idx[1] + 0.5, best_idx[0] + 1, f'{best_val:.4f}', color='red',
                ha='center', va='center', fontweight='bold')

        xtick_step = 5
        n_alphas = len(alphas)
        tick_indices = list(range(0, n_alphas, xtick_step))
        if (n_alphas - 1) not in tick_indices:
            tick_indices.append(n_alphas - 1)

        ax.set_xticks([i + 0.5 for i in tick_indices])
        ax.set_xticklabels([f"{alphas[i]:.3f}" for i in tick_indices], rotation=45, ha='right')
        ax.set_xlabel("alpha")
        ax.set_ylabel("l1_ratio")
        ax.set_title("ElasticNetCV Cross-Validated MSE Heatmap")
        ax.legend()
        plt.tight_layout()
        return plt.gcf()

def _plot_elasticnet_3d_static(diagnostics, best_alpha, best_l1_ratio):
    mean_mse = diagnostics["mean_mse"]
    alphas = diagnostics["alphas"]
    l1_ratios = diagnostics["l1_ratios"]
    A, L = np.meshgrid(alphas, l1_ratios)

    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111, projection='3d')

    surf = ax.plot_surface(
        np.log10(A), L, mean_mse,
        cmap='YlGnBu', edgecolor=None, linewidth=0.5, alpha=0.75, antialiased=True
    )

    best_idx = np.unravel_index(np.argmin(mean_mse), mean_mse.shape)
    best_mse = mean_mse[best_idx]

    ax.scatter(np.log10(best_alpha), best_l1_ratio, best_mse, color='red', s=50, label='Best')
    ax.set_xlabel("log10(alpha)")
    ax.set_ylabel("l1_ratio")
    ax.set_zlabel("Mean CV MSE")
    ax.set_title("Cross-Valid MSE by Elastic Net penalty parameter ($\\alpha$) and l1_ratio")
    ax.legend()
    fig.colorbar(surf, shrink=0.5, aspect=10, label='Mean CV MSE')
    plt.tight_layout()
    return fig

def _plot_elasticnet_3d_interactive(diagnostics, best_alpha, best_l1_ratio):
    mean_mse = diagnostics["mean_mse"]
    std_mse = diagnostics["std_mse"]
    alphas = diagnostics["alphas"]
    l1_ratios = diagnostics["l1_ratios"]
    A, L = np.meshgrid(np.log10(alphas), l1_ratios)

    fig = go.Figure()
    fig.add_trace(go.Surface(
        z=mean_mse, x=A, y=L,
        colorscale='YlGnBu', name='Mean MSE', colorbar=dict(title='Mean CV MSE'),
        showscale=True, hoverinfo='x+y+z', legendgroup='mean', showlegend=True
    ))
    fig.add_trace(go.Surface(
        z=mean_mse + std_mse, x=A, y=L,
        surfacecolor=mean_mse + std_mse, colorscale='Greys', opacity=0.3,
        showscale=False, name='+1 SD', legendgroup='band', showlegend=True, hoverinfo='skip'
    ))
    fig.add_trace(go.Surface(
        z=mean_mse - std_mse, x=A, y=L,
        surfacecolor=mean_mse - std_mse, colorscale='Greys', opacity=0.3,
        showscale=False, name='-1 SD', legendgroup='band', showlegend=True, hoverinfo='skip'
    ))
    fig.add_trace(go.Scatter3d(
        x=[np.log10(best_alpha)], y=[best_l1_ratio], z=[mean_mse.min()],
        mode='markers+text', marker=dict(size=6, color='red', symbol='circle'),
        text=["Best"], textposition="top center", name='Best', showlegend=True
    ))
    fig.update_layout(
        title=dict(text='ElasticNetCV Error Surface with ±1 SD', x=0.5),
        scene=dict(
            xaxis_title='log10(alpha)', yaxis_title='l1_ratio', zaxis_title='Mean CV MSE'
        ),
        legend=dict(x=0.02, y=0.98, bgcolor='rgba(255,255,255,0.7)', bordercolor='gray', borderwidth=1),
        width=900, height=700, margin=dict(l=0, r=0, b=0, t=50)
    )
    fig.show()
    return fig

def summarize_sparse_model_selection(model, experiment, endowments, top_n=None, target_col="aggregate", verbose=True, model_name="Lasso"):
    """
    Summarize which proxy personae were selected by a sparse linear model (Lasso or ElasticNet).

    Parameters:
        model: trained sklearn model with `.coef_` attribute (Lasso or ElasticNet)
        experiment: object with `.get_dataframe_by_split(split, proxy_only)` method
        endowments: object with `.get_endowments_by_role(role)` method
        top_n: number of top non-zero coefficients to print
        target_col: name of target variable (default is 'aggregate')
        verbose: whether to print the summary
        model_name: str, model label to use in printouts (e.g., "Lasso", "ElasticNet")
    """
    df = experiment.get_dataframe_by_split("train", proxy_only=True)
    feature_cols = [col for col in df.columns if col != target_col]

    coef_df = pd.DataFrame({
        'feature': feature_cols,
        'coefficient': model.coef_
    })

    nonzero_df = coef_df[coef_df['coefficient'] != 0].copy()
    nonzero_df = nonzero_df.sort_values(by='coefficient', key=abs, ascending=False)

    num_selected = len(nonzero_df)
    num_candidates = len(endowments.get_endowments_by_role('proxy'))
    num_ground_truth = len(endowments.get_endowments_by_role('ground_truth'))

    if verbose:
        if top_n:
            print(nonzero_df.head(top_n))
        else:
            print(nonzero_df)
        print(
            f"\n{model_name} selected {num_selected} proxy personae "
            f"out of {num_candidates} candidates to emulate the aggregate target, "
            f"which was derived from {num_ground_truth} ground-truth personae."
        )

def assign_model_weights_to_endowments(model, experiment, endowments, target_col="aggregate", selected_features=None):
    """
    Assign model coefficients as weights to endowments (by matching feature names to eids).

    This function works for Lasso, ElasticNet, and OLS (LinearRegression) post-selection refits.

    Parameters:
        model: Trained sklearn model with `.coef_` attribute
        experiment: Provides feature names via `.get_dataframe_by_split(...)`
        endowments: Endowments or ActiveEndowments instance
        target_col: Name of target variable (to exclude from features)
        selected_features: Optional. If model was trained on a subset (e.g., OLS), pass those feature names.
    """
    # Default case: use all features from training split
    df = experiment.get_dataframe_by_split("train", proxy_only=True)
    all_feature_cols = [col for col in df.columns if col != target_col]

    # Determine which features were used in training
    if selected_features is None:
        feature_cols = all_feature_cols
    else:
        feature_cols = selected_features

    # Safety check
    if len(feature_cols) != len(model.coef_):
        raise ValueError(
            f"Mismatch between number of model coefficients ({len(model.coef_)}) "
            f"and selected features ({len(feature_cols)})."
        )

    coef_map = dict(zip(feature_cols, model.coef_))
    endowments.update_weights(coef_map)

def plot_model_predictions_interactive(model, experiment, config, selected_features=None, method_label="Regression"):
    """
    Plot predicted vs. true aggregate responses for train/val and test data (interactive).

    Parameters:
        model: Trained model (Lasso, ElasticNet, or OLS)
        experiment: Provides split-based data
        config: YAML config dict for split control
        selected_features: List of features used (optional; needed for OLS)
        method_label: Label for title and legend (e.g., "ElasticNet", "OLS")
    """
    use_proxy_only = config["split_settings"]["use_proxy_only"]
    split_train, split_val = config["split_settings"]["train_val_split"]
    split_test = config["split_settings"]["test_split"]
    target_col = "aggregate"

    # Get feature names
    feature_cols = get_model_features(model, experiment, target_col, selected_features)

    # Load data
    df_train = experiment.get_dataframe_by_split(split_train, proxy_only=use_proxy_only)
    df_val = experiment.get_dataframe_by_split(split_val, proxy_only=use_proxy_only)
    df_test = experiment.get_dataframe_by_split(split_test, proxy_only=use_proxy_only)

    df_trainval = pd.concat([df_train, df_val])
    X_trainval = df_trainval[model.feature_names_].astype(float)
    y_trainval = df_trainval[target_col].astype(float).values
    X_test = df_test[model.feature_names_].astype(float)
    y_test = df_test[target_col].astype(float).values

    # Predict
    y_pred_trainval = model.predict(X_trainval)
    y_pred_test = model.predict(X_test)

    # Color scheme
    color_train_val = "#96ccf4"
    color_test = "#ffd700"
    color_train_val_em = "#0448a3"
    color_test_em = "#d2691e"
    # For Train/Val split
    error_trainval = np.abs(y_pred_trainval - y_trainval)
    color_trainval = np.where(error_trainval > 0.2, color_train_val_em, color_train_val)  # large error in light purple

    # For Test split
    error_test = np.abs(y_pred_test - y_test)
    color_test = np.where(error_test > 0.2, color_test_em, color_test)  # large error in red

    qids_trainval = df_trainval.index.astype(str).tolist()
    qids_test = df_test.index.astype(str).tolist()

    rmse_train = root_mean_squared_error(y_trainval, y_pred_trainval)
    rmse_test = root_mean_squared_error(y_test, y_pred_test)
    r2_train = r2_score(y_trainval, y_pred_trainval)
    r2_test = r2_score(y_test, y_pred_test)

    # Create two subplots
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Train/Val Performance", "Test Performance"])

    fig.add_trace(
        go.Scatter(
            x=y_pred_trainval,
            y=y_trainval,
            mode="markers",
            marker=dict(color=color_trainval, size=8),
            text=qids_trainval,
            name="Train/Val",
            showlegend=False,
            hovertemplate="QID: %{text}<br>Pred: %{x:.2f}<br>True: %{y:.2f}<extra></extra>",
        ),
        row=1, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=y_pred_test,
            y=y_test,
            mode="markers",
            marker=dict(color=color_test, size=8),
            text=qids_test,
            name="Test",
            showlegend=False,
            hovertemplate="QID: %{text}<br>Pred: %{x:.2f}<br>True: %{y:.2f}<extra></extra>",
        ),
        row=1, col=2
    )
    # Dummy trace for Train/Val: small error
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(color=color_train_val, size=8),
            name="Train/Val: |error| ≤ 0.2",
            showlegend=True
        )
    )

    # Dummy trace for Train/Val: large error
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(color=color_train_val_em, size=8),
            name="Train/Val: |error| > 0.2",
            showlegend=True
        )
    )

    # Dummy trace for Test: small error
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(color=color_test, size=8),
            name="Test: |error| ≤ 0.2",
            showlegend=True
        )
    )

    # Dummy trace for Test: large error
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(color=color_test_em, size=8),
            name="Test: |error| > 0.2",
            showlegend=True
        )
    )

    # Add y = x dashed reference lines
    fig.add_shape(
        type="line", x0=0, y0=0, x1=1, y1=1,
        line=dict(dash="dash", color="gray"), row=1, col=1
    )
    fig.add_shape(
        type="line", x0=0, y0=0, x1=1, y1=1,
        line=dict(dash="dash", color="gray"), row=1, col=2
    )

    fig.update_layout(
        title=f"{method_label}: Predicting Aggregate Response with Proxy Agents",
        height=450,
        width=950,
        margin=dict(t=60),
        plot_bgcolor="rgba(0,0,0,0)",
        #paper_bgcolor="rgba(0,0,0,0)",
    )

    # Annotation for subplot 1
    fig.add_annotation(
        text=f"RMSE = {rmse_train:.2f}<br>R² = {r2_train:.2f}",
        xref="x domain", yref="y domain",
        x=0.95, y=0.05,
        showarrow=False,
        font=dict(size=10),
        align="right"
    )

    # Annotation for subplot 2
    fig.add_annotation(
        text=f"RMSE = {rmse_test:.2f}<br>R² = {r2_test:.2f}",
        xref="x2 domain", yref="y2 domain",
        x=0.95, y=0.05,
        showarrow=False,
        font=dict(size=10),
        align="right"
    )



    fig.update_xaxes(title_text="ŷ (prediction)", row=1, col=1, showgrid = True, gridcolor="lightgray")
    fig.update_yaxes(title_text="y (ground truth)", row=1, col=1, showgrid = True, gridcolor="lightgray")
    fig.update_xaxes(title_text="ŷ (prediction)", row=1, col=2, showgrid = True, gridcolor="lightgray")
    fig.update_yaxes(title_text="y (ground truth)", row=1, col=2, showgrid = True, gridcolor="lightgray")

    return fig

def plot_human_vs_agent_response_distribution(
    qid_bin: str,
    survey,
    responses_bin,
    aggregate_responses: dict,
    agent_distribution: dict,
    title: str = "Comparison of Human vs Agent Aggregate Responses",
    save_path: str = None
):
    """
    Plots side-by-side bar charts for a given binary question, comparing human and agent distributions.

    Args:
        qid_bin (str): The binary question ID.
        survey (BinaryExtendedSurvey): Survey object to get original question ID and label mappings.
        responses (BinaryExtendedResponses): Used to map binary to original question.
        aggregate_responses (dict): Human benchmark: qid -> {option -> prob}
        agent_distribution (dict): Model-weighted estimate: qid -> {option -> prob}
        title (str): Plot title.
        save_path (str or None): If given, saves the plot instead of displaying.
    """
    # Get original QID and option mappings
    original_qid = responses_bin.survey.binary_to_original_map[qid_bin]["original_id"]
    answer_map = survey.get_question_by_id(original_qid)['code_to_answer']  # {label -> code}

    # Build consistent label order
    labels = list(answer_map.values())

    human_probs = [aggregate_responses.get(original_qid, {}).get(label, 0.0) for label in labels]
    agent_probs = [agent_distribution.get(original_qid, {}).get(label, 0.0) for label in labels]

    x = range(len(labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.bar([i - width/2 for i in x], human_probs, width=width, label="Human", color="steelblue")
    ax.bar([i + width/2 for i in x], agent_probs, width=width, label="Agent", color="indianred")

    ax.set_ylabel("Probability")
    ax.set_title(f"{title}\nQuestion: {original_qid}")
    ax.set_xticks(list(x))
    ax.set_xticklabels(labels, rotation=15)
    ax.set_ylim(0, 1.0)
    ax.legend()

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, transparent=True, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

# def plot_human_vs_agent_response_distribution_interactive(
#     qid_bin: str,
#     survey,
#     responses_bin,
#     aggregate_responses: dict,
#     agent_distribution: dict,
#     title: str = "Comparison of Human vs Agent Aggregate Responses"
# ) -> go.Figure:
#     """
#     Returns a Plotly bar chart comparing human and agent aggregate responses for a given binary question.

#     Args:
#         qid_bin (str): Binary question ID.
#         survey (BinaryExtendedSurvey): Survey object to access original metadata.
#         responses_bin (BinaryExtendedResponses): Used to map binary to original questions.
#         aggregate_responses (dict): Human benchmark: original_qid -> {option -> prob}.
#         agent_distribution (dict): Model-weighted: original_qid -> {option -> prob}.
#         title (str): Plot title.

#     Returns:
#         plotly.graph_objects.Figure
#     """
#     # Get original question ID and label mappings
#     original_qid = responses_bin.survey.binary_to_original_map[qid_bin]["original_id"]
#     question = survey.get_question_by_id(original_qid)
#     code_to_answer = question['code_to_answer']  # {code: label}
#     labels = list(code_to_answer.values())

#     human_probs = [aggregate_responses.get(original_qid, {}).get(label, 0.0) for label in labels]
#     agent_probs = [agent_distribution.get(original_qid, {}).get(label, 0.0) for label in labels]

#     fig = go.Figure()
#     fig.add_bar(x=labels, y=human_probs, name="Human", marker_color='steelblue')
#     fig.add_bar(x=labels, y=agent_probs, name="Agent", marker_color='indianred')

#     question_text = question['question']
#     fig.update_layout(
#         barmode='group',
#         title=(
#             f"{title} for {original_qid}"
#             f"<br><sub>{question_text}</sub>"
#         ),
#         xaxis_title="Response Options",
#         yaxis_title="Probability",
#         yaxis=dict(range=[0, 1]),
#         bargap=0.2,
#         legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
#     )

#     return fig 

def build_response_distribution_plot_interactive(qid_bin: str, responses_bin, responses):
    """
    Builds a horizontal bar chart showing the response distribution of the *original* question
    corresponding to the clicked binary question.

    Args:
        qid_bin (str): Binary question ID from the model prediction plot.
        responses_bin (BinaryExtendedResponses): Object used in regression.
        responses (Responses): Original responses with full vocab and text answers.

    Returns:
        dcc.Graph: Dash-compatible bar chart component.
    """
    try:
        # Step 1: Recover original QID
        qid_original = responses_bin.survey.binary_to_original_map[qid_bin]["original_id"]
    except KeyError:
        return f"Original QID not found for binary QID: {qid_bin}"

    # Step 2: Get the response series (concatenated across splits)
    splits = ["train", "valid", "test"]
    df_list = [responses.get_matrix_by_split(split) for split in splits]
    df_all = pd.concat(df_list, axis=1)
    if qid_original not in df_all.index:
        return f"No data found for original QID: {qid_original}"

    series = df_all.loc[qid_original].dropna().astype(str)

    # Step 3: Get the vocabulary (code_to_answer map)
    code_to_answer = responses.questions[qid_original].get("code_to_answer", {})

    if responses.output_format == "answer":
        vocab = list(code_to_answer.values())
    else:
        vocab = [str(k) for k in code_to_answer.keys()]

    # Step 4: Count and pad frequencies
    counts = series.value_counts()
    padded_counts = [counts.get(label, 0) for label in vocab]

    # Step 5: Plot
    fig = go.Figure(
        data=[
            go.Bar(
                y=vocab,
                x=padded_counts,
                orientation="h",
                marker_color="#588c73"
            )
        ],
        layout=go.Layout(
            title=(
                f"Response Distribution for {qid_original}"
                f"<br><sub>{responses.questions[qid_original].get('question', '')}</sub>"
            ),
            xaxis_title="Frequency",
            yaxis_title="Answer",
            height=450,
            width=950,
            margin=dict(t=60, r=40, b=40, l=100)
        )
    )

    return fig

def plot_human_vs_agent_response_distribution_interactive(
    qid_bin: str,
    survey,
    responses_bin,
    responses,
    aggregate_responses: dict,
    agent_distribution_after: dict,
    title: str = "Comparison of Human vs Agent Aggregate Responses"
) -> go.Figure:
    """
    Returns a Plotly bar chart comparing human and agent aggregate responses
    before and after alignment for a given binary question.

    Args:
        qid_bin (str): Binary question ID.
        survey (BinaryExtendedSurvey): Survey object to access original metadata.
        responses_bin (BinaryExtendedResponses): Used to map binary to original questions.
        responses (Responses): Raw unaligned responses (used to compute "before" distribution).
        aggregate_responses (dict): Human benchmark: original_qid -> {option -> prob}.
        agent_distribution_after (dict): Aligned (regression-weighted) distribution: original_qid -> {option -> prob}.
        title (str): Plot title.

    Returns:
        plotly.graph_objects.Figure
    """

    # Step 1: Recover original QID
    original_qid = responses_bin.survey.binary_to_original_map[qid_bin]["original_id"]
    question = survey.get_question_by_id(original_qid)
    code_to_answer = question['code_to_answer']
    labels = list(code_to_answer.values())

    # Step 2: Compute raw frequency distribution ("before alignment")
    splits = ["train", "valid", "test"]
    df_list = [responses.get_matrix_by_split(split) for split in splits]
    df_all = pd.concat(df_list, axis=1)

    if original_qid in df_all.index:
        series = df_all.loc[original_qid].dropna().astype(str)
        raw_counts = series.value_counts()
        total = raw_counts.sum()
        agent_probs_before = [raw_counts.get(label, 0) / total for label in labels] if total > 0 else [0.0] * len(labels)
    else:
        agent_probs_before = [0.0] * len(labels)

    # Step 3: Look up human and after-alignment probs
    human_probs = [aggregate_responses.get(original_qid, {}).get(label, 0.0) for label in labels]
    agent_probs_after = [agent_distribution_after.get(original_qid, {}).get(label, 0.0) for label in labels]

    # Step 4: Plot
    fig = go.Figure()
    fig.add_bar(x=labels, y=human_probs, name="Human", marker_color='steelblue')
    fig.add_bar(x=labels, y=agent_probs_before, name="Agent (before alignment)", marker_color='lightgray')
    fig.add_bar(x=labels, y=agent_probs_after, name="Agent (after alignment)", marker_color='indianred')

    question_text = question.get('question', '')
    fig.update_layout(
        barmode='group',
        title=f"{title} for {original_qid}<br><sub>{question_text}</sub>",
        xaxis_title="Response Options",
        yaxis_title="Probability",
        height=450,
        width=950,
        yaxis=dict(range=[0, 1]),
        bargap=0.2,
        legend=dict(orientation="h", yanchor="bottom", y=1.0, xanchor="right", x=1)
    )

    return fig

def plot_model_vs_endowment_weights(
    model,
    endowments,
    save_path: str = None,
    threshold: float = 0.0,
    sort_by: str = "ground_truth",
    sort_key: Callable = None,
    figsize=(12, 6),
    bar_colors=("red", "darkblue"),
    title: str = "Lasso Coefficients vs Endowment Weights\n(X-ticks silver = ground truth 0)",
):
    """
    Plot model coefficients against ground truth weights from endowments.

    Returns:
        fig (matplotlib.figure.Figure): The generated matplotlib figure.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    feature_names = model.feature_names_
    model_coefs = pd.Series(model.beta_, index=feature_names, name="coefficient")

    gt_weights = [
        endowments.get_endowment_by_eid(eid).get("weight", 0)
        for eid in feature_names
    ]
    gt_weights = [w if w is not None else 0 for w in gt_weights]
    gt_series = pd.Series(gt_weights, index=feature_names, name="ground_truth")

    df_plot = pd.concat([gt_series, model_coefs], axis=1)
    df_plot["feature"] = df_plot.index
    df_plot = df_plot[(df_plot["coefficient"] != 0) | (df_plot["ground_truth"] != 0)]

    if sort_key is not None:
        df_plot = df_plot.sort_values(by=sort_by, key=lambda col: col.map(sort_key), ascending=False)
    else:
        df_plot = df_plot.sort_values(by=sort_by, ascending=False)

    x = np.arange(len(df_plot))
    width = 0.35

    fig, ax = plt.subplots(figsize=figsize, facecolor="none")

    ax.bar(x - width / 2, df_plot["ground_truth"], width, label="Ground Truth", color=bar_colors[0])
    ax.bar(x + width / 2, df_plot["coefficient"], width, label="Model Coefficient", color=bar_colors[1])

    ax.set_xticks(x)
    ax.set_xticklabels(df_plot["feature"], rotation=90)

    for tick, gt in zip(ax.get_xticklabels(), df_plot["ground_truth"]):
        if abs(gt) <= threshold:
            tick.set_color("silver")

    ax.set_ylabel("Coefficient Value")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, axis="y", linestyle="--", alpha=0.6)
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, transparent=True, dpi=300)

    return fig