# This file was generated by `agenkit build`.
import asyncio
import typer
import json
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from services.factory import create_agent
import copy
import re
import yaml
import itertools
import numpy as np


def extract_html_tag(key: str, text: str) -> str | None:
    pattern = f"<{key}>(.*?)</{key}>"
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return [match.strip() for match in matches][0]


def extract_evaluation_details(evaluation_text: str) -> dict:
    """Extract structured details from evaluation text"""
    details = {
        "raw_evaluation": evaluation_text,
        "assessments": {},
    }

    # Extract individual assessment sections
    assessment_types = [
        "novelty_assessment",
        "technical_assessment",
        "completeness_assessment",
    ]

    for assessment_type in assessment_types:
        assessment_content = extract_html_tag(assessment_type, evaluation_text)
        if assessment_content:
            # Extract justification and score from within each assessment
            justification = None
            score = None

            # Look for "Justification:" followed by text
            if "Justification:" in assessment_content:
                justification_start = assessment_content.find("Justification:") + len(
                    "Justification:"
                )
                if "Score:" in assessment_content:
                    justification_end = assessment_content.find("Score:")
                    justification = assessment_content[
                        justification_start:justification_end
                    ].strip()
                else:
                    justification = assessment_content[justification_start:].strip()

            # Look for "Score:" followed by a score pattern
            if "Score:" in assessment_content:
                score_start = assessment_content.find("Score:") + len("Score:")
                score_line = assessment_content[score_start:].split("\n")[0].strip()
                score = score_line

            details["assessments"][assessment_type] = {
                "full_content": assessment_content,
                "justification": justification,
                "score": score,
            }

    # Extract overall justification (simple format fallback)
    overall_justification = extract_html_tag("justification", evaluation_text)
    if overall_justification:
        details["overall_justification"] = overall_justification

    # Extract overall score (simple format fallback) - only add if it exists
    overall_score = extract_html_tag("score", evaluation_text)
    if overall_score:
        details["overall_score"] = overall_score

    # Extract final judgement and parse it for overall justification
    final_judgement = extract_html_tag("final_judgement", evaluation_text)
    if final_judgement:
        details["final_judgement"] = final_judgement

        # Try to extract overall justification from final_judgement
        if "Justification:" in final_judgement and "Judgement:" in final_judgement:
            justification_start = final_judgement.find("Justification:") + len(
                "Justification:"
            )
            justification_end = final_judgement.find("Judgement:")
            overall_justification = final_judgement[
                justification_start:justification_end
            ].strip()
            if overall_justification:
                details["overall_justification"] = overall_justification

            # Extract the final decision
            judgement_start = final_judgement.find("Judgement:") + len("Judgement:")
            final_decision = final_judgement[judgement_start:].strip()
            if final_decision:
                details["final_decision"] = final_decision

    return details


def load_config():
    """Load configuration from config.yaml"""
    config_path = Path("configs/solver.yaml")
    if config_path.exists():
        with open(config_path, "r") as f:
            return yaml.safe_load(f)
    return {}


def load_papers(config):
    """Load papers from the dataset specified in config"""
    dataset_path = Path(config["dataset"]["path"])
    max_papers = config.get("dataset", {}).get("max_papers")

    if not dataset_path.exists():
        typer.secho(f"Dataset not found: {dataset_path}", fg=typer.colors.RED)
        return []

    # load all the json files here and construct a list of dicts
    papers = []

    # Handle nested directory structure - look for JSON files recursively
    json_files = []
    if dataset_path.is_file() and dataset_path.suffix == ".json":
        # Single file
        json_files = [dataset_path]
    else:
        # Directory - search recursively for JSON files
        json_files = list(dataset_path.rglob("*.json"))

    for paper_path in sorted(json_files, key=lambda x: x.name):
        try:
            with open(paper_path, "r", encoding="utf-8") as f:
                full_data = json.load(f)
                paper_data = full_data["metadata"]

                # Try to load from new custom_log format first, fallback to old workflow format
                problem_statement = None
                if (
                    "custom_log" in full_data
                    and "external_iterations" in full_data["custom_log"]
                ):
                    # New format: look in the last external iteration's last internal attempt
                    external_iterations = full_data["custom_log"]["external_iterations"]
                    if external_iterations:
                        last_iteration = external_iterations[-1]
                        if (
                            "internal_attempts" in last_iteration
                            and last_iteration["internal_attempts"]
                        ):
                            last_internal_attempt = last_iteration["internal_attempts"][
                                -1
                            ]
                            problem_statement = last_internal_attempt.get(
                                "problem_statement"
                            )
                if not problem_statement:
                    problem_statement = full_data["custom_log"]["summary"][
                        "final_problem_statement"
                    ]

                # Fallback to old format if new format doesn't have the data
                if (
                    not problem_statement
                    and "workflow" in full_data
                    and "generalizer.generalize" in full_data["workflow"]
                ):
                    problem_statement = extract_html_tag(
                        "problem_statement",
                        full_data["workflow"]["generalizer.generalize"]["output"],
                    )

                # If still no problem statement, skip this paper
                if not problem_statement:
                    typer.secho(
                        f"⚠️ Skipping paper {paper_data.get('id', 'unknown')} - no problem statement found",
                        fg=typer.colors.YELLOW,
                    )
                    continue

                paper_data["current_problem_statement"] = problem_statement
                papers.append(paper_data)
        except Exception as e:
            typer.secho(f"⚠️ Error loading {paper_path}: {e}", fg=typer.colors.YELLOW)
            continue

    if max_papers:
        papers = papers[:max_papers]

    typer.secho(
        f"Loaded {len(papers)} papers from {dataset_path}", fg=typer.colors.GREEN
    )
    return papers


