import asyncio
import typer
import json
import os
import os
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
import re
import yaml
from services.factory import create_agent


def extract_html_tag(key: str, text: str | None) -> str | None:
    """Extracts content enclosed within specific HTML-like tags."""
    if text is None:
        return None
    pattern = f"<{key}>(.*?)</{key}>"
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return [match.strip() for match in matches][0]


def extract_evaluation_details(evaluation_text: str) -> dict:
    """Extracts detailed assessment information from a raw evaluation text."""
    details = {
        "raw_evaluation": evaluation_text,
        "assessments": {},
    }
    assessment_types = [
        "novelty_assessment",
        "technical_assessment",
        "completeness_assessment",
    ]
    for assessment_type in assessment_types:
        assessment_content = extract_html_tag(assessment_type, evaluation_text)
        if assessment_content:
            justification, score = None, None
            justification_match = re.search(
                r"Justification:\s*(.*?)(?:\s*Score:|$)",
                assessment_content,
                re.IGNORECASE | re.DOTALL,
            )
            score_match = re.search(
                r"Score:\s*(.*)", assessment_content, re.IGNORECASE | re.DOTALL
            )

            if justification_match:
                justification = justification_match.group(1).strip()
            if score_match:
                score = score_match.group(1).split("\n")[0].strip()

            details["assessments"][assessment_type] = {
                "full_content": assessment_content,
                "justification": justification,
                "score": score,
            }
    overall_justification = extract_html_tag("justification", evaluation_text)
    if overall_justification:
        details["overall_justification"] = overall_justification
    overall_score = extract_html_tag("score", evaluation_text)
    if overall_score:
        details["overall_score"] = overall_score
    final_judgement = extract_html_tag("final_judgement", evaluation_text)
    if final_judgement:
        details["final_judgement"] = final_judgement
        justification_in_final_match = re.search(
            r"Justification:\s*(.*?)(?:\s*Judgement:|$)",
            final_judgement,
            re.IGNORECASE | re.DOTALL,
        )
        judgement_match = re.search(
            r"Judgement:\s*(.*)", final_judgement, re.IGNORECASE | re.DOTALL
        )

        if justification_in_final_match:
            details["overall_justification"] = justification_in_final_match.group(
                1
            ).strip()
        if judgement_match:
            details["final_decision"] = judgement_match.group(1).strip()
    return details


def _parse_llm_judge_json(
    raw_text: str | None, expected_key: str, relevance_threshold: int = 4
) -> dict | None:
    """Parse standardized LLM-as-a-Judge JSON and derive binary from relevanceScore.

    Returns a dict with keys: relevanceScore, confidenceLevel, summaryStatement, binary
    or None if parsing fails or expected section missing.
    """
    # some models output in this format
    match = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.DOTALL)
    if match:
        raw_text = match.group(1)
    elif not match:
        # try another pattern
        match = re.search(r"\s*(\{.*?\}\s*\})", raw_text, re.DOTALL)
        if match:
            raw_text = match.group(1)
    data = json.loads(raw_text)
    section = data.get(expected_key, {}) or {}
    relevance = section.get("relevanceScore")
    confidence = section.get("confidenceLevel")
    summary = section.get("summaryStatement")
    binary = "YES" if float(relevance) >= relevance_threshold else "NO"
    return {
        "relevanceScore": relevance,
        "confidenceLevel": confidence,
        "summaryStatement": summary,
        "binary": binary,
    }


def load_config():
    """Load configuration from config.yaml"""
    config_path = Path("configs/sol_evaluator.yaml")
    if config_path.exists():
        with open(config_path, "r") as f:
            return yaml.safe_load(f)
    return {}


def find_solution_dirs(config):
    """Find all solution directories within the base path, searching two levels deep."""
    base_path = Path(config["dataset"]["base_path"])
    if not base_path.exists() or not base_path.is_dir():
        typer.secho(f"Base dataset path not found: {base_path}", fg=typer.colors.RED)
        return []

    solution_dirs = [p.resolve() for p in base_path.glob("*/*/*") if p.is_dir()]

    typer.secho(
        f"Found {len(solution_dirs)} nested solution directories to evaluate in {base_path}",
        fg=typer.colors.BLUE,
    )
    return solution_dirs


