import os
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from typing import Tuple, Optional, Union, List

from experiments.experiments import SimulationExperiment
from modules.endowment_manager import ActiveEndowments
from modules.response_converter import Responses
from modules.survey_converter import Survey
from models import fit_lasso_model, fit_elastic_net_model
from experiments.utils import (
    plot_model_predictions,
    plot_model_vs_endowment_weights
)
from utils.plotting import plot_lasso_diagnostics, plot_elasticnet_diagnostics

def simulate_endowment_fraction(base_experiment, config, fraction=0.5, seed=101,
                                model_type="lasso", plot=True, verbose=True):
    """
    Run simulation after subsampling endowments at given fraction.

    Args:
        base_experiment (SimulationExperiment): Original full experiment object.
        config (dict): YAML-driven config.
        fraction (float): Fraction of endowments to retain.
        seed (int): Random seed for reproducibility.
        model_type (str): Regression model to use ("lasso" or "elasticnet").
        plot (bool): Whether to generate plots (default: True).
        verbose (bool): Whether to print summary statistics (default: True).

    Returns:
        model: Fitted regression model
        experiment: New SimulationExperiment object
        stats: Dictionary of fit and selection metrics
        fig_diag: Diagnostic plot (or None)
        fig_pred: Prediction vs ground truth plot (or None)
    """
    # 1. Subsample endowments and responses
    sampled_endowments = base_experiment.endowments.clone_with_fraction(fraction=fraction, seed=seed)
    sampled_endowments.renormalize_ground_truth_weights()
    valid_eids = {e["eid"] for e in sampled_endowments.get_endowments()}
    filtered_responses = base_experiment.responses.clone_with_agents(valid_eids)

    # 2. Create new experiment
    experiment = SimulationExperiment(
        responses=filtered_responses,
        survey=base_experiment.survey,
        endowments=sampled_endowments,
        filter_binary=base_experiment.filter_binary,
        drop_na=base_experiment.drop_na
    )

    # 3. Fit model
    if model_type == "lasso":
        model, best_alpha, diagnostics = fit_lasso_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    elif model_type == "elasticnet":
        model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    else:
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'lasso' or 'elasticnet'.")

    # 4. Training stats
    df_trainval = experiment.get_dataframe_by_split(["train", "val"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_trainval = df_trainval[model.feature_names_].astype(float)
    y_true = df_trainval["aggregate"].values
    mse = model.score(X_trainval, y_true)
    r2 = model.r2(X_trainval, y_true)
    n_obs, n_agents = X_trainval.shape

    # 5. Test stats
    df_test = experiment.get_dataframe_by_split(config["split_settings"]["test_split"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_test = df_test[model.feature_names_].astype(float)
    y_test = df_test["aggregate"].values
    mse_test = model.score(X_test, y_test)
    r2_test = model.r2(X_test, y_test)

    # 6. Selection metrics
    selected = {eid for eid, coef in model.coef_dict_.items() if coef != 0}
    ground_truth = {e["eid"] for e in sampled_endowments.get_endowments_by_role("ground_truth")}
    tp = len(selected & ground_truth)
    fp = len(selected - ground_truth)
    fn = len(ground_truth - selected)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    # 7. Coefficient stats
    mean_weight_gt = np.mean([model.coef_dict_[eid] for eid in ground_truth if eid in model.coef_dict_]) if tp > 0 else 0.0
    mean_weight_fp = np.mean([model.coef_dict_[eid] for eid in selected - ground_truth]) if fp > 0 else 0.0

    # 8. Verbose output
    if verbose:
        print(f"--- Simulation Summary ---")
        print(f"> Observations: {n_obs} | Agents: {n_agents} | Obs/Agent: {n_obs / n_agents:.2f}")
        print(f"> GT agents: {len(ground_truth)} | Obs/GT: {n_obs / len(ground_truth):.2f}")
        print(f"> MSE (train): {mse:.4f} | R² (train): {r2:.4f}")
        print(f"> MSE (test):  {mse_test:.4f} | R² (test):  {r2_test:.4f}")
        print(f"> Selected: {len(selected)} | Precision: {precision:.2f} | Recall: {recall:.2f}")

    # 9. Plots
    fig_diag, fig_pred, fig_weights = None, None, None
    if plot:
        if model_type == "lasso":
            fig_diag = plot_lasso_diagnostics(
                diagnostics,
                best_alpha,
                strategy=config["lasso"].get("validation", {}).get("strategy", "cv")
            )
        elif model_type == "elasticnet":
            plot_style = config["elasticnet"].get("plot_style", "2D")
            fig_diag = plot_elasticnet_diagnostics(
                diagnostics,
                best_alpha,
                best_l1_ratio,
                style=plot_style
            )

        fig_pred = plot_model_predictions(model, experiment, config, method_label=f"{model_type.title()} Regression")
        fig_weights = plot_model_vs_endowment_weights(model, sampled_endowments)

    # 10. Return stats
    stats = {
        "fraction": fraction,
        "seed": seed,
        "model_type": model_type,
        "n_obs": n_obs,
        "n_agents": n_agents,
        "n_gt": len(ground_truth),
        "obs_per_agent": n_obs / n_agents,
        "obs_per_gt": n_obs / len(ground_truth),
        "mse": mse,
        "r2": r2,
        "mse_test": mse_test,
        "r2_test": r2_test,
        "selected": len(selected),
        "precision": precision,
        "recall": recall,
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "mean_weight_gt": mean_weight_gt,
        "mean_weight_fp": mean_weight_fp,
        "selected_eids": list(selected),
        "gt_eids": list(ground_truth),
    }

    return model, experiment, stats, fig_diag, fig_pred, fig_weights

def simulate_response_fraction(base_experiment: SimulationExperiment,
                               config: dict,
                               fraction: float = 0.5,
                               seed: int = 101,
                               model_type: str = "lasso",
                               plot: bool = True,
                               verbose: bool = True):
    """
    Run simulation after subsampling questions (responses) at given fraction.

    Args:
        base_experiment (SimulationExperiment): Original full experiment object.
        config (dict): YAML-driven config.
        fraction (float): Fraction of questions to retain.
        seed (int): Random seed for reproducibility.
        model_type (str): Regression model to use ("lasso" or "elasticnet").
        plot (bool): Whether to generate plots (default: True).
        verbose (bool): Whether to print summary statistics (default: True).

    Returns:
        model: Fitted regression model
        experiment: New SimulationExperiment object with subsampled responses
        stats: Dictionary of fit and selection metrics
        fig_diag: Diagnostic plot (or None)
        fig_pred: Prediction vs ground truth plot (or None)
    """
    # 1. Subsample questions
    experiment = base_experiment.sample_fraction(fraction=fraction, seed=seed)

    # 2. Fit model
    if model_type == "lasso":
        model, best_alpha, diagnostics = fit_lasso_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    elif model_type == "elasticnet":
        model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    else:
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'lasso' or 'elasticnet'.")

    # 3. Training stats
    df_trainval = experiment.get_dataframe_by_split(["train", "val"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_trainval = df_trainval[model.feature_names_].astype(float)
    y_true = df_trainval["aggregate"].values
    mse = model.score(X_trainval, y_true)
    r2 = model.r2(X_trainval, y_true)
    n_obs, n_agents = X_trainval.shape

    # 4. Test stats
    df_test = experiment.get_dataframe_by_split(config["split_settings"]["test_split"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_test = df_test[model.feature_names_].astype(float)
    y_test = df_test["aggregate"].values
    mse_test = model.score(X_test, y_test)
    r2_test = model.r2(X_test, y_test)

    # 5. Selection metrics (ground truth comes from endowments)
    selected = {eid for eid, coef in model.coef_dict_.items() if coef != 0}
    ground_truth = {e["eid"] for e in experiment.endowments.get_endowments_by_role("ground_truth")}
    tp = len(selected & ground_truth)
    fp = len(selected - ground_truth)
    fn = len(ground_truth - selected)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    # 6. Coefficient stats
    mean_weight_gt = np.mean([model.coef_dict_[eid] for eid in ground_truth if eid in model.coef_dict_]) if tp > 0 else 0.0
    mean_weight_fp = np.mean([model.coef_dict_[eid] for eid in selected - ground_truth]) if fp > 0 else 0.0

    # 7. Verbose summary
    if verbose:
        print(f"--- Simulation Summary ---")
        print(f"> Questions: {len(experiment.responses.questions.keys())} | Obs: {n_obs} | Agents: {n_agents}")
        print(f"> Obs/Agent: {n_obs / n_agents:.2f} | Obs/GT: {n_obs / len(ground_truth):.2f}")
        print(f"> MSE (train): {mse:.4f} | R² (train): {r2:.4f}")
        print(f"> MSE (test):  {mse_test:.4f} | R² (test):  {r2_test:.4f}")
        print(f"> Selected: {len(selected)} | Precision: {precision:.2f} | Recall: {recall:.2f}")

    # 8. Plots
    fig_diag, fig_pred, fig_weights = None, None, None
    if plot:
        if model_type == "lasso":
            fig_diag = plot_lasso_diagnostics(
                diagnostics,
                best_alpha,
                strategy=config["lasso"].get("validation", {}).get("strategy", "cv")
            )
        elif model_type == "elasticnet":
            plot_style = config["elasticnet"].get("plot_style", "2D")
            fig_diag = plot_elasticnet_diagnostics(
                diagnostics,
                best_alpha,
                best_l1_ratio,
                style=plot_style
            )
        fig_pred = plot_model_predictions(model, experiment, config, method_label=f"{model_type.title()} Regression")
        fig_weights = plot_model_vs_endowment_weights(model, experiment.endowments)

    # 9. Return stats
    stats = {
        "fraction": fraction,
        "seed": seed,
        "model_type": model_type,
        "n_obs": n_obs,
        "n_agents": n_agents,
        "n_gt": len(ground_truth),
        "obs_per_agent": n_obs / n_agents,
        "obs_per_gt": n_obs / len(ground_truth),
        "mse": mse,
        "r2": r2,
        "mse_test": mse_test,
        "r2_test": r2_test,
        "selected": len(selected),
        "precision": precision,
        "recall": recall,
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "mean_weight_gt": mean_weight_gt,
        "mean_weight_fp": mean_weight_fp,
        "selected_eids": list(selected),
        "gt_eids": list(ground_truth),
    }

    return model, experiment, stats, fig_diag, fig_pred, fig_weights

def run_fraction_sweep(
    base_experiment,
    config: dict,
    fraction_values: List[float],
    mode: str = "endowment",  # "response" or "endowment"
    model_type: str = "lasso",
    seed: int = 101,
    output_csv: str = None,
    plot: bool = False,
    plot_dir: str = None,
    verbose: bool = True
) -> pd.DataFrame:
    """
    Run simulation sweep over different fractions of endowments or responses,
    optionally saving all result plots.

    Args:
        base_experiment: SimulationExperiment object
        config (dict): YAML config
        fraction_values (list[float]): Fractions to simulate
        mode (str): "endowment" or "response"
        model_type (str): "lasso" or "elasticnet"
        seed (int): Random seed
        output_csv (str): Optional CSV output path
        plot (bool): Whether to generate plots
        plot_dir (str): Directory to save plots (if plot=True)
        verbose (bool): Verbose logging

    Returns:
        pd.DataFrame: Collected stats from simulations
    """
    from experiments.simulations import (
        simulate_endowment_fraction,
        simulate_response_fraction
    )

    os.makedirs(plot_dir, exist_ok=True) if plot and plot_dir else None
    records = []

    for fraction in tqdm(fraction_values, desc=f"Running {mode} simulations"):
        # --- 1. Run Simulation ---
        if mode == "endowment":
            model, exp, stats, fig_diag, fig_pred, fig_weights = simulate_endowment_fraction(
                base_experiment, config,
                fraction=fraction, seed=seed,
                model_type=model_type, plot=plot, verbose=verbose
            )
        elif mode == "response":
            model, exp, stats, fig_diag, fig_pred, fig_weights = simulate_response_fraction(
                base_experiment, config,
                fraction=fraction, seed=seed,
                model_type=model_type, plot=plot, verbose=verbose
            )
        else:
            raise ValueError("Invalid mode. Must be 'endowment' or 'response'.")

        # --- 2. Save Plots ---
        if plot and plot_dir:
            fraction_str = f"{fraction:.2f}".replace(".", "p")
            base_name = f"{mode}_{model_type}_frac{fraction_str}_seed{seed}"
            if fig_diag:
                fig_diag.savefig(os.path.join(plot_dir, f"{base_name}_diag.png"), transparent=True, dpi=300)
            if fig_pred:
                fig_pred.savefig(os.path.join(plot_dir, f"{base_name}_pred.png"), transparent=True, dpi=300)
            if fig_weights:
                fig_weights.savefig(os.path.join(plot_dir, f"{base_name}_weights.png"), transparent=True, dpi=300)

        records.append(stats)

    df_results = pd.DataFrame.from_records(records)
    if output_csv:
        df_results.to_csv(output_csv, index=False)
        if verbose:
            print(f"[✓] Results saved to: {output_csv}")

    return df_results


def _evaluate_simulation_model(model, experiment, endowments, config, model_type, fraction, seed, diagnostics, plot, verbose):
    from experiments.utils import (
        plot_model_predictions,
        plot_model_vs_endowment_weights
    )
    from utils.plotting import plot_lasso_diagnostics, plot_elasticnet_diagnostics

    df_trainval = experiment.get_dataframe_by_split(["train", "val"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_trainval = df_trainval[model.feature_names_].astype(float)
    y_true = df_trainval["aggregate"].values
    mse = model.score(X_trainval, y_true)
    r2 = model.r2(X_trainval, y_true)
    n_obs, n_agents = X_trainval.shape

    df_test = experiment.get_dataframe_by_split(config["split_settings"]["test_split"], proxy_only=config["split_settings"]["use_proxy_only"])
    X_test = df_test[model.feature_names_].astype(float)
    y_test = df_test["aggregate"].values
    mse_test = model.score(X_test, y_test)
    r2_test = model.r2(X_test, y_test)

    selected = {eid for eid, coef in model.coef_dict_.items() if coef != 0}
    ground_truth = {e["eid"] for e in endowments.get_endowments_by_role("ground_truth")}
    proxy = {e["eid"] for e in endowments.get_endowments_by_role("proxy")}
    tp = len(selected & ground_truth)
    fp = len(selected - ground_truth)
    fn = len(ground_truth - selected)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    mean_weight_gt = np.mean([model.coef_dict_[eid] for eid in ground_truth if eid in model.coef_dict_]) if tp > 0 else 0.0
    mean_weight_fp = np.mean([model.coef_dict_[eid] for eid in selected - ground_truth]) if fp > 0 else 0.0

    if verbose:
        print(f"--- Simulation Summary ---")
        print(f"> Observations: {n_obs} | Agents: {n_agents} | Obs/Agent: {n_obs / n_agents:.2f}")
        print(f"> GT agents: {len(ground_truth)} | Obs/GT: {n_obs / len(ground_truth):.2f}")
        print(f"> MSE (train): {mse:.4f} | R² (train): {r2:.4f}")
        print(f"> MSE (test):  {mse_test:.4f} | R² (test):  {r2_test:.4f}")
        print(f"> Selected: {len(selected)} | Precision: {precision:.2f} | Recall: {recall:.2f}")

    fig_diag, fig_pred, fig_weights = None, None, None
    if plot:
        if model_type == "lasso":
            fig_diag = plot_lasso_diagnostics(
                diagnostics, model.alpha,
                strategy=config["lasso"].get("validation", {}).get("strategy", "cv")
            )
        elif model_type == "elasticnet":
            fig_diag = plot_elasticnet_diagnostics(
                diagnostics, model.alpha, model.l1_ratio,
                style=config["elasticnet"].get("plot_style", "2D")
            )
        fig_pred = plot_model_predictions(model, experiment, config, method_label=f"{model_type.title()} Regression")
        fig_weights = plot_model_vs_endowment_weights(model, endowments)

    stats = {
        "fraction": fraction,
        "seed": seed,
        "model_type": model_type,
        "n_obs": n_obs,
        "n_agents": n_agents,
        "n_gt": len(ground_truth),
        "n_proxy": len(proxy),
        "obs_per_agent": n_obs / n_agents,
        "obs_per_gt": n_obs / len(ground_truth),
        "obs_per_proxy": n_obs / len(proxy),
        "proxy_to_gt_ratio": len(proxy)/ len(ground_truth),
        "mse": mse,
        "r2": r2,
        "mse_test": mse_test,
        "r2_test": r2_test,
        "selected": len(selected),
        "precision": precision,
        "recall": recall,
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "mean_weight_gt": mean_weight_gt,
        "mean_weight_fp": mean_weight_fp,
        "selected_eids": list(selected),
        "gt_eids": list(ground_truth),
    }


    original_responses, original_survey = reload_original_responses_from_experiment(experiment)

    gt_entropy = ActiveEndowments.compute_entropy(
        endowment_list=endowments.get_endowments_by_role("ground_truth"),
        responses=original_responses,
        survey=original_survey
    )
    proxy_entropy = ActiveEndowments.compute_entropy(
        endowment_list=endowments.get_endowments_by_role("proxy"),
        responses=original_responses,
        survey=original_survey
    )

    stats["gt_entropy"] = gt_entropy
    stats["proxy_entropy"] = proxy_entropy

    return model, experiment, stats, fig_diag, fig_pred, fig_weights

def simulate_proxy_fraction(base_experiment: SimulationExperiment,
                             config: dict,
                             fraction: float = 0.5,
                             seed: int = 101,
                             model_type: str = "lasso",
                             plot: bool = True,
                             verbose: bool = True):
    """
    Run simulation after subsampling only proxy agents at given fraction,
    while keeping ground_truth agents intact.

    Returns:
        model, experiment, stats, fig_diag, fig_pred, fig_weights
    """
    # 1. Subsample proxy endowments
    sampled_endowments = base_experiment.endowments.clone_with_proxy_fraction(fraction=fraction, seed=seed)
    sampled_endowments.renormalize_ground_truth_weights()
    valid_eids = {e["eid"] for e in sampled_endowments.get_endowments()}
    filtered_responses = base_experiment.responses.clone_with_agents(valid_eids)

    # 2. Construct new experiment
    experiment = SimulationExperiment(
        responses=filtered_responses,
        survey=base_experiment.survey,
        endowments=sampled_endowments,
        filter_binary=base_experiment.filter_binary,
        drop_na=base_experiment.drop_na
    )

    # 3. Fit model
    if model_type == "lasso":
        model, best_alpha, diagnostics = fit_lasso_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    elif model_type == "elasticnet":
        model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    else:
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'lasso' or 'elasticnet'.")

    # 4. Evaluation metrics (train/test + selection + coeff stats)
    return _evaluate_simulation_model(model, experiment, sampled_endowments, config, model_type, fraction, seed, diagnostics, plot, verbose)

def simulate_trainvalid_fraction(base_experiment: SimulationExperiment,
                                  config: dict,
                                  fraction: float = 0.5,
                                  seed: int = 101,
                                  model_type: str = "lasso",
                                  plot: bool = True,
                                  verbose: bool = True):
    """
    Run simulation after subsampling only the train and valid question splits,
    keeping test questions intact.

    Returns:
        model, experiment, stats, fig_diag, fig_pred, fig_weights
    """
    # 1. Subsample train+valid questions
    experiment = base_experiment.sample_trainvalid_fraction(fraction=fraction, seed=seed)

    # 2. Fit model
    if model_type == "lasso":
        model, best_alpha, diagnostics = fit_lasso_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    elif model_type == "elasticnet":
        model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    else:
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'lasso' or 'elasticnet'.")

    # 3. Evaluation metrics
    return _evaluate_simulation_model(model, experiment, experiment.endowments, config, model_type, fraction, seed, diagnostics, plot, verbose)

def run_customized_fraction_sweep(
    base_experiment,
    config: dict,
    fraction_values: List[float],
    mode: str = "proxy",  # "proxy" or "trainvalid"
    model_type: str = "lasso",
    seed: int = 101,
    output_csv: str = None,
    plot: bool = False,
    plot_dir: str = None,
    verbose: bool = True
) -> pd.DataFrame:
    """
    Run simulation sweep over customized modes:
    - "proxy": subsample only proxy agents
    - "trainvalid": subsample only train/valid questions (test is intact)

    Args:
        base_experiment: SimulationExperiment object
        config (dict): YAML config
        fraction_values (list[float]): Fractions to simulate
        mode (str): "proxy" or "trainvalid"
        model_type (str): "lasso" or "elasticnet"
        seed (int): Random seed
        output_csv (str): Optional CSV output path
        plot (bool): Whether to generate plots
        plot_dir (str): Directory to save plots (if plot=True)
        verbose (bool): Verbose logging

    Returns:
        pd.DataFrame: Collected stats from simulations
    """
    os.makedirs(plot_dir, exist_ok=True) if plot and plot_dir else None
    records = []

    for fraction in tqdm(fraction_values, desc=f"Running {mode} simulations"):
        # --- 1. Run Simulation ---
        if mode == "proxy":
            model, exp, stats, fig_diag, fig_pred, fig_weights = simulate_proxy_fraction(
                base_experiment, config,
                fraction=fraction, seed=seed,
                model_type=model_type, plot=plot, verbose=verbose
            )
        elif mode == "trainvalid":
            model, exp, stats, fig_diag, fig_pred, fig_weights = simulate_trainvalid_fraction(
                base_experiment, config,
                fraction=fraction, seed=seed,
                model_type=model_type, plot=plot, verbose=verbose
            )
        else:
            raise ValueError("Invalid mode. Must be 'proxy' or 'trainvalid'.")

        # --- 2. Save Plots ---
        if plot and plot_dir:
            fraction_str = f"{fraction:.2f}".replace(".", "p")
            base_name = f"{mode}_{model_type}_frac{fraction_str}_seed{seed}"
            if fig_diag:
                fig_diag.savefig(os.path.join(plot_dir, f"{base_name}_diag.png"), transparent=True, dpi=300)
            if fig_pred:
                fig_pred.savefig(os.path.join(plot_dir, f"{base_name}_pred.png"), transparent=True, dpi=300)
            if fig_weights:
                fig_weights.savefig(os.path.join(plot_dir, f"{base_name}_weights.png"), transparent=True, dpi=300)

        records.append(stats)

    df_results = pd.DataFrame.from_records(records)
    if output_csv:
        df_results.to_csv(output_csv, index=False)
        if verbose:
            print(f"[✓] Results saved to: {output_csv}")

    return df_results

def simulate_proxy_sweep_by_mode(
    base_experiment,
    config: dict,
    included_modes: List[str],
    n_proxies: int = 30,
    seed: Optional[int] = None,
    model_type: str = "lasso",
    plot: bool = True,
    verbose: bool = True
) -> Tuple:
    """
    Run simulation after sampling fixed GT agents (preassigned in base_experiment)
    and sampling `n_proxies` from `included_modes` to form proxy group.

    Args:
        base_experiment: SimulationExperiment object with fixed GT agents
        config (dict): YAML config
        included_modes (List[str]): Modes from which to draw proxy agents
        n_proxies (int): Number of proxy agents to sample
        seed (int, optional): Random seed
        model_type (str): 'lasso' or 'elasticnet'
        plot (bool): Whether to generate diagnostic plots
        verbose (bool): Logging toggle

    Returns:
        model, experiment, stats, fig_diag, fig_pred, fig_weights
    """
    # 1. Subsample endowments using specified modes
    sampled_endowments = base_experiment.endowments.clone_with_proxy_modes(
        included_modes=included_modes,
        n_proxies=n_proxies,
        seed=seed
    )
    valid_eids = {e["eid"] for e in sampled_endowments.get_endowments()}
    filtered_responses = base_experiment.responses.clone_with_agents(valid_eids)

    # 2. Construct new experiment
    experiment = base_experiment.__class__(
        responses=filtered_responses,
        survey=base_experiment.survey,
        endowments=sampled_endowments,
        filter_binary=base_experiment.filter_binary,
        drop_na=base_experiment.drop_na
    )

    # 3. Fit model
    if model_type == "lasso":
        model, best_alpha, diagnostics = fit_lasso_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    elif model_type == "elasticnet":
        model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config, verbose=verbose)
        model.log_fit_summary()
    else:
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'lasso' or 'elasticnet'.")

    # 4. Evaluation metrics (train/test + selection + coeff stats)
    return _evaluate_simulation_model(
        model, experiment, sampled_endowments,
        config, model_type,
        fraction=None,  # not meaningful here
        seed=seed,
        diagnostics=diagnostics,
        plot=plot,
        verbose=verbose
    )

def run_entropy_sweep(
    base_experiment,
    config: dict,
    ordered_modes: List[str],
    n_proxies: int = 30,
    model_type: str = "lasso",
    seed: Optional[int] = None,
    plot: bool = False,
    plot_dir: Optional[str] = None,
    output_csv: Optional[str] = None,
    verbose: bool = True,
) -> pd.DataFrame:
    """
    Run a simulation sweep by gradually expanding proxy modes (sorted by entropy),
    fixing ground truth agents, and observing model performance.

    Args:
        base_experiment: SimulationExperiment with fixed GT agents assigned
        config (dict): YAML config
        ordered_modes (List[str]): Modes sorted by entropy (low → high)
        n_proxies (int): Number of proxy agents to sample per run
        model_type (str): "lasso" or "elasticnet"
        seed (Optional[int]): Reproducibility
        plot (bool): Whether to generate plots
        plot_dir (Optional[str]): Where to save plots if plot=True
        output_csv (Optional[str]): Where to save results
        verbose (bool): Verbosity toggle

    Returns:
        pd.DataFrame: Sweep results (each row = one mode expansion step)
    """
    os.makedirs(plot_dir, exist_ok=True) if plot and plot_dir else None
    records = []


    start_idx, proxy_mode_counts = find_sweep_start_idx_for_proxies(
        endowments=base_experiment.endowments,
        ordered_modes=ordered_modes,
        n_proxies=n_proxies
    )

    print(f"[INFO] Starting sweep from mode index {start_idx} with {sum(len(eids) for _, eids in proxy_mode_counts)} proxy agents.")
    for i in tqdm(range(start_idx, len(ordered_modes) + 1), desc="Entropy Sweep"):
        current_modes = ordered_modes[:i]
        mode_str = ",".join(["+".join(mode) for mode in current_modes])

        model, exp, stats, fig_diag, fig_pred, fig_weights = simulate_proxy_sweep_by_mode(
            base_experiment=base_experiment,
            config=config,
            included_modes=current_modes,
            n_proxies=n_proxies,
            seed=seed,
            model_type=model_type,
            plot=plot,
            verbose=verbose,
        )

        stats["included_modes"] = mode_str
        stats["n_modes"] = i

        if plot and plot_dir:
            prefix = f"{model_type}_modesweep_{i:02d}_seed{seed}"
            if fig_diag:
                fig_diag.savefig(os.path.join(plot_dir, f"{prefix}_diag.png"), dpi=300, transparent=True)
            if fig_pred:
                fig_pred.savefig(os.path.join(plot_dir, f"{prefix}_pred.png"), dpi=300, transparent=True)
            if fig_weights:
                fig_weights.savefig(os.path.join(plot_dir, f"{prefix}_weights.png"), dpi=300, transparent=True)

        records.append(stats)

    df = pd.DataFrame.from_records(records)
    if output_csv:
        df.to_csv(output_csv, index=False)
        if verbose:
            print(f"[✓] Sweep results saved to {output_csv}")

    return df

def find_sweep_start_idx_for_proxies(endowments, ordered_modes, n_proxies, role="proxy"):
    """
    Find the minimum starting index such that the cumulative number of endowments
    with the specified role across ordered_modes is at least `n_proxies`.

    Args:
        endowments (ActiveEndowments): Endowment manager with role assignments.
        ordered_modes (List[tuple]): List of mode tuples in sweep order.
        n_proxies (int): Minimum number of endowments required to begin sweep.
        role (str): Role to filter on (default: "proxy").

    Returns:
        int: Index in `ordered_modes` to begin sweep from.
        List[tuple]: List of (mode, eid_set) tuples accumulated up to start_idx.
    """
    proxy_mode_counts = []
    cumulative_eids = set()

    for i, mode in enumerate(ordered_modes):
        eids = {
            e["eid"] for e in endowments.get_endowments()
            if e["mode"] == mode and e["role"] == role
        }
        proxy_mode_counts.append((mode, eids))
        cumulative_eids.update(eids)

        if len(cumulative_eids) >= n_proxies:
            return i, proxy_mode_counts

    raise ValueError(f"Not enough {role} agents available across all modes to meet n_proxies = {n_proxies}.")


def assign_ground_truth_from_modes(
    endowments: ActiveEndowments,
    gt_modes: List[str],
    n_gt: int = 10,
    seed: Optional[int] = None,
) -> ActiveEndowments:
    """
    Assign ground truth roles to a sampled subset of agents from the specified modes.

    This function filters the available endowments by the provided modes, randomly samples
    `n_gt` agents from them, sets their role to "ground_truth", and initializes their weights.
    All other agents are assigned as "proxy".

    Args:
        endowments (ActiveEndowments): Endowment manager with agent metadata.
        gt_modes (List[str]): List of mode names to sample GT agents from.
        n_gt (int): Number of ground truth agents to assign.
        seed (int, optional): Random seed for reproducibility.

    Returns:
        ActiveEndowments: Updated endowment manager with new role assignments and weights.
    """
    if seed is not None:
        random.seed(seed)
    candidates = [
        e['eid'] for e in endowments.get_endowments()
        if e.get("mode") in gt_modes
    ]

    if len(candidates) < n_gt:
        raise ValueError(f"Only {len(candidates)} GT candidates found in specified modes, but n_gt = {n_gt}.")

    sampled_gt = random.sample(candidates, k=n_gt)

    # Create role_map: set selected GT agents, all others default to 'proxy' in assign_roles
    role_map = {eid: 'ground_truth' for eid in sampled_gt}

    # Assign roles using from_map
    endowments.assign_roles(method='from_map', role_map=role_map)

    # Assign normalized weights to GT agents
    endowments.initialize_ground_truth_weights()

    return endowments

def reload_original_responses_from_experiment(experiment):
    survey_csv = experiment.survey.csv_path
    survey_yaml = experiment.survey.config_path
    responses_path = experiment.responses.source_path
    survey = Survey(csv_path=survey_csv, config_path=survey_yaml)
    responses = Responses(source=responses_path, survey=survey, output_format="answer")
    return responses, survey