#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
from typing import Any

from evaluation import (
    MAXIMUM_SIGNIFIERS_VALUES,
    MINIMUM_RELEVANCE_VALUES,
    GoalCase,
    _display_model_label,
    _expected_filtered_goal_rows,
    _expected_goal_rows,
    _expected_r_precision_goal_rows,
    _filtered_results_entries,
    _format_cell_stats,
    _format_duration,
    _format_r_precision_cell,
    _format_ratio,
    _goal_entry_row_label,
    _goal_row_label,
    _latex_escape,
    _parameter_label,
    _profile_parameters_match,
    _runs_with_r_precision_rankings,
    build_parameter_results_table,
    build_r_precision_table,
    build_results_table,
)

ROOT = Path(__file__).resolve().parent
DEFAULT_INPUT_PATH = ROOT / "results_json.json"
DEFAULT_OUTPUT_PATH = ROOT / "analysis.txt"
DEFAULT_MARKDOWN_OUTPUT_PATH = ROOT / "analysis.md"


def load_results_json(path: Path) -> dict[str, Any]:
    payload = json.loads(path.read_text())
    if not isinstance(payload, dict):
        raise ValueError(f"{path} must contain a JSON object.")
    if not isinstance(payload.get("results"), list):
        raise ValueError(f"{path} must contain a top-level 'results' list.")
    return payload


def _model_labels(results_entries: list[dict[str, Any]]) -> list[str]:
    model_labels: list[str] = []
    for entry in results_entries:
        label = str(entry.get("llm", "unknown"))
        if label not in model_labels:
            model_labels.append(label)
    return model_labels


def _model_display_labels(model_labels: list[str]) -> dict[str, str]:
    return {label: f"LLM {index}" for index, label in enumerate(model_labels)}


def _format_cell_stats_without_time(goal_entry: dict[str, Any]) -> str:
    runs = goal_entry.get("runs", [])
    if not runs:
        return "-"
    n = len(runs)
    avg_precision = sum(float(run.get("precision", 0.0)) for run in runs) / n
    avg_recall = sum(float(run.get("recall", 0.0)) for run in runs) / n
    return f"P:{_format_ratio(avg_precision)}, R:{_format_ratio(avg_recall)}"


def _markdown_escape(value: str) -> str:
    return value.replace("\\", "\\\\").replace("|", "\\|").replace("\n", "<br>")


def _markdown_table(headers: list[str], rows: list[list[str]]) -> str:
    escaped_headers = [_markdown_escape(header) for header in headers]
    table_rows = [
        "| " + " | ".join(escaped_headers) + " |",
        "| " + " | ".join("---" for _header in headers) + " |",
    ]
    table_rows.extend(
        "| " + " | ".join(_markdown_escape(cell) for cell in row) + " |" for row in rows
    )
    return "\n".join(table_rows)


