import os
import sys
import yaml
import argparse
import subprocess
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from io import StringIO
from datetime import datetime
from sklearn.linear_model import Lasso
from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.endowment_manager import ActiveEndowments
from modules.response_converter import BinaryExtendedResponses
from modules.aggregate_responses import AggregateResponses
from experiments.experiments import EmpiricalExperiment
from experiments.utils import (
    fit_post_selection_ols, 
    plot_model_predictions,
    summarize_sparse_model_selection, assign_model_weights_to_endowments,
    save_model, load_model,
    compute_prediction_metrics, compute_mode_summary
)
from utils.plotting import plot_lasso_diagnostics, plot_elasticnet_diagnostics
from models import fit_lasso_model, fit_elastic_net_model

def stringify_keys(d):
    return {
        "+".join(k) if isinstance(k, (tuple, list)) else str(k): v
        for k, v in d.items()
    }

def get_project_root():
    """Return the absolute path to the project root (one level up from this script)."""
    return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

def capture_stdout(func, *args, **kwargs):
    """Capture and return stdout printed by a function."""
    buf = StringIO()
    sys.stdout = buf
    func(*args, **kwargs)
    sys.stdout = sys.__stdout__
    return buf.getvalue()

def convert_md_to_html(md_path, html_path, css):
    subprocess.run([
        "pandoc", md_path,
        "-o", html_path,
        "--standalone",
        f"--css={css}"
    ])

def convert_aggregate_metrics_to_df(all_results):
    rows = []
    for result in all_results:
        model_name = result["name"]
        metrics = result.get("prediction_metrics", {})
        mode_summary = result.get("mode_summary", {})
        hyperparams = result.get("hyperparams", {})

        row = {
            "model": model_name,
            **metrics,
            "best_alpha": hyperparams.get("alpha"),
            "best_l1_ratio": hyperparams.get("l1_ratio"),
            "num_selected": sum(m["num_selected"] for m in mode_summary.values()) if mode_summary else None
        }
        rows.append(row)

    return pd.DataFrame(rows)

def convert_mode_summary_to_df(all_results):
    records = []
    for result in all_results:
        model_name = result["name"]
        mode_summary = result.get("mode_summary", {})
        for mode, stats in mode_summary.items():
            records.append({
                "model": model_name,
                "mode": "+".join(mode) if isinstance(mode, tuple) else str(mode),
                **stats
            })

    return pd.DataFrame(records)

def run_lasso_experiment(experiment, endowments, config, output_dir):
    model, best_alpha, diagnostics = fit_lasso_model(experiment, config)

    # Save plots
    fig_diag = plot_lasso_diagnostics(diagnostics, best_alpha, strategy=config["lasso"].get("validation", {}).get("strategy"))
    fig_pred = plot_model_predictions(model, experiment, config, method_label='Lasso Regression')

    diag_path = os.path.join(output_dir, "diagnostics_lasso.png")
    pred_path = os.path.join(output_dir, "predictions_lasso.png")
    fig_diag.savefig(diag_path, transparent=True, bbox_inches='tight')
    fig_pred.savefig(pred_path, transparent=True, bbox_inches='tight')
    plt.close(fig_diag)
    plt.close(fig_pred)

    summary_text = capture_stdout(summarize_sparse_model_selection, model, experiment, endowments, verbose=True)

    return model, best_alpha, summary_text, diag_path, pred_path


def run_elasticnet_experiment(experiment, endowments, config, output_dir):
    model, best_alpha, best_l1_ratio, diagnostics = fit_elastic_net_model(experiment, config)

    fig_diag = plot_elasticnet_diagnostics(diagnostics, best_alpha, best_l1_ratio, style=config.get("elasticnet", {}).get("plot_style", "2D"))
    fig_pred = plot_model_predictions(model, experiment, config, method_label='Elastic Net Regression')

    diag_path = os.path.join(output_dir, "diagnostics_elasticnet.png")
    pred_path = os.path.join(output_dir, "predictions_elasticnet.png")
    fig_diag.savefig(diag_path, transparent=True, bbox_inches='tight')
    fig_pred.savefig(pred_path, transparent=True, bbox_inches='tight')
    plt.close(fig_diag)
    plt.close(fig_pred)

    summary_text = capture_stdout(summarize_sparse_model_selection, model, experiment, endowments, verbose=True)

    return model, (best_alpha, best_l1_ratio), summary_text, diag_path, pred_path


