import os
import sys
import yaml
import argparse
import shutil
import subprocess
import logging
import json
import pandas as pd
from datetime import datetime

def load_system_prompt_from_cfg(em_cfg, default_prompt: str):
    """
    Accepts either:
      - em_cfg["system_prompt"]: a literal multi-line string, or
      - em_cfg["system_prompt_path"]: a file path to read from.
    Falls back to default_prompt if neither provided.
    """
    if em_cfg is None:
        return default_prompt

    # Prefer explicit file path if given
    path = em_cfg.get("system_prompt_path")
    if path:
        with open(path, "r", encoding="utf-8") as f:
            return f.read()

    # Otherwise look for inline prompt
    prompt = em_cfg.get("system_prompt")
    if isinstance(prompt, str) and prompt.strip():
        return prompt

    return default_prompt

EM_SYSTEM_PROMPT = """You are an expert assistant trained to generate realistic, diverse, and demographically plausible personas for social science surveys. 

Each persona should include:
- `eid`: a short, lowercase, variable-safe identifier that encodes key traits (e.g., urban_liberal_30s_female). No punctuation or spaces.
- `endow_text`: a natural language description of the persona, written as if describing a survey respondent.

Instructions:
- Represent a wide range of age, gender, race/ethnicity, education, region, and political ideology.
- Avoid repetition of phrasing or demographic combinations across personas.
- Do not include explanations or formatting outside of the JSON array.
"""

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from modules.dataclasses import AgentConfig
from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.response_converter import Responses, BinaryExtendedResponses
from modules.endowment_manager import ActiveEndowments
from modules.endowment_model import EndowmentModel
from modules.attribute_learner import AttributeLearner
from modules.token_tracker import MultiAxisTokenTracker
from generators.active_endowment_generator import ActiveEndowmentGenerator
from generators.utils import *

def load_config(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)

def convert_md_to_html(md_path, html_path, css):
    subprocess.run([
        "pandoc", md_path,
        "-o", html_path,
        "--standalone",
        f"--css={css}"
    ])

def stringify_mode(mode):
    return "+".join(mode) if isinstance(mode, (tuple, list)) else str(mode)