def build_results_markdown_table(
    results_entries: list[dict[str, Any]],
    *,
    include_wot_robot_control: bool = False,
    model_labels: list[str] | None = None,
    model_display_labels: dict[str, str] | None = None,
    include_time: bool = True,
) -> str:
    results_entries = _filtered_results_entries(results_entries, include_wot_robot_control)
    if not results_entries:
        return "No results available."

    if model_labels is None:
        model_labels = _model_labels(results_entries)

    goals = _expected_goal_rows(include_wot_robot_control)
    present_goal_rows: list[str] = []
    for entry in results_entries:
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_entry_row_label(goal, interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            if row_label not in present_goal_rows:
                present_goal_rows.append(row_label)
    goals.extend([row for row in present_goal_rows if row not in goals])

    stats_runs_by_model_goal: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for entry in results_entries:
        model = str(entry.get("llm", "unknown"))
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_entry_row_label(goal, interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            stats_runs_by_model_goal.setdefault((model, row_label), [])
            stats_runs_by_model_goal[(model, row_label)].extend(goal.get("runs", []))

    headers = [
        "Task/Model",
        *[_display_model_label(label, model_display_labels) for label in model_labels],
    ]
    rows: list[list[str]] = []
    for goal_name in goals:
        cells = []
        for model in model_labels:
            runs = stats_runs_by_model_goal.get((model, goal_name), [])
            if not runs:
                cells.append("-")
            elif include_time:
                cells.append(_format_cell_stats({"runs": runs}))
            else:
                cells.append(_format_cell_stats_without_time({"runs": runs}))
        rows.append([goal_name, *cells])
    return _markdown_table(headers, rows)


def build_parameter_results_markdown_table(
    results_entries: list[dict[str, Any]],
    maximum_signifiers: int | None,
    minimum_relevance_value: float | None,
    *,
    include_wot_robot_control: bool = False,
    model_labels: list[str] | None = None,
    model_display_labels: dict[str, str] | None = None,
    include_time: bool = True,
) -> str:
    results_entries = _filtered_results_entries(results_entries, include_wot_robot_control)
    if not results_entries:
        return "No results available."

    if model_labels is None:
        model_labels = _model_labels(results_entries)

    goals = _expected_filtered_goal_rows(include_wot_robot_control)
    present_goal_rows: list[str] = []
    for entry in results_entries:
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            if not _profile_parameters_match(goal, maximum_signifiers, minimum_relevance_value):
                continue
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            if row_label not in present_goal_rows:
                present_goal_rows.append(row_label)
    goals.extend([row for row in present_goal_rows if row not in goals])

    stats_runs_by_model_goal: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for entry in results_entries:
        model = str(entry.get("llm", "unknown"))
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            if not _profile_parameters_match(goal, maximum_signifiers, minimum_relevance_value):
                continue
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            stats_runs_by_model_goal.setdefault((model, row_label), [])
            stats_runs_by_model_goal[(model, row_label)].extend(goal.get("runs", []))

    headers = [
        "Task/Model",
        *[_display_model_label(label, model_display_labels) for label in model_labels],
    ]
    rows: list[list[str]] = []
    for goal_name in goals:
        cells = []
        for model in model_labels:
            runs = stats_runs_by_model_goal.get((model, goal_name), [])
            if not runs:
                cells.append("-")
            elif include_time:
                cells.append(_format_cell_stats({"runs": runs}))
            else:
                cells.append(_format_cell_stats_without_time({"runs": runs}))
        rows.append([goal_name, *cells])
    return _markdown_table(headers, rows)


def build_r_precision_markdown_table(
    results_entries: list[dict[str, Any]],
    *,
    include_wot_robot_control: bool = False,
    model_labels: list[str] | None = None,
    model_display_labels: dict[str, str] | None = None,
) -> str:
    results_entries = _filtered_results_entries(results_entries, include_wot_robot_control)
    if not results_entries:
        return "No R-precision results available."

    if model_labels is None:
        model_labels = _model_labels(results_entries)

    goals = _expected_r_precision_goal_rows(include_wot_robot_control)
    present_goal_rows: list[str] = []
    for entry in results_entries:
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            if row_label not in present_goal_rows:
                present_goal_rows.append(row_label)
    goals.extend([row for row in present_goal_rows if row not in goals])

    runs_by_model_goal: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for entry in results_entries:
        model = str(entry.get("llm", "unknown"))
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            runs_by_model_goal.setdefault((model, row_label), [])
            runs_by_model_goal[(model, row_label)].extend(_runs_with_r_precision_rankings(goal))

    headers = [
        "Task/Model",
        *[_display_model_label(label, model_display_labels) for label in model_labels],
    ]
    rows: list[list[str]] = []
    for goal_name in goals:
        cells = []
        for model in model_labels:
            runs = runs_by_model_goal.get((model, goal_name), [])
            cells.append(_format_r_precision_cell(runs) if runs else "-")
        rows.append([goal_name, *cells])
    return _markdown_table(headers, rows)


def build_results_table_without_time(
    results_entries: list[dict[str, Any]], include_wot_robot_control: bool = False
) -> str:
    results_entries = _filtered_results_entries(results_entries, include_wot_robot_control)
    if not results_entries:
        return "% No results available."

    model_labels = _model_labels(results_entries)
    model_display_labels = _model_display_labels(model_labels)
    goals = _expected_goal_rows(include_wot_robot_control)
    present_goal_rows: list[str] = []
    for entry in results_entries:
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_entry_row_label(goal, interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            if row_label not in present_goal_rows:
                present_goal_rows.append(row_label)
    goals.extend([row for row in present_goal_rows if row not in goals])

    stats_runs_by_model_goal: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for entry in results_entries:
        model = str(entry.get("llm", "unknown"))
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            row_label = _goal_entry_row_label(goal, interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            stats_runs_by_model_goal.setdefault((model, row_label), [])
            stats_runs_by_model_goal[(model, row_label)].extend(goal.get("runs", []))

    col_spec = "l" + ("c" * len(model_labels))
    rows = [
        r"\begin{table}[htbp]",
        r"\centering",
        f"\\begin{{tabular}}{{{col_spec}}}",
        "  Task/Model & "
        + " & ".join(
            _latex_escape(_display_model_label(label, model_display_labels))
            for label in model_labels
        )
        + r"\\",
    ]
    for goal_name in goals:
        cells = [
            _format_cell_stats_without_time(
                {"runs": stats_runs_by_model_goal.get((model, goal_name), [])}
            )
            if stats_runs_by_model_goal.get((model, goal_name), [])
            else "-"
            for model in model_labels
        ]
        rows.append(
            "  "
            + _latex_escape(goal_name)
            + " & "
            + " & ".join(_latex_escape(cell) for cell in cells)
            + r"\\"
        )
    rows.extend(
        [
            r"\end{tabular}",
            r"\caption{Results without time}",
            r"\label{table:results-without-time}",
            r"\end{table}",
        ]
    )
    return "\n".join(rows)


def build_parameter_results_table_without_time(
    results_entries: list[dict[str, Any]],
    maximum_signifiers: int | None,
    minimum_relevance_value: float | None,
    include_wot_robot_control: bool = False,
) -> str:
    results_entries = _filtered_results_entries(results_entries, include_wot_robot_control)
    if not results_entries:
        return "% No results available."

    model_labels = _model_labels(results_entries)
    model_display_labels = _model_display_labels(model_labels)
    goals = _expected_filtered_goal_rows(include_wot_robot_control)
    present_goal_rows: list[str] = []
    for entry in results_entries:
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            if not _profile_parameters_match(goal, maximum_signifiers, minimum_relevance_value):
                continue
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            if row_label not in present_goal_rows:
                present_goal_rows.append(row_label)
    goals.extend([row for row in present_goal_rows if row not in goals])

    stats_runs_by_model_goal: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for entry in results_entries:
        model = str(entry.get("llm", "unknown"))
        interface_name = str(entry.get("interface", ""))
        for goal in entry.get("goals", []):
            if not _profile_parameters_match(goal, maximum_signifiers, minimum_relevance_value):
                continue
            row_label = _goal_row_label(str(goal.get("goal", "unknown")), interface_name)
            if not include_wot_robot_control and row_label == "Robot Control (WoT)":
                continue
            stats_runs_by_model_goal.setdefault((model, row_label), [])
            stats_runs_by_model_goal[(model, row_label)].extend(goal.get("runs", []))

    parameter_case = GoalCase(
        goal="",
        relevant_tools=[],
        maximum_signifiers=maximum_signifiers,
        minimum_relevance_value=minimum_relevance_value,
    )
    parameter_label = _parameter_label(parameter_case)
    table_label = parameter_label.replace("=", "-").replace(", ", "-").replace(".", "-")

    col_spec = "l" + ("c" * len(model_labels))
    rows = [
        r"\begin{table}[htbp]",
        r"\centering",
        f"\\begin{{tabular}}{{{col_spec}}}",
        "  Task/Model & "
        + " & ".join(
            _latex_escape(_display_model_label(label, model_display_labels))
            for label in model_labels
        )
        + r"\\",
    ]
    for goal_name in goals:
        cells = [
            _format_cell_stats_without_time(
                {"runs": stats_runs_by_model_goal.get((model, goal_name), [])}
            )
            if stats_runs_by_model_goal.get((model, goal_name), [])
            else "-"
            for model in model_labels
        ]
        rows.append(
            "  "
            + _latex_escape(goal_name)
            + " & "
            + " & ".join(_latex_escape(cell) for cell in cells)
            + r"\\"
        )
    rows.extend(
        [
            r"\end{tabular}",
            f"\\caption{{Results without time ({_latex_escape(parameter_label)})}}",
            f"\\label{{table:results-without-time-{table_label}}}",
            r"\end{table}",
        ]
    )
    return "\n".join(rows)


def build_analysis(
    payload: dict[str, Any],
    *,
    include_wot_robot_control: bool = False,
    include_json: bool = False,
) -> str:
    results_entries = _filtered_results_entries(
        payload["results"],
        include_wot_robot_control,
    )
    model_labels = _model_labels(results_entries)
    model_display_labels = _model_display_labels(model_labels)
    table = build_results_table(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
        model_labels=model_labels,
        model_display_labels=model_display_labels,
    )
    filtered_stats_tables = [
        (
            _parameter_label(
                GoalCase(
                    goal="",
                    relevant_tools=[],
                    maximum_signifiers=maximum_signifiers,
                    minimum_relevance_value=minimum_relevance_value,
                )
            ),
            build_parameter_results_table(
                results_entries,
                maximum_signifiers,
                minimum_relevance_value,
                include_wot_robot_control=include_wot_robot_control,
                model_labels=model_labels,
                model_display_labels=model_display_labels,
            ),
        )
        for maximum_signifiers in MAXIMUM_SIGNIFIERS_VALUES
        for minimum_relevance_value in MINIMUM_RELEVANCE_VALUES
    ]
    table_without_time = build_results_table_without_time(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
    )
    filtered_stats_tables_without_time = [
        (
            _parameter_label(
                GoalCase(
                    goal="",
                    relevant_tools=[],
                    maximum_signifiers=maximum_signifiers,
                    minimum_relevance_value=minimum_relevance_value,
                )
            ),
            build_parameter_results_table_without_time(
                results_entries,
                maximum_signifiers,
                minimum_relevance_value,
                include_wot_robot_control=include_wot_robot_control,
            ),
        )
        for maximum_signifiers in MAXIMUM_SIGNIFIERS_VALUES
        for minimum_relevance_value in MINIMUM_RELEVANCE_VALUES
    ]
    r_precision_table = build_r_precision_table(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
        model_labels=model_labels,
        model_display_labels=model_display_labels,
    )

    filtered_payload: dict[str, Any] = {"results": results_entries}
    runtime_section = ""
    total_elapsed_seconds = payload.get("total_elapsed_seconds")
    if isinstance(total_elapsed_seconds, int | float):
        filtered_payload["total_elapsed_seconds"] = total_elapsed_seconds
        runtime_section = (
            "# Evaluation Runtime\n\n"
            f"Total elapsed time: {_format_duration(float(total_elapsed_seconds))}"
            f" ({float(total_elapsed_seconds):.2f}s)\n"
        )

    sections = [
        runtime_section,
        "# Task x Model Statistics (LaTeX)\n\n" + table,
        *[
            "# Task x Model Statistics "
            f"({_latex_escape(parameter_label)}) (LaTeX)\n\n"
            + filtered_table
            for parameter_label, filtered_table in filtered_stats_tables
        ],
        "# Task x Model Statistics Without Time (LaTeX)\n\n" + table_without_time,
        *[
            "# Task x Model Statistics Without Time "
            f"({_latex_escape(parameter_label)}) (LaTeX)\n\n"
            + filtered_table
            for parameter_label, filtered_table in filtered_stats_tables_without_time
        ],
        "# Task x Model R-Precision Bounds (LaTeX)\n\n" + r_precision_table,
    ]
    analysis = "\n\n".join(section for section in sections if section)

    if include_json:
        return json.dumps(filtered_payload, indent=2, sort_keys=True) + "\n\n" + analysis + "\n"
    return analysis + "\n"


def build_analysis_markdown(
    payload: dict[str, Any],
    *,
    include_wot_robot_control: bool = False,
    include_json: bool = False,
) -> str:
    results_entries = _filtered_results_entries(
        payload["results"],
        include_wot_robot_control,
    )
    model_labels = _model_labels(results_entries)
    model_display_labels = _model_display_labels(model_labels)
    table = build_results_markdown_table(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
        model_labels=model_labels,
        model_display_labels=model_display_labels,
    )
    filtered_stats_tables = [
        (
            _parameter_label(
                GoalCase(
                    goal="",
                    relevant_tools=[],
                    maximum_signifiers=maximum_signifiers,
                    minimum_relevance_value=minimum_relevance_value,
                )
            ),
            build_parameter_results_markdown_table(
                results_entries,
                maximum_signifiers,
                minimum_relevance_value,
                include_wot_robot_control=include_wot_robot_control,
                model_labels=model_labels,
                model_display_labels=model_display_labels,
            ),
        )
        for maximum_signifiers in MAXIMUM_SIGNIFIERS_VALUES
        for minimum_relevance_value in MINIMUM_RELEVANCE_VALUES
    ]
    table_without_time = build_results_markdown_table(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
        model_labels=model_labels,
        model_display_labels=model_display_labels,
        include_time=False,
    )
    filtered_stats_tables_without_time = [
        (
            _parameter_label(
                GoalCase(
                    goal="",
                    relevant_tools=[],
                    maximum_signifiers=maximum_signifiers,
                    minimum_relevance_value=minimum_relevance_value,
                )
            ),
            build_parameter_results_markdown_table(
                results_entries,
                maximum_signifiers,
                minimum_relevance_value,
                include_wot_robot_control=include_wot_robot_control,
                model_labels=model_labels,
                model_display_labels=model_display_labels,
                include_time=False,
            ),
        )
        for maximum_signifiers in MAXIMUM_SIGNIFIERS_VALUES
        for minimum_relevance_value in MINIMUM_RELEVANCE_VALUES
    ]
    r_precision_table = build_r_precision_markdown_table(
        results_entries,
        include_wot_robot_control=include_wot_robot_control,
        model_labels=model_labels,
        model_display_labels=model_display_labels,
    )

    filtered_payload: dict[str, Any] = {"results": results_entries}
    runtime_section = ""
    total_elapsed_seconds = payload.get("total_elapsed_seconds")
    if isinstance(total_elapsed_seconds, int | float):
        filtered_payload["total_elapsed_seconds"] = total_elapsed_seconds
        runtime_section = (
            "# Evaluation Runtime\n\n"
            f"Total elapsed time: {_format_duration(float(total_elapsed_seconds))}"
            f" ({float(total_elapsed_seconds):.2f}s)\n"
        )

    sections = [
        runtime_section,
        "# Task x Model Statistics\n\n" + table,
        *[
            "# Task x Model Statistics "
            f"({parameter_label})\n\n"
            + filtered_table
            for parameter_label, filtered_table in filtered_stats_tables
        ],
        "# Task x Model Statistics Without Time\n\n" + table_without_time,
        *[
            "# Task x Model Statistics Without Time "
            f"({parameter_label})\n\n"
            + filtered_table
            for parameter_label, filtered_table in filtered_stats_tables_without_time
        ],
        "# Task x Model R-Precision Bounds\n\n" + r_precision_table,
    ]
    analysis = "\n\n".join(section for section in sections if section)

    if include_json:
        json_block = "```json\n" + json.dumps(filtered_payload, indent=2, sort_keys=True) + "\n```"
        return json_block + "\n\n" + analysis + "\n"
    return analysis + "\n"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Build the evaluation2.py analysis tables from a JSON-only results file."
    )
    parser.add_argument(
        "input",
        nargs="?",
        type=Path,
        default=DEFAULT_INPUT_PATH,
        help=f"JSON results file to analyze. Defaults to {DEFAULT_INPUT_PATH.name}.",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=Path,
        default=DEFAULT_OUTPUT_PATH,
        help=f"Output file for the generated analysis. Defaults to {DEFAULT_OUTPUT_PATH.name}.",
    )
    parser.add_argument(
        "--markdown-output",
        type=Path,
        default=None,
        help=(
            "Output file for the generated Markdown analysis. "
            f"Defaults to {DEFAULT_MARKDOWN_OUTPUT_PATH.name} when using the default text output, "
            "or the text output path with a .md suffix when --output is set."
        ),
    )
    parser.add_argument(
        "--include-wot-robot-control",
        action="store_true",
        help="Include WoT robot-control rows that evaluation2.py normally hides unless enabled.",
    )
    parser.add_argument(
        "--include-json",
        action="store_true",
        help="Prepend the filtered JSON payload, matching the combined structure of results2.txt.",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    payload = load_results_json(args.input)
    output = build_analysis(
        payload,
        include_wot_robot_control=args.include_wot_robot_control,
        include_json=args.include_json,
    )
    markdown_output = build_analysis_markdown(
        payload,
        include_wot_robot_control=args.include_wot_robot_control,
        include_json=args.include_json,
    )
    markdown_output_path = (
        args.markdown_output
        if args.markdown_output is not None
        else (
            DEFAULT_MARKDOWN_OUTPUT_PATH
            if args.output == DEFAULT_OUTPUT_PATH
            else args.output.with_suffix(".md")
        )
    )
    args.output.write_text(output)
    markdown_output_path.write_text(markdown_output)
    print(f"Wrote analysis to {args.output}")
    print(f"Wrote Markdown analysis to {markdown_output_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