def load_papers(config, papers_path: Path, generalizer_base_path: Path):
    """Load papers from a specific dataset directory, extracting relevant data and workflow logs."""
    max_papers = config.get("dataset", {}).get("max_papers")

    papers = []
    paper_paths = []
    skipped_paper_ids = []  # New list to store skipped paper IDs
    json_files = list(papers_path.rglob("*.json"))
    for paper_path in sorted(json_files, key=lambda x: x.name):
        try:
            with open(paper_path, "r", encoding="utf-8") as f:
                full_data = json.load(f)
                paper_data = full_data["metadata"]
                paper_data["solver_workflow_log"] = full_data.get("custom_log", {})

                final_solution = (
                    full_data.get("custom_log", {})
                    .get("summary", {})
                    .get("final_solution")
                )
                if final_solution is None:
                    final_solution = full_data.get("custom_log", {}).get(
                        "final_solution"
                    )

                if not final_solution:
                    skipped_paper_ids.append(paper_data["id"])
                    typer.secho(
                        f"⚠️ Skipping paper {paper_data['id']}: Solver's final_solution is empty or missing.",
                        fg=typer.colors.YELLOW,
                    )
                    continue  # Skip this paper and move to the next

                paper_data["final_solution"] = final_solution
                relative_model_dir = "external_model=" + re.search(
                    r"ps_from=([^/]+)", str(paper_path)
                )[1].replace("--", os.sep)
                generalizer_log_file = (
                    generalizer_base_path / relative_model_dir / paper_path.name
                )
                paper_paths.append(paper_path)

                generalizer_problem_statement = ""
                generalizer_workflow = {}
                if generalizer_log_file.exists():
                    with open(generalizer_log_file, "r", encoding="utf-8") as gf:
                        generalizer_data = json.load(gf)
                        generalizer_problem_statement = (
                            generalizer_data.get("custom_log", {})
                            .get("summary", {})
                            .get("final_problem_statement", "")
                        )
                        generalizer_workflow = generalizer_data.get("custom_log", {})
                else:
                    typer.secho(
                        f"⚠️ Generalizer log file not found for {paper_data['id']}: {generalizer_log_file}",
                        fg=typer.colors.YELLOW,
                    )
                paper_data["final_problem_statement"] = generalizer_problem_statement
                paper_data["generalizer_workflow_log"] = (
                    generalizer_workflow  # Store generalizer's full workflow log
                )

                if not paper_data["final_problem_statement"]:
                    skipped_paper_ids.append(paper_data["id"])
                    typer.secho(
                        f"⚠️ Skipping paper {paper_data['id']}: Generalizer's final_problem_statement is empty or missing.",
                        fg=typer.colors.YELLOW,
                    )
                    continue  # Skip this paper and move to the next

                if "current_problem_statement" in full_data["metadata"]:
                    paper_data["current_problem_statement"] = full_data["metadata"][
                        "current_problem_statement"
                    ]
                else:
                    paper_data["current_problem_statement"] = ""

                papers.append(paper_data)
        except Exception as e:
            paper_id_from_path = paper_path.stem.replace("paper_", "")
            skipped_paper_ids.append(paper_id_from_path)
            typer.secho(
                f"⚠️ Skipping paper {paper_id_from_path} due to error loading: {e}",
                fg=typer.colors.YELLOW,
            )
            continue  # Skip this paper and move to the next

    if max_papers:
        papers = papers[:max_papers]

    typer.secho(
        f"  Loaded {len(papers)} papers from {Path(*papers_path.parts[-2:])}",
        fg=typer.colors.GREEN,
    )

    return papers, paper_paths, skipped_paper_ids


