import os
import sys
import yaml
import argparse
import shutil
import subprocess
import logging
import json
import pandas as pd
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from itertools import islice

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from modules.survey_converter import Survey
from modules.response_converter import Responses
from modules.endowment_manager import ActiveEndowments
from modules.dataclasses import AgentConfig
from modules.token_tracker import MultiAxisTokenTracker
from modules.survey_conductor import SurveyConductor
from generators.theme_variability_tracker import ThemeVariabilityTracker

def load_config(path: str) -> dict:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def convert_md_to_html(md_path: str, html_path: str, css: str):
    """
    Mirror AEG runner behavior: use pandoc with a CSS file.
    """
    subprocess.run(
        ["pandoc", md_path, "-o", html_path, "--standalone", f"--css={css}"],
        check=False
    )

def write_summary_plugin_md(
    output_dir: str,
    meta: dict,
    config_path: str,
    endowments_csv: str,
    responses_csv: str,
    entropy_bar_path: str,
    token_tracker=None,
    # Optional extras (include sections only if provided & exist)
    agg_entropy_csv: str = None,
    question_entropy_csv: str = None,
    mca_plot_path: str = None,
    spectrum_plot_path: str = None,
    token_usage_csv: str = None,
    token_usage_json: str = None,
    tracker=None,
) -> str:
    """
    Create a concise Parallel Elicitation report in Markdown.

    Sections included:
      - Metadata
      - Output Files (only those that exist)
      - Diagnostic Plots:
          * Effective Rank Analysis (if spectrum_plot_path exists)
          * Multiple Correspondence Analysis (if mca_plot_path exists)
          * Entropy Bar Chart (required)
      - Diagnostics Data (if tracker provided):
          * Top Variable Modes (up to 5)
          * Low Entropy Questions (up to 5)
      - Token Usage Summary (by Model and Module)
        NOTE: In the plugin pipeline, usage is logged by SurveyConductor only.

    Returns:
        str: Path to the generated Markdown file.
    """
    os.makedirs(output_dir, exist_ok=True)
    summary_path = os.path.join(output_dir, "summary_plugin_elicitation.md")

    def _exists(p):
        return isinstance(p, str) and os.path.exists(p)

    with open(summary_path, "w", encoding="utf-8") as f:
        # Title
        f.write("# Parallel Elicitation Report\n\n")

        # Metadata
        f.write("## Metadata\n")
        f.write(f"**Name**: `{meta.get('name', 'N/A')}`  \n")
        f.write(f"**Description**: {meta.get('description', 'N/A')}  \n")
        f.write(f"**Seed**: `{meta.get('seed', 'N/A')}`  \n")
        f.write(f"**Config File**: `{config_path}`  \n")
        f.write(f"**Generated At**: `{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}`  \n\n")

        # Output Files
        f.write("## Output Files\n")
        if _exists(endowments_csv):
            f.write(f"- Endowments: `{os.path.basename(endowments_csv)}`  \n")
        if _exists(responses_csv):
            f.write(f"- Responses: `{os.path.basename(responses_csv)}`  \n")
        if _exists(agg_entropy_csv):
            f.write(f"- Aggregate Entropy Trajectory (CSV): `{os.path.basename(agg_entropy_csv)}`  \n")
        if _exists(question_entropy_csv):
            f.write(f"- Question Entropy Trajectories (CSV): `{os.path.basename(question_entropy_csv)}`  \n")
        if _exists(token_usage_csv):
            f.write(f"- Token Usage CSV: `{os.path.basename(token_usage_csv)}`  \n")
        if _exists(token_usage_json):
            f.write(f"- Token Usage JSON: `{os.path.basename(token_usage_json)}`  \n")
        f.write("\n")

        # Diagnostic Plots
        f.write("## Diagnostic Plots\n\n")

        # Effective Rank (only if provided)
        if _exists(spectrum_plot_path):
            f.write("### Effective Rank Analysis\n\n")
            f.write(f'<img src="{os.path.basename(spectrum_plot_path)}" alt="Spectrum Plot" style="max-width:80%; margin-bottom:1em;">\n\n')

        # MCA (only if provided)
        if _exists(mca_plot_path):
            f.write("### Multiple Correspondence Analysis\n\n")
            f.write(f'<img src="{os.path.basename(mca_plot_path)}" alt="MCA Plot" style="max-width:80%; margin-bottom:1em;">\n\n')

        # Entropy Bar (required for plugin; we already generate this PNG)
        if _exists(entropy_bar_path):
            f.write("### Entropy Bar Chart\n\n")
            f.write(f'<img src="{os.path.basename(entropy_bar_path)}" alt="Entropy Chart" style="max-width:80%; margin-bottom:1em;">\n\n')
        else:
            f.write("_[Entropy bar chart not found – expected image path was not provided or does not exist.]_\n\n")

        # Diagnostics Data (optional; relies on ThemeVariabilityTracker)
        if tracker is not None:
            f.write("## Diagnostics Data\n\n")

            # Top Variable Modes (will just show 'plugin' mode here, but keep for convenience)
            try:
                top_modes = sorted(tracker.get_scores().items(), key=lambda x: x[1], reverse=True)[:5]
            except Exception:
                top_modes = []

            f.write("### Top Variable Modes\n\n")
            if top_modes:
                f.write("| Mode | Score |\n")
                f.write("|------|-------:|\n")
                for mode, score in top_modes:
                    mode_str = "+".join(mode) if isinstance(mode, (tuple, list)) else str(mode)
                    f.write(f"| `{mode_str}` | `{score:.4f}` |\n")
            else:
                f.write("_No mode variability data available._\n")
            f.write("\n")

            # Low Entropy Questions
            try:
                entropy_dict = tracker.compute_global_question_entropy()
            except Exception:
                entropy_dict = {}

            f.write("### Low Entropy Questions\n\n")
            if entropy_dict:
                low_entropy_qs = sorted(entropy_dict.items(), key=lambda x: x[1])[:5]
                f.write("| Question ID | Entropy |\n")
                f.write("|-------------|--------:|\n")
                for qid, ent in low_entropy_qs:
                    f.write(f"| `{qid}` | `{ent:.4f}` |\n")
            else:
                f.write("_No question entropy data available._\n")
            f.write("\n")

        # Token Usage Summary
        f.write("## Token Usage Summary (by Model and Module)\n\n")
        if token_tracker is not None:
            s = token_tracker.summary() or {}
            per_mm = s.get("per_model_module", {})
            total_cost = s.get("total_cost", 0.0)

            f.write("| Module | Model | Input Tokens | Output Tokens | Cached Input Tokens | Cost |\n")
            f.write("|--------|-------|-------------:|--------------:|--------------------:|-----:|\n")
            # In plugin runs, SurveyConductor is the only module that logs usage.
            for key, vals in per_mm.items():
                # key format used by MultiAxisTokenTracker: "model | module"
                parts = [p.strip() for p in key.split("|")]
                model = parts[0] if len(parts) > 0 else "unknown"
                module = parts[1] if len(parts) > 1 else "unspecified"
                f.write(
                    f"| {module} | {model} | "
                    f"{int(vals.get('total_input_tokens', 0))} | "
                    f"{int(vals.get('total_output_tokens', 0))} | "
                    f"{int(vals.get('total_cached_input_tokens', 0))} | "
                    f"${vals.get('total_cost', 0.0):.4f} |\n"
                )

            f.write(f"\n**Total Estimated Cost:** `${total_cost:.4f}`\n")
            f.write("\n")
        else:
            f.write("_Token tracker unavailable; no usage data to report._\n\n")

    return summary_path