def write_summary_md(output_dir, meta, config_path, endowments_csv, responses_csv, mca_plot_path, spectrum_plot_path, entropy_path, variability_path, q_entropy_path, tracker, mode_summary_csv, mode_entropy_csv, 
                     token_tracker, token_usage_csv: str = None, token_usage_json: str = None):
    summary_path = os.path.join(output_dir, "summary_endowments.md")
    with open(summary_path, "w") as f:
        f.write(f"# Endowment Generation Report\n\n")
        f.write(f"## Metadata\n")
        f.write(f"**Name**: `{meta['name']}`  \n")
        f.write(f"**Description**: {meta.get('description', 'N/A')}  \n")
        f.write(f"**Seed**: `{meta.get('seed', 'N/A')}`  \n")
        f.write(f"**Config File**: `{config_path}`  \n")
        f.write(f"**Generated At**: `{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}`  \n\n")

        f.write(f"## Output Files\n")
        f.write(f"- Endowments: `{os.path.basename(endowments_csv)}`  \n")
        f.write(f"- Responses: `{os.path.basename(responses_csv)}`  \n")
        f.write(f"- Entropy Chart: `{os.path.basename(entropy_path)}`  \n")
        f.write(f"- Theme Trajectories: `{os.path.basename(variability_path)}`  \n")
        f.write(f"- Question Entropy Trajectories: `{os.path.basename(q_entropy_path)}`  \n\n")
        f.write(f"- Mode Summary: `{os.path.basename(mode_summary_csv)}`  \n\n")
        f.write(f"- Persistent Mode Entropy: `{os.path.basename(mode_entropy_csv)}`  \n")
        if token_usage_csv:
            f.write(f"- Token Usage CSV: `{os.path.basename(token_usage_csv)}`  \n")  # ✅
        if token_usage_json:
            f.write(f"- Token Usage JSON: `{os.path.basename(token_usage_json)}`  \n")  # ✅


        f.write(f"## Diagnostic Plots\n\n")

        f.write(f"### Effective Rank Analysis\n\n")
        f.write(f'<img src="{os.path.basename(spectrum_plot_path)}" alt="Spectrum Plot" style="max-width:80%; margin-bottom:1em;">\n\n')


        f.write(f"### Multiple Correspondence Analysis\n\n")
        f.write(f'<img src="{os.path.basename(mca_plot_path)}" alt="MCA Plot" style="max-width:80%; margin-bottom:1em;">\n\n')

        f.write(f"### Entropy Bar Chart\n\n")
        f.write(f'<img src="{os.path.basename(entropy_path)}" alt="Entropy Chart" style="max-width:80%; margin-bottom:1em;">\n\n')

        f.write(f"### Theme Variability Trajectories\n\n")
        f.write(f'<img src="{os.path.basename(variability_path)}" alt="Theme Variability" style="max-width:80%; margin-bottom:1em;">\n\n')

        f.write(f"### Question Entropy Trajectories\n\n")
        f.write(f'<img src="{os.path.basename(q_entropy_path)}" alt="Question Entropy" style="max-width:80%; margin-bottom:1em;">\n\n')

        f.write(f"## Diagnostics Data\n\n")

        top_modes = sorted(tracker.get_scores().items(), key=lambda x: x[1], reverse=True)[:5]
        f.write(f"### Top Variable Modes\n\n")
        f.write("| Mode | Score |\n")
        f.write("|------|--------|\n")
        for mode, score in top_modes:
            mode_str = "+".join(mode)
            f.write(f"| `{mode_str}` | `{score:.4f}` |\n")

        # Get full entropy dict and top-k sorted entries
        entropy_dict = tracker.compute_global_question_entropy()
        top_k = 5
        low_entropy_qs = sorted(entropy_dict.items(), key=lambda x: x[1])[:top_k]

        f.write(f"\n### Low Entropy Questions\n\n")
        f.write("| Question ID | Entropy |\n")
        f.write("|-------------|---------|\n")
        for qid, entropy in low_entropy_qs:
            f.write(f"| `{qid}` | `{entropy:.4f}` |\n")

        f.write(f"\n## Mode Summary Table\n\n")
        mode_df = pd.read_csv(mode_summary_csv)

        if mode_df is not None and not mode_df.empty:
            f.write("| Mode | # Endowments | Attributes |\n")
            f.write("|------|---------------|------------|\n")
            for _, row in mode_df.iterrows():
                mode_str = row["mode"]
                n = int(row["n"])
                attributes = row["attributes"]
                f.write(f"| `{mode_str}` | `{n}` | {attributes} |\n")
        else:
            f.write("No mode summary available.\n")

        if token_tracker is not None:
            s = token_tracker.summary()
            f.write("\n## Token Usage Summary (by Model and Module)\n\n")
            f.write("| Module | Model | Input Tokens | Output Tokens | Cached Input Tokens | Cost |\n")
            f.write("|--------|-------|--------------:|--------------:|--------------------:|-----:|\n")
            for key, vals in s.get("per_model_module", {}).items():
                # key format: "model | module"
                parts = [p.strip() for p in key.split("|")]
                model = parts[0] if len(parts) > 0 else "unknown"
                module = parts[1] if len(parts) > 1 else "unspecified"
                f.write(
                    f"| {module} | {model} | "
                    f"{int(vals.get('total_input_tokens', 0))} | "
                    f"{int(vals.get('total_output_tokens', 0))} | "
                    f"{int(vals.get('total_cached_input_tokens', 0))} | "
                    f"${vals.get('total_cost', 0.0):.4f} |\n"
                )

            f.write(f"\n**Total Estimated Cost:** `${s.get('total_cost', 0.0):.4f}`\n")

    return summary_path