async def _async_process_paper(
    paper: dict,
    index: int,
    total_papers: int,
    semaphore: asyncio.Semaphore,
    config: dict,
    solution_path: Path,
    base_path: Path,
    relevance_threshold: int = 4,  # New parameter
):
    """An asynchronous worker to process a single paper, respecting the semaphore limit."""
    async with semaphore:
        model_name_full = config["model"]["name"]
        if "--" in model_name_full:
            model_provider, model_name = model_name_full.split("--")
        else:
            model_provider = config["model"].get("provider", "unknown")
            model_name = model_name_full
        kwargs = {
            "model_name": model_name,
            "model_provider": model_provider,
        }
        log_dir_base = Path(config["output"]["results_dir"])
        mp = (
            kwargs["model_provider"][:-1]
            if kwargs["model_provider"][-1].isdigit()
            else kwargs["model_provider"]
        )
        safe_eval_model = mp + "__" + kwargs["model_name"].replace("/", "_")

        relative_path = solution_path.relative_to(base_path)
        safe_sol_model = str(relative_path)
        log_file = (
            log_dir_base
            / "solver"
            / "monolithic"
            / f"evaluator_model={safe_eval_model}"
            / f"solution_model={safe_sol_model}"
        )

        log_file.parent.mkdir(parents=True, exist_ok=True)
        # log_file = log_dir / f"paper_{paper['id']}.json"
        assert log_file.name == f"paper_{paper['id']}.json"

        typer.secho(
            f"    ➡️ ({safe_sol_model}) Processing paper {index+1}/{total_papers}: {paper.get('title', 'Unknown')[:40]}...",
            fg=typer.colors.CYAN,
        )
        custom_log = {}
        try:
            custom_log = await _async_main(
                initial_input=paper["abstract"],
                problem_statement=paper["current_problem_statement"],
                solution=paper["final_solution"],
                original_abstract=paper["abstract"],
                generalizer_problem_statement=paper.get("final_problem_statement", ""),
                # Pass the raw workflow logs for statistical analysis
                solver_workflow_log=paper.get("solver_workflow_log", {}),
                generalizer_workflow_log=paper.get("generalizer_workflow_log", {}),
                model_name=kwargs["model_name"],
                model_provider=kwargs["model_provider"],
                workflow_log={},  # This is for internal tracking of the *evaluator's* own workflow
                paper_id=paper["id"],
                relevance_threshold=relevance_threshold,  # Pass the dynamic threshold
            )
        except Exception as e:
            typer.secho(
                f"\n    ❌ ({safe_sol_model}) Workflow for paper {index+1} halted: {e}",
                fg=typer.colors.RED,
            )
        finally:
            # Save the evaluation results to a JSON file
            paper_data = {"metadata": paper, "workflow": {}, "custom_log": custom_log}
            with open(log_file, "w", encoding="utf-8") as f:
                json.dump(paper_data, f, indent=2)


async def run_async_workflow(relevance_threshold: int = 4):
    """Sets up and runs asynchronous processing for all nested solution directories."""
    config = load_config()
    solution_dirs = find_solution_dirs(config)
    if not solution_dirs:
        typer.secho(
            "No solution directories found to process. Exiting.", fg=typer.colors.RED
        )
        return None, None

    # Resolve base paths once to ensure consistency for Path.relative_to()
    base_path = Path(config["dataset"]["base_path"]).resolve()
    generalizer_base_path = Path(config["dataset"]["generalizer_base_path"]).resolve()
    semaphore = asyncio.Semaphore(1000)  # Limit concurrent tasks
    tasks = []
    all_skipped_paper_ids = []  # New list to aggregate skipped papers

    for solution_dir in solution_dirs:
        typer.secho(
            f"\n--- Queuing evaluation for: {Path(*solution_dir.parts[-2:])} ---",
            fg=typer.colors.MAGENTA,
            bold=True,
        )
        # Load papers, including captured workflow logs from solver and generalizer
        papers, paper_paths, skipped_paper_ids = load_papers(
            config, solution_dir, generalizer_base_path
        )
        all_skipped_paper_ids.extend(skipped_paper_ids)  # Aggregate skipped paper IDs
        if not papers:
            typer.secho(
                f"No papers found in {solution_dir.name}. Skipping.",
                fg=typer.colors.YELLOW,
            )
            continue

        total_papers = len(papers)
        tasks.extend(
            [
                _async_process_paper(
                    paper,
                    i,
                    total_papers,
                    semaphore,
                    config,
                    paper_paths[i],
                    base_path,
                    relevance_threshold,
                )
                for i, paper in enumerate(papers)
            ]
        )

    typer.secho(
        "\n🚀 Starting parallel processing for all queued tasks...",
        fg=typer.colors.CYAN,
        bold=True,
    )
    run_start_ts = datetime.now().timestamp()
    await asyncio.gather(*tasks)
    typer.secho(
        "\n🎉 All directories have been processed.", fg=typer.colors.CYAN, bold=True
    )
    return run_start_ts, all_skipped_paper_ids