async def _async_process_paper(
    paper: dict,
    index: int,
    total_papers: int,
    semaphore: asyncio.Semaphore,
    config,
):
    """
    An asynchronous worker to process a single paper, respecting the semaphore limit.
    """
    async with semaphore:
        # Pack all CLI arguments into a dictionary to pass to the async runner
        internal_provider, internal_model = config["model"]["internal"]["name"].split(
            "--"
        )
        external_provider, external_model = config["model"]["external"]["name"].split(
            "--"
        )
        kwargs = {
            "internal_model_name": internal_model,
            "internal_model_provider": internal_provider,
            "external_model_name": external_model,
            "external_model_provider": external_provider,
            "max_internal_tries": config["workflow"]["max_internal_refine_attempts"],
            "max_external_tries": config["workflow"]["max_external_refine_attempts"],
        }

        log_dir = Path(config["output"]["results_dir"])
        safe_external_provider = (
            kwargs["external_model_provider"][:-1]
            if kwargs["external_model_provider"][-1].isdigit()
            else kwargs["external_model_provider"]
        )

        safe_internal_provider = (
            kwargs["internal_model_provider"][:-1]
            if kwargs["internal_model_provider"][-1].isdigit()
            else kwargs["internal_model_provider"]
        )
        safe_external_model = (
            safe_external_provider
            + "__"
            + kwargs["external_model_name"].replace("/", "_")
        )
        safe_internal_model = (
            safe_internal_provider
            + "__"
            + kwargs["internal_model_name"].replace("/", "_")
        )
        log_dir = (
            log_dir
            / "monolithic"
            / f"external_model={safe_external_model}"
            / f"internal_model={safe_internal_model}"
        )
        log_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = log_dir / f"paper_{paper['id']}.json"
        if log_file.exists():
            results = (
                json.load(open(log_file, encoding="utf-8"))
                .get("custom_log", {})
                .get("external_iterations", [])
            )
            if results:
                typer.secho(
                    f"⏭️ Skipping paper {index+1}/{total_papers}: Already processed.",
                    fg=typer.colors.YELLOW,
                )
                return
        typer.secho(
            f"\n➡️ Processing paper {index+1}/{total_papers}: {paper.get('title', 'Unknown')[:60]}...",
            fg=typer.colors.CYAN,
            bold=True,
        )
        workflow_log = {}
        custom_log = {"external_iterations": []}

        try:
            custom_log = await _async_main(
                initial_input=paper["abstract"],
                current_problem_statement=paper["current_problem_statement"],
                internal_model_name=kwargs["internal_model_name"],
                internal_model_provider=kwargs["internal_model_provider"],
                external_model_name=kwargs["external_model_name"],
                external_model_provider=kwargs["external_model_provider"],
                max_internal_tries=kwargs["max_internal_tries"],
                max_external_tries=kwargs["max_external_tries"],
                workflow_log=workflow_log,
            )
        except Exception as e:
            typer.secho(
                f"\n❌ Workflow for paper {index+1} halted due to an error: {e}",
                fg=typer.colors.RED,
            )
        finally:
            typer.secho(
                f"✅ Finished paper {index+1}. Writing log to {log_file}",
                fg=typer.colors.YELLOW,
            )
            paper_data = {"metadata": paper, "workflow": {}, "custom_log": custom_log}
            with open(log_file, "w", encoding="utf-8") as f:
                json.dump(paper_data, f, indent=2)