def main(config_path, output_dir: str = None, return_outputs: bool = False):
    token_tracker = MultiAxisTokenTracker()

    cfg = load_config(config_path)
    meta = cfg["metadata"]
    paths = cfg["paths"]
    agent_cfg = cfg["agent"]
    gen_cfg = cfg["generation"]

    # Create timestamped output folder
    # Use provided output_dir or default to timestamped directory
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"{meta['name']}_{timestamp}")

    os.makedirs(output_dir, exist_ok=True)

    log_file = os.path.join(output_dir, "run.log")
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear existing handlers
    
    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    # File handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    if not cfg["generation"].get("verbose", True):
        logger.setLevel(logging.WARNING)
    
    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    logger.info(f"[Runner] Output directory: {output_dir}")
    logger.info("[Runner] Starting endowment generation...")

    # Output file paths
    responses_csv = os.path.join(output_dir, "responses.csv")
    endowments_csv = os.path.join(output_dir, "endowments.csv")
    attribute_bank_yaml = os.path.join(output_dir, "attribute_bank.yaml")
    metadata_yaml = os.path.join(output_dir, "metadata.yaml")

    entropy_path = os.path.join(output_dir, "entropy_bar_chart.png")
    mca_plot_path = os.path.join(output_dir, "mca_plot.png")
    spectrum_plot_path = os.path.join(output_dir, "singular_value.png")
    variability_path = os.path.join(output_dir, "theme_variability_trajectories.png")
    q_entropy_path = os.path.join(output_dir, "question_entropy_trajectories.png")
    token_usage_csv = os.path.join(output_dir, "token_usage_summary.csv")
    token_usage_json = os.path.join(output_dir, "token_usage_summary.json")

    # Save metadata snapshot
    with open(metadata_yaml, "w") as f:
        yaml.safe_dump(cfg, f)

    # Load components
    survey = Survey(paths["survey_csv"], paths["survey_yaml"])

    # Split if not already done
    if all(q.get("split", "").strip() == "" for q in survey.get_questions()):
        logger.info("[Runner] No split found in survey. Performing split now...")
        survey.split_questions(seed=meta["seed"], save=True)
    else:
        logger.info("[Runner] Using existing split in survey.")

    with open(paths["attribute_bank"], "r") as f:
        attribute_bank = yaml.safe_load(f)

    agent_config = AgentConfig(
        agent_type=agent_cfg["type"],
        model_name=agent_cfg["model_name"],
        formality=agent_cfg["formality"],
        agent_kwargs=agent_cfg.get("kwargs", {})
    )

    em_cfg = cfg["endowment_model"]
    system_prompt = load_system_prompt_from_cfg(em_cfg, EM_SYSTEM_PROMPT)

    endowment_model = EndowmentModel(
        backend_type=em_cfg["backend_type"],
        model_name=em_cfg["model_name"],
        temperature=em_cfg["temperature"],
        batch_size=em_cfg["batch_size"],
        delay=em_cfg["delay"],
        retry_failed=em_cfg["retry_failed"],
        max_attributes=em_cfg["max_attributes"],
        randomize_attributes=em_cfg["randomize_attributes"],
        logger=logger,
        verbose=gen_cfg.get("verbose", True),
        token_tracker=token_tracker,
        system_prompt=system_prompt,
    )

    logger.info("[Runner] EndowmentModel system prompt source: %s",
            "system_prompt_path" if em_cfg.get("system_prompt_path") else
            ("system_prompt (inline)" if em_cfg.get("system_prompt") else "DEFAULT"))

    al_cfg = cfg["attribute_learner"]
    attribute_learner = AttributeLearner(
        survey=survey,
        backend_type=al_cfg["backend_type"],
        model_name=al_cfg["model_name"],
        max_tokens=al_cfg["max_tokens"],
        token_tracker=token_tracker
    )

    generator = ActiveEndowmentGenerator(
        survey=survey,
        endowment_model=endowment_model,
        attribute_bank=attribute_bank,
        attribute_learner=attribute_learner,
        agent_config=agent_config,
        target_n=gen_cfg["target_n"],
        initial_n=gen_cfg["initial_n"],
        num_update_steps=gen_cfg["num_update_steps"],
        seed=meta["seed"],
        parallel=gen_cfg["parallel"],
        max_workers=gen_cfg.get("max_workers", 10),
        question_patch=gen_cfg.get("question_patch", {}),
        verbose=gen_cfg.get("verbose", True),
        logger=logger,
        token_tracker=token_tracker
    )

    # Run generation
    generator.generate()

    logger.info("[Runner] Saving endowments, responses, and diagnostics...")

    # Save outputs
    generator.export_responses_to_csv(responses_csv, format="answer")
    generator.endowment_manager.save(endowments_csv)
    generator.export_attribute_bank(attribute_bank_yaml)

    mode_summary_csv = os.path.join(output_dir, "mode_summary.csv")
    generator.export_mode_summary(mode_summary_csv)

    # Save entropy diagnostics
    entropy_trajectories = generator.tracker.get_entropy_trajectories()

    # Save aggregate entropy trajectory
    agg_entropy_csv = os.path.join(output_dir, "aggregate_entropy_trajectory.csv")
    pd.DataFrame({"entropy": entropy_trajectories["aggregate"]}).to_csv(agg_entropy_csv, index_label="step")

    # Save per-mode entropy trajectory
    # Convert mode tuples to strings
    by_mode = {
        stringify_mode(mode): entropy_list
        for mode, entropy_list in entropy_trajectories["by_mode"].items()
    }

    mode_entropy_df = pd.DataFrame(by_mode)
    mode_entropy_csv = os.path.join(output_dir, "persistent_mode_entropy.csv")
    mode_entropy_df.to_csv(mode_entropy_csv, index_label="step")

    # Save per-question entropy trajectory
    question_entropy_csv = os.path.join(output_dir, "question_entropy_trajectories.csv")
    qid_df = pd.DataFrame(entropy_trajectories["by_question"])
    qid_df.to_csv(question_entropy_csv, index_label="step")    

    # Identify top mode by final entropy
    final_mode_entropies = {
        "+".join(mode): values[-1] for mode, values in entropy_trajectories["by_mode"].items() if values
    }
    top_mode, top_mode_entropy = max(final_mode_entropies.items(), key=lambda x: x[1], default=(None, None))

    # Identify lowest-entropy mode
    lowest_mode, lowest_mode_entropy = min(final_mode_entropies.items(), key=lambda x: x[1], default=(None, None))

    # Re-import the responses for the diagnostics
    endowments = ActiveEndowments(endowments_csv)
    responses = Responses(source = responses_csv, survey = survey, output_format = 'code')
    binary_survey = BinaryExtendedSurvey.from_survey(survey)
    binary_responses = BinaryExtendedResponses(source = responses_csv, survey = binary_survey, output_format = 'code')
    with open(attribute_bank_yaml, "r") as f:
        attribute_banks = yaml.safe_load(f)

    df_train = responses.get_matrix_by_split('train', dropna=True)
    df_val = responses.get_matrix_by_split('valid', dropna=True)
    df = pd.concat([df_train,df_val])

    plot_mca_agents_by_template_grouped(df= df, 
                                        endowments= endowments, 
                                        attribute_bank= attribute_banks,
                                        save_path=mca_plot_path)
    
    bin_df_train = binary_responses.get_matrix_by_split('train', dropna=True)
    bin_df_val = binary_responses.get_matrix_by_split('valid', dropna=True)
    bin_df = pd.concat([bin_df_train,bin_df_val])
    bin_df = bin_df.apply(pd.to_numeric, errors='coerce')
    s, probs = get_singular_value_spectrum(bin_df)
    effective_rank = compute_effective_rank(bin_df)
    plot_singular_value_spectrum(s, probs, effective_rank=effective_rank, cumulative= True, save_path=spectrum_plot_path)    

    generator.tracker.plot_entropy_bar_chart(save_path=entropy_path)
    generator.tracker.plot_theme_variability_trajectories(attribute_bank=attribute_bank, save_path=variability_path)
    generator.tracker.plot_question_entropy_trajectories(save_path=q_entropy_path)

    # Export token usage artifacts and log a human-readable report
    token_tracker.save(token_usage_csv)         # CSV
    token_tracker.to_json(token_usage_json)     # JSON
    logger.info("[Runner] Token usage reports saved.")
    logger.info("\n" + token_tracker.report())



    results = {
        "output_dir": output_dir,
        "num_endowments": len(generator),
        "num_modes": len(generator.tracker.theme_variability_history),
        "final_entropy": (
            entropy_trajectories["aggregate"][-1] if entropy_trajectories["aggregate"] else None
        ),
        "entropy_steps": len(entropy_trajectories["aggregate"]),
        "effective_rank": effective_rank,
        "top_mode": top_mode,
        "top_mode_entropy": top_mode_entropy,
        "lowest_mode": lowest_mode,
        "lowest_mode_entropy": lowest_mode_entropy,
        "token_usage_summary": token_tracker.summary(),
        "token_usage_csv": token_usage_csv,
        "token_usage_json": token_usage_json,
    }

    results_json = os.path.join(output_dir, "results.json")
    results["results_json"] = results_json
    with open(results_json, "w") as f:
        json.dump(results, f, indent=2)    

    results["config"] = cfg

    # Save markdown summary
    summary_md = write_summary_md(
        output_dir, meta, config_path, endowments_csv, responses_csv,
        mca_plot_path, spectrum_plot_path, entropy_path, variability_path, q_entropy_path, generator.tracker,
        mode_summary_csv, mode_entropy_csv,
        token_tracker,
        token_usage_csv,
        token_usage_json
    )

    # Optional: convert to HTML with CSS
    css_path = (
        cfg.get("report", {}).get("endowment") or
        cfg.get("report", {}).get("css") or
        "styles/generation_report.css"
    )
    if css_path:
        css_basename = os.path.basename(css_path)
        css_dest = os.path.join(output_dir, css_basename)
        if not os.path.exists(css_dest):
            os.makedirs(os.path.dirname(css_dest), exist_ok=True)
            shutil.copy(css_path, css_dest)
        convert_md_to_html(
            summary_md,
            summary_md.replace(".md", ".html"),
            css=css_basename
        )
    
    logger.info("[Runner] Summary report saved. Run complete.")
    
    return (generator, results) if return_outputs else results

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True,
        help="(Required) Path to the YAML config file for endowment generation."
    )
    args = parser.parse_args()
    main(args.config)