def calculate_and_report_metrics(
    modified_after_ts: float | None = None,
    skipped_paper_ids: list[str] = None,
    relevance_threshold: int = 4,
):
    """Calculates and reports all evaluation metrics, including LLM-as-a-Judge and statistical insights."""
    config = load_config()
    results_dir = Path(config["output"]["results_dir"])

    # Initialize counts for LLM-as-a-Judge binary metrics
    rediscovery_counts = {}
    sr_solver_counts = {}
    novel_and_valid_counts = {}
    sr_baseline_counts = {}
    sr_generalizer_counts = {}
    total_counts = {}

    # Dictionary to hold all statistical metrics, keyed by model_key
    all_model_statistics = {}

    def _parse_score(score_str: str) -> float | None:
        """Helper to parse 'X/10' score strings into floats."""
        if score_str and isinstance(score_str, str) and "/" in score_str:
            try:
                num, den = map(int, score_str.split("/"))
                return num / den
            except ValueError:
                return None
        return None

    # Iterate through all evaluator output files to collect metrics
    for eval_model_dir in results_dir.glob("solver/monolithic/evaluator_model=*"):
        eval_model_name = eval_model_dir.name.split("=")[1]
        solution_model_dirs = list(
            set([p.parent for p in eval_model_dir.rglob("*.json")])
        )
        for solution_model_dir in solution_model_dirs:
            solution_model_name = (
                str(solution_model_dir.resolve())
                .split("solution_model=", 1)[1]
                .replace(os.sep, "<>")
            )
            full_model_key = f"{eval_model_name}__eval__{solution_model_name}"

            # Initialize counts for the current model combination
            rediscovery_counts[full_model_key] = 0
            novel_and_valid_counts[full_model_key] = 0
            sr_solver_counts[full_model_key] = 0
            sr_baseline_counts[full_model_key] = 0
            sr_generalizer_counts[full_model_key] = 0
            total_counts[full_model_key] = 0

            # Initialize statistical lists/dictionaries for the current model if not already present
            if full_model_key not in all_model_statistics:
                all_model_statistics[full_model_key] = {
                    "generalizer_avg_dimensions": {},
                    "solver_avg_dimensions": {},
                    "generalizer_total_external_attempts": [],
                    "generalizer_total_internal_attempts": [],
                    "solver_total_external_attempts": [],
                    "solver_total_internal_attempts": [],
                }

            model_stats = all_model_statistics[full_model_key]

            for log_file in solution_model_dir.glob("paper_*.json"):
                try:
                    if (
                        modified_after_ts is not None
                        and log_file.stat().st_mtime < modified_after_ts
                    ):
                        continue
                    with open(log_file, "r", encoding="utf-8") as f:
                        data = json.load(f)
                        # Extract LLM-as-a-Judge binary decisions
                        evaluation = data.get("custom_log", {}).get("evaluation", {})

                        if evaluation.get("rediscovery", {}).get("binary") == "YES":
                            rediscovery_counts[full_model_key] += 1
                        if evaluation.get("sr_solver", {}).get("binary") == "YES":
                            sr_solver_counts[full_model_key] += 1
                        if (
                            evaluation.get("rediscovery", {}).get("binary") == "NO"
                            and evaluation.get("sr_solver", {}).get("binary") == "YES"
                        ):
                            novel_and_valid_counts[full_model_key] += 1
                        if evaluation.get("sr_baseline", {}).get("binary") == "YES":
                            sr_baseline_counts[full_model_key] += 1
                        if evaluation.get("sr_generalizer", {}).get("binary") == "YES":
                            sr_generalizer_counts[full_model_key] += 1
                        total_counts[full_model_key] += 1

                except Exception as e:
                    typer.secho(
                        f"⚠️ Error processing log file {log_file}: {e}",
                        fg=typer.colors.YELLOW,
                    )

    typer.secho("\n--- Evaluation Results ---", fg=typer.colors.GREEN, bold=True)

    all_metrics = {}
    # Report LLM-as-a-Judge Metrics
    for model_key in total_counts:
        total = total_counts[model_key]
        if total == 0:
            typer.secho(f"No evaluations for {model_key}", fg=typer.colors.YELLOW)
            continue

        rediscovery_percentage = (rediscovery_counts[model_key] / total) * 100
        sr_solver_percentage = (sr_solver_counts[model_key] / total) * 100
        novel_and_valid_percentage = (novel_and_valid_counts[model_key] / total) * 100
        sr_baseline_percentage = (sr_baseline_counts[model_key] / total) * 100
        sr_generalizer_percentage = (sr_generalizer_counts[model_key] / total) * 100

        typer.secho(
            f"\nMetrics for: {model_key} (Threshold: {relevance_threshold})",
            fg=typer.colors.BLUE,
        )
        typer.secho(
            f"  Rediscovery (R): {rediscovery_percentage:.2f}% ({rediscovery_counts[model_key]}/{total})",
            fg=typer.colors.GREEN,
        )
        typer.secho(
            f"  SR_solver: {sr_solver_percentage:.2f}% ({sr_solver_counts[model_key]}/{total})",
            fg=typer.colors.GREEN,
        )
        typer.secho(
            f"  Valid_solver: {novel_and_valid_percentage:.2f}% ({novel_and_valid_counts[model_key]}/{total})",
            fg=typer.colors.GREEN,
        )
        typer.secho(
            f"  SR_baseline: {sr_baseline_percentage:.2f}% ({sr_baseline_counts[model_key]}/{total})",
            fg=typer.colors.GREEN,
        )
        typer.secho(
            f"  SR_generalizer: {sr_generalizer_percentage:.2f}% ({sr_generalizer_counts[model_key]}/{total})",
            fg=typer.colors.GREEN,
        )

        # Store LLM-as-a-Judge metrics for JSON output
        model_metrics = {
            "Rediscovery (R)": f"{rediscovery_percentage:.2f}% ({rediscovery_counts[model_key]}/{total})",
            "SR_solver": f"{sr_solver_percentage:.2f}% ({sr_solver_counts[model_key]}/{total})",
            "NV_solver": f"{novel_and_valid_percentage:.2f}% ({novel_and_valid_counts[model_key]}/{total})",
            "SR_baseline": f"{sr_baseline_percentage:.2f}% ({sr_baseline_counts[model_key]}/{total})",
            "SR_generalizer": f"{sr_generalizer_percentage:.2f}% ({sr_generalizer_counts[model_key]}/{total})",
        }

        # Report and store Additional Statistical Metrics per model
        current_model_stats = all_model_statistics[full_model_key]

        typer.secho(
            f"\n--- Additional Statistical Metrics for {model_key} ---",
            fg=typer.colors.GREEN,
            bold=True,
        )

        # Removed Generalizer Stats calculation and reporting
        # model_metrics["Generalizer_Stats"] = generalizer_stats

        # solver_stats = {}
        # Removed Solver Stats calculation and reporting
        # model_metrics["Solver_Stats"] = solver_stats

        all_metrics[model_key] = model_metrics

    # Write all_metrics to a JSON file for easy parsing and storage
    output_file_path = (
        results_dir / "evaluation_results.json"
        if relevance_threshold == 4
        else results_dir / f"evaluation_results_th{relevance_threshold}.json"
    )
    with open(output_file_path, "w", encoding="utf-8") as f:
        json.dump(all_metrics, f, indent=2)
    typer.secho(
        f"\nAll evaluation metrics saved to {output_file_path}",
        fg=typer.colors.GREEN,
        bold=True,
    )

    # Write skipped paper IDs to a file if any
    if skipped_paper_ids:
        skipped_file_path = (
            results_dir / "skipped_papers.txt"
            if relevance_threshold == 4
            else results_dir / f"skipped_papers_th{relevance_threshold}.txt"
        )
        with open(skipped_file_path, "w", encoding="utf-8") as f:
            for paper_id in skipped_paper_ids:
                f.write(f"{paper_id}\n")
        typer.secho(
            f"Skipped paper IDs saved to {skipped_file_path}",
            fg=typer.colors.YELLOW,
            bold=True,
        )
    return all_metrics  # Return the metrics for aggregation by analyze_thresholds