def run_post_selection_ols(model, experiment, config, output_dir, model_name="lasso"):
    """
    Refit OLS on the features selected by a sparse model and save prediction plot.

    Parameters:
        model: Trained sparse model (e.g., Lasso or ElasticNet) with `coef_` attribute
        experiment: EmpiricalExperiment object
        config: YAML config dict
        output_dir: path to write prediction plot
        model_name: selection method, used to annotate the plot and regression results
    """
    # Identify selected features (nonzero coefficients)
    df_train = experiment.get_dataframe_by_split(config["split_settings"]["train_val_split"][0], proxy_only=config["split_settings"]["use_proxy_only"])
    df_val = experiment.get_dataframe_by_split(config["split_settings"]["train_val_split"][1], proxy_only=config["split_settings"]["use_proxy_only"])
    df_all = pd.concat([df_train, df_val])

    X_all = df_all.drop(columns=["aggregate"]).astype(float).values
    y_all = df_all["aggregate"].astype(float).values
    feature_cols = [col for col in df_all.columns if col != "aggregate"]

    ols_model, selected_features = fit_post_selection_ols(model, X_all, y_all, feature_cols)

    # Generate and save prediction plot
    fig = plot_model_predictions(ols_model, experiment, config, selected_features=selected_features, method_label= f'Post-Selection OLS Refit ({model_name.upper()})')
    plot_path = os.path.join(output_dir, f"predictions_post_selection_ols_{model_name}.png")
    fig.savefig(plot_path, transparent=True, bbox_inches='tight')
    plt.close(fig)

    return ols_model, selected_features, plot_path

def write_summary_md(
    output_dir,
    exp_name,
    exp_desc,
    config_path,
    config,
    timestamp,
    all_results,
    updated_endow_path,
    weight_assignment_key
):
    """
    Write a markdown summary report combining multiple regression results.

    Parameters:
        output_dir (str): Path to output directory.
        exp_name (str): Experiment name.
        exp_desc (str): Description of the experiment.
        config_path (str): Path to the YAML config file used to run the experiment.
        config (dict): Parsed YAML configuration used in the experiment.
        timestamp (str): Timestamp for the run.
        all_results (list of dict): Each dict should have:
            - name (e.g., 'lasso', 'elasticnet', etc.)
            - model (fitted sklearn model)
            - best_params (tuple or float)
            - summary_text (str)
            - diagnostics_path (str)
            - prediction_path (str)
        updated_endow_path (str): Path to saved updated endowments.
        weight_assignment_key (str): Which model's weights were assigned.
    """
    summary_md = os.path.join(output_dir, "summary.md")
    with open(summary_md, "w") as f:
        f.write(f"# Regression Experiment Report\n\n")
        f.write(f"## Basic Information\n\n")
        f.write(f"**Experiment Name**: `{exp_name}`  \n")
        f.write(f"**Description**: {exp_desc}  \n")
        f.write(f"**Config File**: `{config_path}`  \n")
        f.write(f"**Run Timestamp**: {timestamp}  \n")
        f.write(f"**Updated Endowments Saved To**: `{updated_endow_path}`  \n")
        f.write(f"**Weights Assigned From**: `{weight_assignment_key}` model  \n\n")

        f.write("## Model Diagnostics and Results\n\n")

        for result in all_results:
            name = result["name"]
            best_params = result.get("hyperparams")
            diagnostics_path = (
                os.path.basename(result["diagnostics"]) if result.get("diagnostics") else None
            )
            prediction_path = (
                os.path.basename(result["predictions"]) if result.get("predictions") else None
            )
            summary_text = result.get("summary", "")

            f.write(f"### {name.upper()} Model\n\n")
            if best_params and 'selected_features' not in best_params:
                if isinstance(best_params, tuple):
                    f.write(f"**Best Alpha**: {best_params['alpha']:.2e}  \n")
                    f.write(f"**Best L1 Ratio**: {best_params['l1_ratio']:.2f}  \n")
                else:
                    f.write(f"**Best Alpha**: {best_params['alpha']:.2e}  \n")
            f.write("\n")

            if diagnostics_path:
                f.write("**Diagnostics Plot**:\n\n")
                f.write(f'<img src="{diagnostics_path}" alt="Diagnostics Plot" style="max-width:80%; margin-bottom:1em;">\n\n')

            if prediction_path:
                f.write("**Prediction Plot**:\n\n")
                f.write(f'<img src="{prediction_path}" alt="Prediction Plot" style="max-width:80%; margin-bottom:1em;">\n\n')

            if summary_text is not None and summary_text.strip():
                f.write("**Selection Summary**:\n\n")
                f.write("```text\n")
                f.write(summary_text.strip())
                f.write("\n```\n\n")

    # Convert to HTML if CSS is provided
    css_path = (
        config.get("report", {}).get("css") or
        "styles/lasso_report.css"
    )
    if css_path:
        css_basename = os.path.basename(css_path)
        css_dest = os.path.join(output_dir, css_basename)
        if not os.path.exists(css_dest):
            os.makedirs(os.path.dirname(css_dest), exist_ok=True)
            shutil.copy(css_path, css_dest)
        convert_md_to_html(summary_md, summary_md.replace(".md", ".html"), css=css_basename)

