# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import os
import json
import math
import yaml
import wandb
import statistics

# ----- #
# Configuration
# ----- #
ENABLE_SYNC = False
EVAL_TASKS = {
    "hellaswag":      ("acc_norm,none", "acc_norm_stderr,none", "acc_norm", "Hellaswag"),
    "piqa":           ("acc_norm,none", "acc_norm_stderr,none", "acc_norm", "PiQA"),
    "lambada_openai": ("acc,none",      "acc_stderr,none",      "acc",      "LAMBADA"),
    "arc_easy":       ("acc_norm,none", "acc_norm_stderr,none", "acc_norm", "Arc E."),
    "arc_challenge":  ("acc_norm,none", "acc_norm_stderr,none", "acc_norm", "Arc C."),
}
# ----- #


def get_run_id(vault_path: str) -> str:
    """Extract run_id from the .wandb file in latest-run folder."""
    wandb_latest = os.path.join(vault_path, "wandb", "latest-run")
    for fname in os.listdir(wandb_latest):
        if fname.startswith("run-") and fname.endswith(".wandb"):
            return fname[4:-6]
    raise FileNotFoundError("Could not find run-*.wandb file in latest-run folder")


def fetch_wandb_data(vault_path: str, config: dict) -> dict:
    """Fetch run data using wandb API."""
    run_id = get_run_id(vault_path)
    entity = config.get("project_entity")
    project = config["project_name"]

    run_path = f"{entity}/{project}/{run_id}" if entity else f"{project}/{run_id}"
    print(f"[INFO] Fetching data for run: {run_path}")

    api = wandb.Api()
    run = api.run(run_path)

    print("[INFO] Fetching run history...")
    df = run.history()

    iter_times = df["iter_time"].dropna().tolist()
    loss_values = df["loss_lm"].dropna().tolist()
    num_iters = int(df["step"].max()) + 1
    perplexity_val = run.summary.get("perplexity_val")

    return {
        "run_id": run_id,
        "iter_times": iter_times,
        "loss_values": loss_values,
        "perplexity_val": perplexity_val,
        "num_iters": num_iters,
    }


def generate_latex_row(wandb_data: dict, eval_results: dict | None) -> str:
    """Generate a LaTeX table row for the run."""
    parts = []

    # Perplexity (no stderr available, so no subscript)
    perplexity = wandb_data["perplexity_val"]
    if perplexity is not None:
        parts.append(f"{perplexity:.2f}")
    else:
        parts.append("--")

    # Evaluation results
    if eval_results is None:
        for _ in EVAL_TASKS:
            parts.append("--")
        parts.append("--")  # average
    else:
        acc_values = []
        for task, (acc_key, stderr_key, label, display_name) in EVAL_TASKS.items():
            task_data = eval_results[task]
            acc = task_data[acc_key]
            acc_values.append(acc)
            parts.append(f"{100*acc:.1f}")

        avg = statistics.mean(acc_values)
        parts.append(f"{100*avg:.2f}")

    return " & ".join(parts)


def generate_report(wandb_data: dict, config: dict, eval_results: dict | None) -> str:
    """Generate the summary report as a string."""
    lines = []
    lines.append("\n" * 8)
    lines.append("=" * 60)
    lines.append("                    RUN SUMMARY REPORT")
    lines.append("=" * 60)

    # Run name and WandB ID
    lines.append(f"\nrun_name: {config.get('run_name', 'N/A')}")
    lines.append(f"wandb_id: {wandb_data['run_id']}")

    # Perplexity
    perplexity = wandb_data["perplexity_val"]
    if perplexity is not None:
        lines.append(f"perplexity: {perplexity:.2f}")
        lines.append(f"loss_val (ln(perplexity)): {math.log(perplexity):.2f}")
    else:
        lines.append("perplexity: N/A")

    # Iteration time stats
    iter_times = wandb_data["iter_times"]
    assert iter_times is not None
    avg_time = statistics.mean(iter_times)
    median_time = statistics.median(iter_times)
    std_time = statistics.stdev(iter_times) if len(iter_times) > 1 else 0.0
    num_gpus = config.get("num_gpu", 1)
    num_iters = wandb_data["num_iters"]
    total_time = avg_time * num_iters
    h100_hours = total_time * num_gpus / 3600
    assert num_iters == 15028
    h100_hours_median = median_time * 15028 * num_gpus / 3600

    lines.append(f"\niter_time_avg: {avg_time:.4f}s")
    lines.append(f"iter_time_median: {median_time:.4f}s")
    lines.append(f"iter_time_std: {std_time:.4f}s")
    lines.append(f"iter_time_total: {total_time:.2f}s ({total_time/3600:.2f}h)")
    lines.append(f"num_gpus: {num_gpus}")
    lines.append(f"num_iters: {num_iters}")
    lines.append(f"h100_hours: {h100_hours:.2f}")
    lines.append(f"h100_hours_(median): {h100_hours_median:.2f}")

    # Evaluation results
    lines.append("\n" + "-" * 60)
    lines.append("EVALUATION RESULTS")
    lines.append("-" * 60)

    if eval_results is None:
        lines.append("No eval_results.txt found")
    else:
        acc_values = []
        for task, (acc_key, stderr_key, label, display_name) in EVAL_TASKS.items():
            task_data = eval_results[task]
            acc = task_data[acc_key]
            stderr = task_data[stderr_key]
            acc_values.append(acc)
            lines.append(f"{display_name} ({label}): {100*acc:.1f}% +/- {100*stderr:.1f}%")

        avg = statistics.mean(acc_values)
        lines.append(f"\naverage: {100*avg:.2f}%")

    # LaTeX row
    lines.append("\n" + "-" * 60)
    lines.append("LATEX ROW")
    lines.append("-" * 60)
    lines.append(generate_latex_row(wandb_data, eval_results))

    lines.append("\n" + "=" * 60)
    lines.append("\n" * 8)

    return "\n".join(lines)


def main(vault_path: str) -> str:
    print(f"[INFO] Vault path: {vault_path}")

    # Sync wandb if enabled
    if ENABLE_SYNC:
        import subprocess
        wandb_latest = os.path.join(vault_path, "wandb", "latest-run")
        print(f"[INFO] Syncing wandb run at: {wandb_latest}")
        subprocess.run(["wandb", "sync", wandb_latest], cwd=vault_path, timeout=300, check=True)
        print("[INFO] wandb sync completed")
    else:
        print("[INFO] Skipping wandb sync (ENABLE_SYNC=False)")

    # Load config
    with open(os.path.join(vault_path, "config.yaml")) as f:
        config = yaml.safe_load(f)
    print(f"[INFO] Loaded config: run_name={config.get('run_name')}")

    # Fetch wandb data
    wandb_data = fetch_wandb_data(vault_path, config)

    # Load eval results (may be None if eval hasn't run yet)
    eval_path = os.path.join(vault_path, "eval_results.txt")
    eval_results = json.load(open(eval_path)) if os.path.isfile(eval_path) else None

    # Generate report
    report = generate_report(wandb_data, config, eval_results)
    print(report)
    return report


if __name__ == "__main__":
    vault_paths = []

    out = ""
    for vault_path in vault_paths:
        out += main(vault_path)

    with open("run_summary_reports.txt", "w") as f:
        f.write(out)
    print(f"[INFO] Written to run_summary_reports.txt")