async def _async_main(
    initial_input: str,
    problem_statement: str,
    solution: str,
    original_abstract: str,
    generalizer_problem_statement: str,
    solver_workflow_log: dict,  # Raw workflow log from the solver output
    generalizer_workflow_log: dict,  # Raw workflow log from the generalizer output
    model_name: str,
    model_provider: str,
    workflow_log: dict,  # This is for internal tracking of the *evaluator's* own workflow (not generalizer/solver raw logs)
    paper_id: str,
    relevance_threshold: int = 4,  # New parameter
):
    """Core asynchronous logic to run LLM-as-a-Judge agents and capture their outputs."""
    load_dotenv()
    custom_log = {}
    # Add initial inputs and raw workflow logs directly to custom_log
    custom_log["original_abstract"] = original_abstract
    custom_log["problem_statement"] = problem_statement
    custom_log["model_solution"] = solution
    custom_log["raw_solver_workflow_log"] = solver_workflow_log
    custom_log["raw_generalizer_workflow_log"] = generalizer_workflow_log

    context = {
        "initial_input": original_abstract,
        "problem_statement": problem_statement,
        "solution": solution,
        "original_abstract": original_abstract,
        "generalizer_output": generalizer_problem_statement,
    }

    # --- LLM-as-a-Judge Agent Calls ---

    # Solution Equivalence Agent: Compares problem statement vs. solution (retry up to 3)
    solution_equivalence_agent = create_agent(
        agent_type="solution_equivalence",
        model_name=model_name,
        model_provider=model_provider,
    )
    se_output = None
    for _ in range(5):
        agent_context_se = await solution_equivalence_agent.run(
            workflow_log={}, **context
        )
        se_output = agent_context_se.get("solution_equivalence.measure.output")
        if se_output:
            break
    if not se_output:
        raise ValueError(
            f"Solution Equivalence agent failed to produce output after retries for paper {paper_id}."
        )
    custom_log["solution_equivalence_output"] = se_output

    # Rediscovery Agent (R): Compares original abstract vs. solver's final solution (retry up to 3)
    rediscovery_agent = create_agent(
        agent_type="rediscovery",
        model_name=model_name,
        model_provider=model_provider,
    )
    r_output = None
    for _ in range(5):
        agent_context_r = await rediscovery_agent.run(workflow_log={}, **context)
        r_output = agent_context_r.get("rediscovery.measure.output")
        if r_output:
            break
    if not r_output:
        raise ValueError(
            f"Rediscovery agent failed to produce output after retries for paper {paper_id}."
        )
    custom_log["rediscovery_output"] = r_output
    typer.secho(f"  Rediscovery Raw Output:\n{r_output[:500]}...", fg=typer.colors.BLUE)

    # SR_solver Agent: Compares generalizer's final problem statement vs. solver's final solution (retry up to 3)
    sr_solver_agent = create_agent(
        agent_type="sr_solver",
        model_name=model_name,
        model_provider=model_provider,
    )
    srs_output = None
    for _ in range(5):
        # Validate inputs before running SR_solver agent
        if not context.get("generalizer_output"):
            raise ValueError(
                f"SR_solver: Missing generalizer_output (Problem Statement) for paper {paper_id}"
            )
        if not context.get("solution"):
            raise ValueError(f"SR_solver: Missing solution for paper {paper_id}")

        agent_context_srs = await sr_solver_agent.run(workflow_log={}, **context)
        srs_output = agent_context_srs.get("sr_solver.measure.output")
        if srs_output:
            break
    if not srs_output:
        raise ValueError(
            f"SR_solver agent failed to produce output after retries for paper {paper_id}."
        )
    custom_log["sr_solver_output"] = srs_output
    typer.secho(f"  SR_solver Raw Output:\n{srs_output[:500]}...", fg=typer.colors.BLUE)

    # SR_baseline Agent: Compares generalizer's final problem statement vs. original abstract (retry up to 3)
    sr_baseline_agent = create_agent(
        agent_type="sr_baseline",
        model_name=model_name,
        model_provider=model_provider,
    )
    srb_output = None
    for _ in range(5):
        agent_context_srb = await sr_baseline_agent.run(workflow_log={}, **context)
        srb_output = agent_context_srb.get("sr_baseline.measure.output")
        if srb_output:
            break
    if not srb_output:
        raise ValueError(
            f"SR_baseline agent failed to produce output after retries for paper {paper_id}."
        )
    custom_log["sr_baseline_output"] = srb_output

    # SR_generalizer Agent: Compares original abstract vs. generalizer's output (final problem statement) (retry up to 3)
    sr_generalizer_agent = create_agent(
        agent_type="sr_generalizer",
        model_name=model_name,
        model_provider=model_provider,
    )
    sr_generalizer_output_content = None
    for _ in range(5):
        agent_context_srg = await sr_generalizer_agent.run(workflow_log={}, **context)
        sr_generalizer_output_content = agent_context_srg.get(
            "sr_generalizer.measure.output"
        )
        if sr_generalizer_output_content:
            break
    if not sr_generalizer_output_content:
        raise ValueError(
            f"SR_generalizer agent failed to produce output after retries for paper {paper_id}."
        )
    custom_log["sr_generalizer_output"] = sr_generalizer_output_content

    # Solution Equivalence now outputs standardized JSON; parse and derive binary from relevanceScore
    parsed_se = _parse_llm_judge_json(
        custom_log["solution_equivalence_output"],
        "solution_equivalence",
        relevance_threshold * 2,  # outputs in range 0 - 10
    )
    if parsed_se:
        binary_decision_se = parsed_se["binary"]
        justification_se = parsed_se.get("summaryStatement")
    else:
        raise ValueError(
            f"Failed to parse Solution Equivalence JSON output for paper {paper_id}. Raw output: {custom_log['solution_equivalence_output']}"
        )

    # Rediscovery now outputs standardized JSON; parse and derive binary from relevanceScore
    parsed_r = _parse_llm_judge_json(
        custom_log["rediscovery_output"],
        "solution_rediscovery",
        relevance_threshold,
    )
    if parsed_r:
        binary_decision_r = parsed_r["binary"]
        justification_r = parsed_r.get("summaryStatement")
    else:
        raise ValueError(
            f"Failed to parse Rediscovery JSON output for paper {paper_id}. Raw output: {custom_log['rediscovery_output']}"
        )

    # SR_solver now outputs standardized JSON; parse and derive binary from relevanceScore
    parsed_srs = _parse_llm_judge_json(
        custom_log["sr_solver_output"],
        "problem_solution_alignment",
        relevance_threshold,
    )
    if parsed_srs:
        binary_decision_srs = parsed_srs["binary"]
        justification_srs = parsed_srs.get("summaryStatement")
    else:
        raise ValueError(
            f"Failed to parse SR_solver JSON output for paper {paper_id}. Raw output: {custom_log['sr_solver_output']}"
        )

    # SR_baseline now outputs standardized JSON
    parsed_srb = _parse_llm_judge_json(
        custom_log["sr_baseline_output"],
        "problem_generalization_quality",
        relevance_threshold,
    )
    if parsed_srb:
        binary_decision_srb = parsed_srb["binary"]
        justification_srb = parsed_srb.get("summaryStatement")
    else:
        raise ValueError(
            f"Failed to parse SR_baseline JSON output for paper {paper_id}. Raw output: {custom_log['sr_baseline_output']}"
        )

    # SR_generalizer now outputs standardized JSON
    parsed_srg = _parse_llm_judge_json(
        custom_log["sr_generalizer_output"],
        "problem_generalization_quality",
        relevance_threshold,
    )
    if parsed_srg:
        binary_decision_srg = parsed_srg["binary"]
        justification_srg = parsed_srg.get("summaryStatement")
    else:
        raise ValueError(
            f"Failed to parse SR_generalizer JSON output for paper {paper_id}. Raw output: {custom_log['sr_generalizer_output']}"
        )

    # Add the aggregated evaluation results
    custom_log["evaluation"] = {
        "solution_equivalence": {
            "justification": justification_se,
            "binary": binary_decision_se,
            **(
                {
                    "relevanceScore": parsed_se.get("relevanceScore"),
                    "confidenceLevel": parsed_se.get("confidenceLevel"),
                    "summaryStatement": parsed_se.get("summaryStatement"),
                }
                if parsed_se
                else {}
            ),
        },
        "rediscovery": {
            "justification": justification_r,
            "binary": binary_decision_r,
            **(
                {
                    "relevanceScore": parsed_r.get("relevanceScore"),
                    "confidenceLevel": parsed_r.get("confidenceLevel"),
                    "summaryStatement": parsed_r.get("summaryStatement"),
                }
                if parsed_r
                else {}
            ),
        },
        "sr_solver": {
            "justification": justification_srs,
            "binary": binary_decision_srs,
            **(
                {
                    "relevanceScore": parsed_srs.get("relevanceScore"),
                    "confidenceLevel": parsed_srs.get("confidenceLevel"),
                    "summaryStatement": parsed_srs.get("summaryStatement"),
                }
                if parsed_srs
                else {}
            ),
        },
        "sr_baseline": {
            "justification": justification_srb,
            "binary": binary_decision_srb,
            **(
                {
                    "relevanceScore": parsed_srb.get("relevanceScore"),
                    "confidenceLevel": parsed_srb.get("confidenceLevel"),
                    "summaryStatement": parsed_srb.get("summaryStatement"),
                }
                if parsed_srb
                else {}
            ),
        },
        "sr_generalizer": {
            "justification": justification_srg,
            "binary": binary_decision_srg,
            **(
                {
                    "relevanceScore": parsed_srg.get("relevanceScore"),
                    "confidenceLevel": parsed_srg.get("confidenceLevel"),
                    "summaryStatement": parsed_srg.get("summaryStatement"),
                }
                if parsed_srg
                else {}
            ),
        },
    }

    return custom_log