def run_regression_experiment(config_path, output_dir=None, return_all=False):
    # Load config
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    # Prepare output directory
    exp_name = config["experiment"]["name"]
    exp_desc = config["experiment"]["description"]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if output_dir is None:
        output_dir = os.path.join("outputs", f"{exp_name}_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    # Load components
    paths = config["paths"]
    survey = BinaryExtendedSurvey(csv_path=paths['survey_csv'], config_path=paths['survey_yaml'])
    responses = BinaryExtendedResponses(source=paths['responses_csv'], survey=survey, output_format='code')
    endowments = ActiveEndowments.load(path=paths['endowments_csv'])
    endowments.assign_roles()
    aggregate = AggregateResponses(survey=survey, json_path=paths['aggregate_json'])
    experiment = EmpiricalExperiment(
        responses=responses,
        survey=survey,
        endowments=endowments,
        aggregate_stats=aggregate.get_all_binary(),
        filter_binary=True,
        drop_na=True
    )

    all_results = []
    model_paths = {}

    # Run Lasso
    if "lasso" in config:
        lasso_model, alpha, summary, diag_path, pred_path = run_lasso_experiment(experiment, endowments, config, output_dir)
        assign_model_weights_to_endowments(lasso_model, experiment, endowments)
        metrics = compute_prediction_metrics(lasso_model, experiment, config)
        mode_summary = compute_mode_summary(lasso_model, endowments)
        all_results.append({
            "name": "lasso",
            "model": lasso_model,
            "hyperparams": {"alpha": alpha},
            "summary": summary,
            "diagnostics": diag_path,
            "predictions": pred_path,
            "prediction_metrics": metrics,
            "mode_summary": stringify_keys(mode_summary),
        })
        if config["lasso"].get("post_selection_refit", True):
            ols_model, selected_features, ols_pred_path = run_post_selection_ols(lasso_model, experiment, config, output_dir, model_name='lasso')
            all_results.append({
                "name": "lasso_post_ols",
                "model": ols_model,
                "hyperparams": {"selected_features": selected_features},
                "summary": None,
                "diagnostics": None,
                "predictions": ols_pred_path
            })
        lasso_save_path = os.path.join(output_dir, "lasso_model.joblib") 
        save_model(lasso_model, lasso_save_path)
        model_paths["lasso"] = lasso_save_path

    # Run ElasticNet
    if "elasticnet" in config:
        enet_model, (alpha, l1_ratio), summary, diag_path, pred_path = run_elasticnet_experiment(experiment, endowments, config, output_dir)
        assign_model_weights_to_endowments(enet_model, experiment, endowments)
        metrics = compute_prediction_metrics(enet_model, experiment, config)
        mode_summary = compute_mode_summary(enet_model, endowments)
        all_results.append({
            "name": "elasticnet",
            "model": enet_model,
            "hyperparams": {"alpha": alpha, "l1_ratio": l1_ratio},
            "summary": summary,
            "diagnostics": diag_path,
            "predictions": pred_path,
            "prediction_metrics": metrics,
            "mode_summary": stringify_keys(mode_summary),
        })
        if config["elasticnet"].get("post_selection_refit", True):
            ols_model, selected_features, ols_pred_path = run_post_selection_ols(enet_model, experiment, config, output_dir, model_name='elasticnet')
            all_results.append({
                "name": "elasticnet_post_ols",
                "model": ols_model,
                "hyperparams": {"selected_features": selected_features},
                "summary": None,
                "diagnostics": None,
                "predictions": ols_pred_path
            })
        enet_save_path = os.path.join(output_dir, "elasticnet_model.joblib")
        save_model(enet_model, enet_save_path)
        model_paths["elasticnet"] = enet_save_path

    # Save final endowments
    # Assign weights from the specified model
    weight_assignment_key = config.get("weight_assignment_model", "lasso")  # default to lasso
    selected_result = next((r for r in all_results if r["name"] == weight_assignment_key), None)

    if selected_result is None:
        raise ValueError(f"No model named '{weight_assignment_key}' found in experiment results for weight assignment.")
    
    selected_features = selected_result.get("hyperparams", {}).get("selected_features")

    assign_model_weights_to_endowments(selected_result["model"], experiment, endowments, selected_features=selected_features)

    updated_endow_path = os.path.join(output_dir, "updated_endowments.csv")
    endowments.save(updated_endow_path)

    # Save full config
    config["model_paths"] = model_paths
    with open(os.path.join(output_dir, "metadata.yaml"), "w") as f:
        yaml.dump(config, f)

    # Write summary markdown + HTML
    write_summary_md(
        output_dir=output_dir,
        exp_name=exp_name,
        exp_desc=exp_desc,
        config_path=config_path,
        config=config,
        timestamp=timestamp,
        all_results=all_results,
        updated_endow_path=updated_endow_path,
        weight_assignment_key=weight_assignment_key
    )

    # Save aggregate and mode summary tables
    df_agg = convert_aggregate_metrics_to_df(all_results)
    df_mode = convert_mode_summary_to_df(all_results)

    df_agg.to_csv(os.path.join(output_dir, "summary_aggregate.csv"), index=False)
    df_mode.to_csv(os.path.join(output_dir, "summary_by_mode.csv"), index=False)

    if return_all:
        return experiment, df_agg, df_mode, all_results

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    parser.add_argument("--output", type=str, default=None, help="Optional output directory")
    args = parser.parse_args()
    run_regression_experiment(args.config, args.output)