def main(config_path, output_dir=None, return_outputs=False, token_tracker=None):
    # Reuse shared tracker if provided
    if token_tracker is None:
        token_tracker = MultiAxisTokenTracker()
    cfg = yaml.safe_load(open(config_path))

    meta = cfg["metadata"]
    paths = cfg["paths"]
    agent_cfg = cfg["agent"]
    run_cfg = cfg["runner"]

    # --- Output directory ---
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"{meta['name']}_parallel_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    # --- Logging setup ---
    log_file = os.path.join(output_dir, "run.log")
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear existing handlers

    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    # File handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # Respect verbosity from config
    if not run_cfg.get("verbose", True):
        logger.setLevel(logging.WARNING)

    logger.info(f"[Runner] Output directory: {output_dir}")
    logger.info("[Runner] Starting parallel elicitation...")

    # --- Predeclare output paths ---
    responses_csv = os.path.join(output_dir, "responses.csv")
    sampled_endowments_csv = os.path.join(output_dir, "endowments.csv")
    metadata_yaml = os.path.join(output_dir, "metadata.yaml")

    # Diagnostics/plots
    entropy_bar_path = os.path.join(output_dir, "question_entropy_bar_chart.png")
    agg_entropy_csv = os.path.join(output_dir, "aggregate_entropy_trajectory.csv")
    question_entropy_csv = os.path.join(output_dir, "question_entropy_trajectories.csv")

    # Token usage artifacts
    token_usage_csv = os.path.join(output_dir, "token_usage_summary.csv")
    token_usage_json = os.path.join(output_dir, "token_usage_summary.json")

    # Save metadata snapshot
    with open(metadata_yaml, "w") as f:
        yaml.safe_dump(cfg, f)

    # --- Load survey & endowments ---
    survey = Survey(paths["survey_csv"], paths["survey_yaml"])

    # Split if not already done
    if all(q.get("split", "").strip() == "" for q in survey.get_questions()):
        logger.info("[Runner] No split found in survey. Performing split now...")
        survey.split_questions(seed=meta["seed"], save=True)
    else:
        logger.info("[Runner] Using existing split in survey.")

    endowments = ActiveEndowments(paths["endowments_csv"])

    # --- Optional budget-based subsampling ---
    target_n = run_cfg.get("target_n", None)
    if target_n is not None and target_n < len(endowments):
        logger.info(f"[Runner] Subsampling {target_n} endowments from {len(endowments)} total...")
        # role-aware subsampling: preserve ground_truth/proxy ratio
        endowments = endowments.clone_with_fraction(fraction=target_n / len(endowments), seed=meta.get("seed", 101))
    else:
        logger.info(f"[Runner] Using all {len(endowments)} endowments.")

    # Save working endowments to outputs
    endowments.save(sampled_endowments_csv)
    logger.info(f"[Runner] Saved working endowments to {sampled_endowments_csv}")

    # --- Build agent config ---
    agent_config = AgentConfig(
        agent_type=agent_cfg["type"],
        model_name=agent_cfg["model_name"],
        formality=agent_cfg["formality"],
        agent_kwargs=agent_cfg.get("kwargs", {})
    )

    # --- Parallelism config ---
    parallel = run_cfg.get("parallel", True)
    max_workers = run_cfg.get("max_workers", 10)

    # --- Prepare chunking (outer-level parallelism) ---
    endowment_list = endowments.endowments
    total = len(endowment_list)

    # Decide chunk size
    if parallel:
        factor = 2  # like ActiveEndowmentGenerator
        chunk_size = max(1, total // max(1, max_workers * factor))
    else:
        chunk_size = total

    def chunked_iter(iterable, n):
        it = iter(iterable)
        while True:
            chunk = list(islice(it, n))
            if not chunk:
                break
            yield chunk

    chunks = list(chunked_iter(endowment_list, chunk_size))
    logger.info(
        f"[Runner] Chunking {total} endowments into chunks of size {chunk_size}"
    )

    # --- Elicitation function (per chunk) ---
    def run_conductor(endowments_chunk):
        manager = ActiveEndowments.from_endowment_list(endowments_chunk)
        # Inject a synthetic mode for plugin runs
        for endowment in manager.endowments:
            endowment["mode"] = ("plugin",)
        conductor = SurveyConductor(
            survey=survey,
            endowments=manager,
            agent_type=agent_config.agent_type,
            model_name=agent_config.model_name,
            formality=agent_config.formality,
            verbose=run_cfg.get("verbose", True),
            token_tracker=token_tracker
        )
        conductor.run(save=False, parallel_worker=parallel, **agent_config.agent_kwargs)
        return conductor.to_agent_records()

    # --- Run survey (parallel over chunks) ---
    all_records = []
    if parallel:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(run_conductor, chunk) for chunk in chunks]
            for f in tqdm(as_completed(futures), total=len(futures), desc="Simulating agents"):
                records = f.result()
                all_records.extend(records)
    else:
        for chunk in chunks:
            records = run_conductor(chunk)
            all_records.extend(records)

    # Convert back into Responses object
    responses = Responses.from_agent_records(all_records, survey, output_format="answer")
    
    responses.save(responses_csv)
    logger.info(f"[Runner] Saved responses to {responses_csv}")

    # --- Entropy diagnostics ---
    tracker = ThemeVariabilityTracker(survey.get_questions_by_split("train"), smoothing=0)
    filtered_records = tracker.filter_records(all_records, survey=survey)
    tracker.update_from_records(filtered_records)
    entropy_trajectories = tracker.get_entropy_trajectories()

    # Save aggregate entropy trajectory
    pd.DataFrame({"entropy": entropy_trajectories["aggregate"]}).to_csv(agg_entropy_csv, index_label="step")

    # Save per-question entropy trajectory
    pd.DataFrame(entropy_trajectories["by_question"]).to_csv(question_entropy_csv, index_label="step")

    # Save entropy bar chart (final snapshot)
    tracker.plot_entropy_bar_chart(save_path=entropy_bar_path)

    # --- Export token usage artifacts ---
    token_tracker.save(token_usage_csv)     # CSV
    token_tracker.to_json(token_usage_json) # JSON
    logger.info("[Runner] Token usage reports saved.")
    logger.info("\n" + token_tracker.report())

    # --- Build and save results.json ---
    results = {
        "output_dir": output_dir,
        "num_endowments": len(endowments),
        "num_responses": len(responses),  # __len__ on Responses returns #unique EIDs
        "final_entropy": entropy_trajectories["aggregate"][-1] if entropy_trajectories["aggregate"] else None,
        "entropy_steps": len(entropy_trajectories["aggregate"]),
        "token_usage_summary": token_tracker.summary(),
        "token_usage_csv": token_usage_csv,
        "token_usage_json": token_usage_json,
    }
    results_json = os.path.join(output_dir, "results.json")
    with open(results_json, "w") as f:
        json.dump(results, f, indent=2)

    # --- Summary markdown + optional HTML conversion (mirror AEG flow) ---
    summary_md = write_summary_plugin_md(
        output_dir=output_dir,
        meta=meta,
        config_path=config_path,
        endowments_csv=sampled_endowments_csv,
        responses_csv=responses_csv,
        entropy_bar_path=entropy_bar_path,
        token_tracker=token_tracker,
        agg_entropy_csv=agg_entropy_csv,
        question_entropy_csv=question_entropy_csv,
        # Optional images left out here (MCA / spectrum) but supported by the writer:
        spectrum_plot_path=None,
        token_usage_csv=token_usage_csv,
        token_usage_json=token_usage_json,
        tracker=tracker,
    )

    css_path = (
        cfg.get("report", {}).get("plugin") or
        cfg.get("report", {}).get("css") or
        "styles/generation_report.css"
    )
    if css_path:
        css_basename = os.path.basename(css_path)
        css_dest = os.path.join(output_dir, css_basename)
        if not os.path.exists(css_dest):
            os.makedirs(os.path.dirname(css_dest), exist_ok=True)
            try:
                shutil.copy(css_path, css_dest)
            except Exception:
                # If CSS path isn't on disk (e.g., packaged), skip silently
                pass
        convert_md_to_html(
            summary_md,
            summary_md.replace(".md", ".html"),
            css=css_basename
        )

    logger.info("[Runner] Summary report saved. Run complete.")
    return (responses, results) if return_outputs else results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True,
        help="Path to YAML config file for running parallel survey."
    )
    args = parser.parse_args()
    main(args.config)