app = typer.Typer()


@app.command("evaluate")
def evaluate_command(
    relevance_threshold: int = typer.Option(
        4,
        "--relevance-threshold",
        "-t",
        help="The relevance score threshold for binary decisions (0-5).",
    )
):
    """Executes the agentic workflow and reports metrics."""
    load_dotenv()
    run_start_ts, skipped_paper_ids = asyncio.run(
        run_async_workflow(relevance_threshold=relevance_threshold)
    )
    calculate_and_report_metrics(
        modified_after_ts=run_start_ts,
        skipped_paper_ids=skipped_paper_ids,
        relevance_threshold=relevance_threshold,
    )


@app.command("recalculate-metrics")
def recalculate_metrics_command(
    relevance_threshold: int = typer.Option(
        3,
        "--relevance-threshold",
        "-t",
        help="The relevance score threshold for binary decisions (0-5).",
    )
):
    """Recalculates and reports metrics from existing evaluation outputs."""
    load_dotenv()
    typer.secho(
        f"\n--- Recalculating metrics from existing outputs (Threshold: {relevance_threshold}) --- ",
        fg=typer.colors.BLUE,
        bold=True,
    )
    calculate_and_report_metrics(
        modified_after_ts=None,
        skipped_paper_ids=None,
        relevance_threshold=relevance_threshold,
    )


