import os
import sys
import yaml
import argparse
import shutil
import subprocess
import pandas as pd
from datetime import datetime

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Imports
from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.response_converter import Responses, BinaryExtendedResponses, ResponseUtils
from modules.endowment_manager import ActiveEndowments
from modules.aggregate_responses import AggregateResponses
from experiments.experiments import EmpiricalExperiment
from experiments.fewshot_utils import *


def load_config(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)


def convert_md_to_html(md_path, html_path, css):
    subprocess.run([
        "pandoc", md_path,
        "-o", html_path,
        "--standalone",
        f"--css={css}"
    ])


def write_summary_md(output_dir, exp_name, exp_desc, config_path, timestamp,
                     metrics, plot_path=None):
    md_path = os.path.join(output_dir, "summary.md")
    with open(md_path, "w") as f:
        f.write(f"# Simple Average Experiment Report\n\n")
        f.write(f"## Basic Information\n\n")
        f.write(f"**Experiment Name**: `{exp_name}`  \n")
        f.write(f"**Description**: {exp_desc}  \n")
        f.write(f"**Config File**: `{config_path}`  \n")
        f.write(f"**Run Timestamp**: {timestamp}  \n\n")

        f.write("## Metrics\n\n")
        for k, v in metrics.items():
            f.write(f"- {k}: {v:.4f}\n")

        if plot_path and os.path.exists(plot_path):
            f.write("\n## Test Performance Plot\n\n")
            f.write(f'<img src="{os.path.basename(plot_path)}" alt="Simple Average Performance" '
                    f'style="max-width:80%; margin-bottom:1em;">\n\n')
    return md_path


def run_simple_average_experiment(config_path, output_dir=None, return_outputs=False):
    # === Load config
    cfg = load_config(config_path)
    exp_name = cfg["experiment"]["name"]
    exp_desc = cfg["experiment"]["description"]
    paths = cfg["paths"]

    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"{exp_name}_simple_avg_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    # === Copy config
    shutil.copy(config_path, os.path.join(output_dir, "original_config.yaml"))

    # === Load components
    survey = BinaryExtendedSurvey(csv_path=paths["survey_csv"], config_path=paths["survey_yaml"])
    survey_raw = Survey(csv_path=paths["survey_csv"], config_path=paths["survey_yaml"])
    responses_raw = Responses(source=paths["responses_csv"], survey=survey_raw, output_format="answer")
    responses = BinaryExtendedResponses(source=paths["responses_csv"], survey=survey, output_format="code")
    endowments = ActiveEndowments.load(path=paths["endowments_csv"])
    endowments.assign_roles()
    truth = AggregateResponses(json_path=paths["aggregate_json"], survey=survey)

    experiment = EmpiricalExperiment(
        responses=responses,
        survey=survey,
        endowments=endowments,
        aggregate_stats=truth.get_all_binary(),
        filter_binary=True,
        drop_na=True  # ensures NA agents are dropped consistently
    )

    # === Step 1: Determine valid agents (respecting drop_na + proxy role)
    df = experiment.get_augmented_matrix(proxy_only=True)  # respects drop_na
    valid_eids = [eid for eid in df.columns if eid != "aggregate"]
    agent_weights = {eid: 1.0 for eid in valid_eids}

    # === Step 2: Aggregate responses (uniform weights over valid proxies)
    agg_pred = ResponseUtils.save_answer_distributions(responses = responses_raw, path = os.path.join(output_dir, "agent_agg_results.json"))

    # === Step 3: Wrap predictions in AggregateResponses
    pred = AggregateResponses(aggregate_dict=agg_pred, survey=survey)

    # === Step 4: Compute metrics
    metrics = evaluate_fewshot_predictions(pred, truth, cfg)

    metrics_csv = os.path.join(output_dir, "simple_average_metrics.csv")
    pd.DataFrame([metrics]).to_csv(metrics_csv, index=False)

    # === Step 5: Plot predictions vs truth
    plot_path = os.path.join(output_dir, "simple_average_performance.png")
    fig = plot_fewshot_test_performance(pred, truth, save_path=None, method_label="Few-Shot")
    fig.savefig(plot_path, transparent=True, bbox_inches="tight")
    fig.clf()

    # === Step 6: Write summary report
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    summary_md = write_summary_md(
        output_dir, exp_name, exp_desc, config_path, timestamp, metrics, plot_path
    )

    css_path = cfg.get("report", {}).get("css", "styles/lasso_report.css")
    if css_path and os.path.exists(css_path):
        css_basename = os.path.basename(css_path)
        shutil.copy(css_path, os.path.join(output_dir, css_basename))
        convert_md_to_html(summary_md, summary_md.replace(".md", ".html"), css=css_basename)

    if return_outputs:
        return {"experiment": experiment, "metrics": metrics, "metrics_csv": metrics_csv, "plot": plot_path}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    parser.add_argument("--output", type=str, default=None, help="Optional output directory")
    args = parser.parse_args()

    run_simple_average_experiment(args.config, args.output)