async def run_async_workflow():
    """
    Sets up and runs asynchronous processing for all papers across multiple
    experiment configurations IN PARALLEL to maximize throughput.
    """
    # 1. Load base configuration and generate experiment configs (same as before)
    base_config = load_config()
    papers = load_papers(base_config)

    if not papers:
        typer.secho("No papers to process. Exiting.", fg=typer.colors.RED)
        return

    internal_models = base_config.get("model", {}).get("internal", {}).get("name", [])
    external_models = base_config.get("model", {}).get("external", {}).get("name", [])

    if not isinstance(internal_models, list):
        internal_models = [internal_models]
    if not isinstance(external_models, list):
        external_models = [external_models]

    model_combinations = list(itertools.product(internal_models, external_models))

    experiment_configs = []
    for internal_model, external_model in model_combinations:
        exp_config = copy.deepcopy(base_config)
        exp_config["model"]["internal"]["name"] = internal_model
        exp_config["model"]["external"]["name"] = external_model
        experiment_configs.append(exp_config)

    # 2. Create a master list to hold all tasks from all experiments
    all_tasks = []
    semaphore = asyncio.Semaphore(300)
    total_papers = len(papers)

    typer.secho(
        f"🚀 Queuing up tasks for {len(experiment_configs)} experiments...",
        fg=typer.colors.CYAN,
        bold=True,
    )

    # 3. Loop through configs to CREATE and QUEUE tasks, but do not run them yet
    for config in experiment_configs:
        internal_name = config["model"]["internal"]["name"]
        external_name = config["model"]["external"]["name"]
        typer.secho(
            f"  - Queuing tasks for: Internal='{internal_name}', External='{external_name}'"
        )

        tasks_for_this_config = [
            _async_process_paper(paper, i, total_papers, semaphore, config)
            for i, paper in enumerate(papers)
        ]
        all_tasks.extend(tasks_for_this_config)

    typer.secho(
        f"\n🔥 Running {len(all_tasks)} total tasks concurrently to maximize GPU usage...",
        fg=typer.colors.MAGENTA,
        bold=True,
    )
    np.random.shuffle(all_tasks)
    # 4. Run ALL queued tasks from ALL experiments in parallel
    await asyncio.gather(*all_tasks)

    typer.secho(
        "\n\n✅ All tasks from all experiments have finished processing.",
        fg=typer.colors.GREEN,
        bold=True,
    )


def main():
    """
    Executes the agentic workflow defined in your plan.
    """

    asyncio.run(run_async_workflow())


