import os
import sys
import yaml
import argparse
import subprocess
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from io import StringIO
from datetime import datetime
from sklearn.linear_model import Lasso
from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.endowment_manager import ActiveEndowments
from modules.response_converter import BinaryExtendedResponses
from modules.aggregate_responses import AggregateResponses
from experiments.experiments import EmpiricalExperiment
from experiments.utils import (
    fit_lasso_model, plot_lasso_diagnostics, plot_model_predictions,
    summarize_sparse_model_selection, assign_model_weights_to_endowments
)

def get_project_root():
    """Return the absolute path to the project root (one level up from this script)."""
    return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

def capture_stdout(func, *args, **kwargs):
    """Capture and return stdout printed by a function."""
    buf = StringIO()
    sys.stdout = buf
    func(*args, **kwargs)
    sys.stdout = sys.__stdout__
    return buf.getvalue()

def convert_md_to_html(md_path, html_path, css):
    subprocess.run([
        "pandoc", md_path,
        "-o", html_path,
        "--standalone",
        f"--css={css}"
    ])

def run_lasso_experiment(config_path, output_dir=None):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    # Resolve paths
    project_root = get_project_root()
    paths = {
        key: os.path.abspath(os.path.join(project_root, rel_path))
        for key, rel_path in config["paths"].items()
    }

    exp_name = config["experiment"]["name"]
    exp_desc = config["experiment"].get("description", "")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    if output_dir is None:
        output_dir = os.path.join("outputs", f"{exp_name}_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    # Load components
    survey = BinaryExtendedSurvey(csv_path=paths['survey_csv'], config_path=paths['survey_yaml'])
    responses = BinaryExtendedResponses(source=paths['responses_csv'], survey=survey, output_format='code')
    endowments = ActiveEndowments.load(path=paths['endowments_csv'])
    endowments.assign_roles()
    aggregate = AggregateResponses(survey=survey, json_path=paths['aggregate_json'])
    aggregate_dict = aggregate.get_all_binary()

    experiment = EmpiricalExperiment(
        responses=responses,
        survey=survey,
        endowments=endowments,
        aggregate_stats=aggregate_dict,
        filter_binary=True,
        drop_na=True
    )

    # Run Lasso
    model, best_alpha, diagnostics = fit_lasso_model(experiment, config)

    # Save plots
    plot_path1 = os.path.join(output_dir, "diagnostics.png")
    plot_path2 = os.path.join(output_dir, "predictions.png")

    fig1 = plot_lasso_diagnostics(diagnostics, best_alpha, strategy=config["validation"]["strategy"])
    plt.savefig(plot_path1, transparent=True, bbox_inches='tight')
    plt.close(fig1)

    fig2= plot_model_predictions(model, experiment, config)
    plt.savefig(plot_path2, transparent=True, bbox_inches='tight')
    plt.close(fig2)

    # Capture summary text and assign weights
    summary_text = capture_stdout(summarize_sparse_model_selection, model, experiment, endowments, verbose=True)
    assign_model_weights_to_endowments(model, experiment, endowments)

    # Save updated endowments
    updated_endow_path = os.path.join(output_dir, "updated_endowments.csv")
    endowments.save(updated_endow_path)

    # Save full config snapshot
    with open(os.path.join(output_dir, "metadata.yaml"), "w") as f:
        yaml.dump(config, f)

    # Save metadata and MD report
    summary_md = os.path.join(output_dir, "summary_lasso.md")
    with open(summary_md, "w") as f:
        f.write(f"# Lasso Experiment Report\n\n")
        f.write(f"## Basic Information\n\n")
        f.write(f"**Experiment Name**: `{exp_name}`  \n")
        f.write(f"**Description**: {exp_desc}  \n")
        f.write(f"**Config File**: `{config_path}`  \n")
        f.write(f"**Run Timestamp**: {timestamp}  \n")
        f.write(f"**Best Alpha**: {best_alpha:.2e}  \n")
        f.write(f"**Updated Endowments Saved to**: `{updated_endow_path}`\n\n")
        f.write(f"## Exploratory Data Analysis\n\n")
        f.write(f"### Diagnostics Plot\n\n")
        f.write(f'<img src="{os.path.basename(plot_path1)}" alt="Diagnostics Plot" style="max-width:80%; margin-bottom:1em;">\n\n')
        f.write(f"### Predictions Plot\n\n")
        f.write(f'<img src="{os.path.basename(plot_path2)}" alt="Predictions Plot" style="max-width:80%; margin-bottom:1em;">\n\n')
        f.write("## Lasso Selection Summary\n\n")
        f.write("```text\n")
        f.write(summary_text)
        f.write("```")

    # Convert summary.md to summary.html
    css_path = (
        config.get("report", {}).get("lasso") or
        config.get("report", {}).get("css") or
        "styles/lasso_report.css"
    )
    if css_path:
        css_basename = os.path.basename(css_path)
        css_dest = os.path.join(output_dir, css_basename)
        if not os.path.exists(css_dest):
            os.makedirs(os.path.dirname(css_dest), exist_ok=True)
            shutil.copy(css_path, css_dest)
        # Use relative path so HTML and CSS are bundled together
        relative_css = css_basename
        convert_md_to_html(summary_md, summary_md.replace(".md", ".html"), css=relative_css)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    parser.add_argument("--output", type=str, default=None, help="Optional output directory")
    args = parser.parse_args()
    run_lasso_experiment(args.config, args.output)