@app.command("analyze-thresholds")
def analyze_thresholds_command(
    min_threshold: int = typer.Option(
        0,
        "--min-threshold",
        "-min",
        help="Minimum relevance score threshold to analyze.",
    ),
    max_threshold: int = typer.Option(
        5,
        "--max-threshold",
        "-max",
        help="Maximum relevance score threshold to analyze.",
    ),
):
    """Analyzes metric trends across a range of relevance score thresholds."""
    load_dotenv()
    config = load_config()
    results_dir = Path(config["output"]["results_dir"])
    all_threshold_metrics = {}

    typer.secho(
        f"\n--- Analyzing metric trends from threshold {min_threshold} to {max_threshold} --- ",
        fg=typer.colors.MAGENTA,
        bold=True,
    )

    for threshold in range(min_threshold, max_threshold + 1):
        typer.secho(
            f"\nProcessing for relevance threshold: {threshold}", fg=typer.colors.CYAN
        )
        metrics_for_threshold = calculate_and_report_metrics(
            modified_after_ts=None,
            skipped_paper_ids=None,
            relevance_threshold=threshold,
        )
        all_threshold_metrics[f"threshold_{threshold}"] = metrics_for_threshold

    output_file_path = results_dir / "threshold_trend_analysis.json"
    with open(output_file_path, "w", encoding="utf-8") as f:
        json.dump(all_threshold_metrics, f, indent=2)
    typer.secho(
        f"\nThreshold trend analysis saved to {output_file_path}",
        fg=typer.colors.GREEN,
        bold=True,
    )


if __name__ == "__main__":
    app()