async def _async_main(
    initial_input: str,
    current_problem_statement: str,
    internal_model_name: str,
    internal_model_provider: str,
    external_model_name: str,
    external_model_provider: str,
    max_internal_tries: int,
    max_external_tries: int,
    workflow_log: dict,
):
    load_dotenv()
    context = {
        "initial_input": initial_input,
        "current_problem_statement": current_problem_statement,
    }

    # Initialize custom logging structure
    custom_log = {"external_iterations": []}

    typer.secho("🚀 Starting compiled workflow...", fg=typer.colors.CYAN, bold=True)

    ## Start solving
    typer.secho(
        "\n🔄 Starting repeat block for ['solver', ['solution_creativity', 'solution_technical_feasibility', 'solution_completeness', 'solver_monolithic_reviewer'], 'solver_retry', ['solution_creativity', 'solution_technical_feasibility', 'solution_completeness', 'solver_monolithic_reviewer'], 'solver_monolithic_reviewer_external']...",
        fg=typer.colors.MAGENTA,
    )
    num_external_tries = 0
    external_happy = False
    while not external_happy and num_external_tries < max_external_tries:
        # Initialize current external iteration log
        external_iteration = {
            "external_iteration": num_external_tries + 1,
            "internal_attempts": [],
            "external_evaluation": None,
            "external_happy": False,
        }

        num_internal_tries = 0
        internal_happy = False
        while not internal_happy and num_internal_tries < max_internal_tries:
            typer.secho(
                f"\n Solver => Internal Try: {num_internal_tries+1} | External Try: {num_external_tries+1}",
                fg=typer.colors.BLUE,
            )

            # Initialize current internal attempt log
            internal_attempt = {
                "internal_attempt": num_internal_tries + 1,
                "agent_used": None,
                "solution_output": None,
                "solution": None,
                "internal_evaluation": None,
                "internal_feedback": None,
                "internal_happy": False,
            }

            if (num_internal_tries == 0 and num_external_tries == 0) or (
                not context.get("feedback_for_generalizer")
            ):
                typer.secho("\n▶  Running Agent: solver", fg=typer.colors.BLUE)
                internal_attempt["agent_used"] = "solver"
                solver_agent = create_agent(
                    agent_type="solver",
                    model_name=internal_model_name,
                    model_provider=internal_model_provider,
                )
                context = await solver_agent.run(workflow_log={}, **context)
                internal_attempt["solution_output"] = context["solver.solve.output"]
                solution = extract_html_tag(
                    "solution",
                    context["solver.solve.output"],
                )
                context["current_solution"] = solution
                internal_attempt["solution"] = solution
            else:
                typer.secho("\n▶  Running Agent: solver_retry", fg=typer.colors.BLUE)
                internal_attempt["agent_used"] = "solver_retry"
                solver_retry_agent = create_agent(
                    agent_type="solver_retry",
                    model_name=internal_model_name,
                    model_provider=internal_model_provider,
                )
                context = await solver_retry_agent.run(workflow_log={}, **context)
                internal_attempt["solution_output"] = context[
                    "solver_retry.retry.output"
                ]
                solution = extract_html_tag(
                    "solution",
                    context["solver_retry.retry.output"],
                )
                context["current_solution"] = solution
                internal_attempt["solution"] = solution

            typer.secho(
                "\n Starting INTERNAL EVALUATION block for ['solution_creativity', 'solution_technical_feasibility', 'solution_completeness', 'solver_monolithic_reviewer']...",
                fg=typer.colors.CYAN,
            )
            solver_monolithic_reviewer_agent = create_agent(
                agent_type="solver_monolithic_reviewer",
                model_name=internal_model_name,
                model_provider=internal_model_provider,
            )
            context = await solver_monolithic_reviewer_agent.run(
                workflow_log={}, **copy.deepcopy(context)
            )

            internal_evaluation_raw = context[
                "solver_monolithic_reviewer.measure.output"
            ]
            internal_attempt["internal_evaluation"] = extract_evaluation_details(
                internal_evaluation_raw
            )

            final_judgement_str = extract_html_tag(
                "final_judgement",
                internal_evaluation_raw,
            )
            if not final_judgement_str:
                context["feedback_for_solver"] = (
                    "There was an error in generating the feedback. Try solving again."
                )
                internal_attempt["internal_happy"] = False
                external_iteration["internal_attempts"].append(internal_attempt)
                num_internal_tries += 1
                continue
            decision_start_idx = final_judgement_str.index("Judgement:")

            final_decision = final_judgement_str[decision_start_idx:].strip().lower()

            internal_happy = "yes" in final_decision
            internal_attempt["internal_happy"] = internal_happy
            internal_attempt["internal_feedback"] = copy.deepcopy(
                internal_evaluation_raw
            )

            context["feedback_for_solver"] = copy.deepcopy(internal_evaluation_raw)
            external_iteration["internal_attempts"].append(internal_attempt)
            num_internal_tries += 1

        # start external evaluation
        typer.secho(
            "\n▶  Running Agent: solver_monolithic_reviewer_external",
            fg=typer.colors.BLUE,
        )
        solver_monolithic_reviewer_external_agent = create_agent(
            agent_type="solver_monolithic_reviewer_external",
            model_name=external_model_name,
            model_provider=external_model_provider,
        )
        context = await solver_monolithic_reviewer_external_agent.run(
            workflow_log={}, **copy.deepcopy(context)
        )

        # Capture external evaluation with structured details
        external_evaluation_raw = context[
            "solver_monolithic_reviewer_external.measure.output"
        ]
        external_iteration["external_evaluation"] = extract_evaluation_details(
            external_evaluation_raw
        )

        final_judgement_str = extract_html_tag(
            "final_judgement",
            external_evaluation_raw,
        )
        if not final_judgement_str:
            context["feedback_for_solver"] = (
                "There was an error in generating the feedback. Try solving again."
            )
            external_iteration["external_happy"] = False
            custom_log["external_iterations"].append(external_iteration)
            num_external_tries += 1
            continue
        decision_start_idx = final_judgement_str.index("Judgement:")

        final_decision = final_judgement_str[decision_start_idx:].strip().lower()

        context["feedback_for_solver"] = (
            f"Feedback from internal evaluation:\n\n{context['feedback_for_solver']}\n\n"
            + f"Feedback from external evaluation:\n\n{external_evaluation_raw}"
        )
        external_happy = "yes" in final_decision
        external_iteration["external_happy"] = external_happy
        custom_log["external_iterations"].append(external_iteration)
        num_external_tries += 1

    ## End solving

    # Add final summary counts
    total_internal_attempts = sum(
        len(ext_iter["internal_attempts"])
        for ext_iter in custom_log["external_iterations"]
    )
    total_external_attempts = len(custom_log["external_iterations"])

    custom_log["summary"] = {
        "total_external_attempts": total_external_attempts,
        "total_internal_attempts": total_internal_attempts,
        "workflow_completed": external_happy,
        "final_external_happy": external_happy,
        "final_solution": context.get("current_solution", ""),
    }

    typer.secho(
        f"\n📊 Workflow Summary: {total_external_attempts} external attempts, {total_internal_attempts} total internal attempts",
        fg=typer.colors.GREEN,
    )
    typer.secho("\n🎉 Workflow finished successfully!", fg=typer.colors.CYAN, bold=True)

    return custom_log


if __name__ == "__main__":
    typer.run(main)
