# This file was generated by `agenkit build`.
import asyncio
import typer
import json
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from services.factory import create_agent
import copy
import re
import yaml
import os
import itertools
import numpy as np


def extract_html_tag(key: str, text: str) -> str | None:
    pattern = f"<{key}>(.*?)</{key}>"
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return [match.strip() for match in matches][0]


def extract_evaluation_details(evaluation_text: str) -> dict:
    """Extract structured details from evaluation text"""
    details = {
        "raw_evaluation": evaluation_text,
        "assessments": {},
    }

    # Extract individual assessment sections
    assessment_types = [
        "semantic_fidelity_assessment",
        "jargon_reduction_assessment",
        "first_principles_assessment",
        "information_loss_assessment",
        "ambiguity_assessment",
        "solution_leakage_assessment",
    ]

    for assessment_type in assessment_types:
        assessment_content = extract_html_tag(assessment_type, evaluation_text)
        if assessment_content:
            # Extract justification and score from within each assessment
            justification = None
            score = None

            # Look for "Justification:" followed by text
            if "Justification:" in assessment_content:
                justification_start = assessment_content.find("Justification:") + len(
                    "Justification:"
                )
                if "Score:" in assessment_content:
                    justification_end = assessment_content.find("Score:")
                    justification = assessment_content[
                        justification_start:justification_end
                    ].strip()
                else:
                    justification = assessment_content[justification_start:].strip()

            # Look for "Score:" followed by a score pattern
            if "Score:" in assessment_content:
                score_start = assessment_content.find("Score:") + len("Score:")
                score_line = assessment_content[score_start:].split("\n")[0].strip()
                score = score_line

            details["assessments"][assessment_type] = {
                "full_content": assessment_content,
                "justification": justification,
                "score": score,
            }

    # Extract overall justification (simple format fallback)
    overall_justification = extract_html_tag("justification", evaluation_text)
    if overall_justification:
        details["overall_justification"] = overall_justification

    # Extract overall score (simple format fallback) - only add if it exists
    overall_score = extract_html_tag("score", evaluation_text)
    if overall_score:
        details["overall_score"] = overall_score

    # Extract final judgement and parse it for overall justification
    final_judgement = extract_html_tag("final_judgement", evaluation_text)
    if final_judgement:
        details["final_judgement"] = final_judgement

        # Try to extract overall justification from final_judgement
        if "Justification:" in final_judgement and "Judgement:" in final_judgement:
            justification_start = final_judgement.find("Justification:") + len(
                "Justification:"
            )
            justification_end = final_judgement.find("Judgement:")
            overall_justification = final_judgement[
                justification_start:justification_end
            ].strip()
            if overall_justification:
                details["overall_justification"] = overall_justification

            # Extract the final decision
            judgement_start = final_judgement.find("Judgement:") + len("Judgement:")
            final_decision = final_judgement[judgement_start:].strip()
            if final_decision:
                details["final_decision"] = final_decision

    return details


def load_config():
    """Load configuration from config.yaml"""
    config_path = Path("configs/generalizer.yaml")
    if config_path.exists():
        with open(config_path, "r") as f:
            return yaml.safe_load(f)
    return {}


def load_papers(config):
    """Load papers from the dataset specified in config"""
    dataset_path = Path(
        config.get("dataset", {}).get("path", "data/iclr_test_sample.json")
    )
    max_papers = config.get("dataset", {}).get("max_papers")

    if not dataset_path.exists():
        typer.secho(f"Dataset not found: {dataset_path}", fg=typer.colors.RED)
        return []

    with open(dataset_path, "r") as f:
        data = json.load(f)
        # Handle both simple array and nested structure
        if isinstance(data, list):
            papers = data
        elif isinstance(data, dict) and "generalization_experiments" in data:
            papers = data["generalization_experiments"]["combined_2025"]
        else:
            typer.secho(
                f"Unexpected data format in {dataset_path}", fg=typer.colors.RED
            )
            return []

    if max_papers:
        papers = papers[:max_papers]

    typer.secho(
        f"Loaded {len(papers)} papers from {dataset_path}", fg=typer.colors.GREEN
    )
    return papers


async def _async_process_paper(
    paper: dict,
    index: int,
    total_papers: int,
    semaphore: asyncio.Semaphore,
    config: dict,
):
    """
    An asynchronous worker to process a single paper, respecting the semaphore limit.
    """
    async with semaphore:
        internal_provider, internal_model = config["model"]["internal"]["name"].split(
            "--"
        )
        external_provider, external_model = config["model"]["external"]["name"].split(
            "--"
        )
        kwargs = {
            "internal_model_name": internal_model,
            "internal_model_provider": internal_provider,
            "external_model_name": external_model,
            "external_model_provider": external_provider,
            "max_internal_tries": config["workflow"]["max_internal_refine_attempts"],
            "max_external_tries": config["workflow"]["max_external_refine_attempts"],
        }
        log_dir = Path(config["output"]["results_dir"])
        safe_external_provider = (
            kwargs["external_model_provider"][:-1]
            if kwargs["external_model_provider"][-1].isdigit()
            else kwargs["external_model_provider"]
        )

        safe_internal_provider = (
            kwargs["internal_model_provider"][:-1]
            if kwargs["internal_model_provider"][-1].isdigit()
            else kwargs["internal_model_provider"]
        )
        safe_external_model = (
            safe_external_provider
            + "__"
            + kwargs["external_model_name"].replace("/", "_")
        )
        safe_internal_model = (
            safe_internal_provider
            + "__"
            + kwargs["internal_model_name"].replace("/", "_")
        )
        log_dir = (
            log_dir
            / "monolithic"
            / f"external_model={safe_external_model}"
            / f"internal_model={safe_internal_model}"
        )
        log_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = log_dir / f"paper_{paper['id']}.json"

        if log_file.exists():
            results = (
                json.load(open(log_file, encoding="utf-8"))
                .get("custom_log", {})
                .get("external_iterations", [])
            )
            if results:
                typer.secho(
                    f"⏭️ Skipping paper {index+1}/{total_papers}: Already processed.",
                    fg=typer.colors.YELLOW,
                )
                return
        typer.secho(
            f"\n➡️ Processing paper {index+1}/{total_papers}: {paper.get('title', 'Unknown')[:60]}...",
            fg=typer.colors.CYAN,
            bold=True,
        )
        workflow_log = {}
        custom_log = {"external_iterations": []}

        try:
            custom_log = await _async_main(
                initial_input=paper["abstract"],
                internal_model_name=kwargs["internal_model_name"],
                internal_model_provider=kwargs["internal_model_provider"],
                external_model_name=kwargs["external_model_name"],
                external_model_provider=kwargs["external_model_provider"],
                max_internal_tries=kwargs["max_internal_tries"],
                max_external_tries=kwargs["max_external_tries"],
                workflow_log=workflow_log,
            )
        except Exception as e:
            typer.secho(
                f"\n❌ Workflow for paper {index+1} halted due to an error: {e}",
                fg=typer.colors.RED,
            )
        finally:
            typer.secho(
                f"✅ Finished paper {index+1}. Writing log to {log_file}",
                fg=typer.colors.YELLOW,
            )
            paper_data = {"metadata": paper, "workflow": {}, "custom_log": custom_log}
            with open(log_file, "w", encoding="utf-8") as f:
                json.dump(paper_data, f, indent=2)


async def run_async_workflow():
    """
    Sets up and runs asynchronous processing for all papers across multiple
    experiment configurations IN PARALLEL to maximize throughput.
    """
    # 1. Load base configuration and generate experiment configs (same as before)
    base_config = load_config()
    papers = load_papers(base_config)

    if not papers:
        typer.secho("No papers to process. Exiting.", fg=typer.colors.RED)
        return

    internal_models = base_config.get("model", {}).get("internal", {}).get("name", [])
    external_models = base_config.get("model", {}).get("external", {}).get("name", [])

    if not isinstance(internal_models, list):
        internal_models = [internal_models]
    if not isinstance(external_models, list):
        external_models = [external_models]

    model_combinations = list(itertools.product(internal_models, external_models))

    experiment_configs = []
    for internal_model, external_model in model_combinations:
        exp_config = copy.deepcopy(base_config)
        exp_config["model"]["internal"]["name"] = internal_model
        exp_config["model"]["external"]["name"] = external_model
        experiment_configs.append(exp_config)

    # 2. Create a master list to hold all tasks from all experiments
    all_tasks = []
    semaphore = asyncio.Semaphore(1000)
    total_papers = len(papers)

    typer.secho(
        f"🚀 Queuing up tasks for {len(experiment_configs)} experiments...",
        fg=typer.colors.CYAN,
        bold=True,
    )

    # 3. Loop through configs to CREATE and QUEUE tasks, but do not run them yet
    for config in experiment_configs:
        internal_name = config["model"]["internal"]["name"]
        external_name = config["model"]["external"]["name"]
        typer.secho(
            f"  - Queuing tasks for: Internal='{internal_name}', External='{external_name}'"
        )

        tasks_for_this_config = [
            _async_process_paper(paper, i, total_papers, semaphore, config)
            for i, paper in enumerate(papers)
        ]
        all_tasks.extend(tasks_for_this_config)

    typer.secho(
        f"\n🔥 Running {len(all_tasks)} total tasks concurrently to maximize GPU usage...",
        fg=typer.colors.MAGENTA,
        bold=True,
    )
    np.random.shuffle(all_tasks)
    # 4. Run ALL queued tasks from ALL experiments in parallel
    await asyncio.gather(*all_tasks)

    typer.secho(
        "\n\n✅ All tasks from all experiments have finished processing.",
        fg=typer.colors.GREEN,
        bold=True,
    )


def main():
    """
    Executes the agentic workflow defined in your plan.
    """

    asyncio.run(run_async_workflow())


async def _async_main(
    initial_input: str,
    internal_model_name: str,
    internal_model_provider: str,
    external_model_name: str,
    external_model_provider: str,
    max_internal_tries: int,
    max_external_tries: int,
    workflow_log: dict,
):
    load_dotenv()
    context = {"initial_input": initial_input}

    # Initialize custom logging structure
    custom_log = {"external_iterations": []}

    typer.secho("🚀 Starting compiled workflow...", fg=typer.colors.CYAN, bold=True)
    ## Start generalization
    typer.secho(
        "\n🔄 Starting repeat block for ['generalizer', ['generalizer_monolithic_reviewer'], 'generalizer_retry', ['generalizer_monolithic_reviewer'], 'generalizer_retry']...",
        fg=typer.colors.MAGENTA,
    )
    num_external_tries = 0
    external_happy = False
    while not external_happy and num_external_tries < max_external_tries:
        # Initialize current external iteration log
        external_iteration = {
            "external_iteration": num_external_tries + 1,
            "internal_attempts": [],
            "external_evaluation": None,
            "external_happy": False,
        }

        num_internal_tries = 0
        internal_happy = False
        while not internal_happy and num_internal_tries < max_internal_tries:
            typer.secho(
                f"\n Generalization => Internal Try: {num_internal_tries+1} | External Try: {num_external_tries+1}",
                fg=typer.colors.BLUE,
            )

            # Initialize current internal attempt log
            internal_attempt = {
                "internal_attempt": num_internal_tries + 1,
                "agent_used": None,
                "generalization_output": None,
                "problem_statement": None,
                "internal_evaluation": None,
                "internal_feedback": None,
                "internal_happy": False,
            }

            if (num_internal_tries == 0 and num_external_tries == 0) or (
                not context.get("feedback_for_generalizer")
            ):
                typer.secho("\n▶  Running Agent: generalizer", fg=typer.colors.BLUE)
                internal_attempt["agent_used"] = "generalizer"
                generalizer_agent = create_agent(
                    agent_type="generalizer",
                    model_name=internal_model_name,
                    model_provider=internal_model_provider,
                )
                context = await generalizer_agent.run(
                    workflow_log={}, **context  # Pass empty workflow_log
                )
                internal_attempt["generalization_output"] = context[
                    "generalizer.generalize.output"
                ]
                context["current_problem_statement"] = extract_html_tag(
                    "problem_statement",
                    context["generalizer.generalize.output"],
                )
                internal_attempt["problem_statement"] = context[
                    "current_problem_statement"
                ]
            else:
                typer.secho(
                    "\n▶  Running Agent: generalizer_retry",
                    fg=typer.colors.BLUE,
                )
                internal_attempt["agent_used"] = "generalizer_retry"
                generalizer_retry_agent = create_agent(
                    agent_type="generalizer_retry",
                    model_name=internal_model_name,
                    model_provider=internal_model_provider,
                )
                context = await generalizer_retry_agent.run(
                    workflow_log={}, **context  # Pass empty workflow_log
                )
                internal_attempt["generalization_output"] = context[
                    "generalizer_retry.retry.output"
                ]
                context["current_problem_statement"] = extract_html_tag(
                    "problem_statement",
                    context["generalizer_retry.retry.output"],
                )
                internal_attempt["problem_statement"] = context[
                    "current_problem_statement"
                ]
            # trigger self reflection
            typer.secho(
                "\n Starting INTERNAL EVALUATION block: ['generalizer_monolithic_reviewer']...",
                fg=typer.colors.CYAN,
            )
            typer.secho(
                "\n▶  Running Agent: generalizer_monolithic_reviewer",
                fg=typer.colors.BLUE,
            )
            generalizer_monolithic_reviewer_agent = create_agent(
                agent_type="generalizer_monolithic_reviewer",
                model_name=internal_model_name,
                model_provider=internal_model_provider,
            )
            context = await generalizer_monolithic_reviewer_agent.run(
                workflow_log={}, **copy.deepcopy(context)  # Pass empty workflow_log
            )
            internal_evaluation_raw = context[
                "generalizer_monolithic_reviewer.measure.output"
            ]
            internal_attempt["internal_evaluation"] = extract_evaluation_details(
                internal_evaluation_raw
            )

            final_judgement_str = extract_html_tag(
                "final_judgement",
                internal_evaluation_raw,
            )
            if not final_judgement_str:
                internal_attempt["internal_happy"] = False
                external_iteration["internal_attempts"].append(internal_attempt)
                num_internal_tries += 1
                continue
            decision_start_idx = final_judgement_str.index("Judgement:")

            final_decision = final_judgement_str[decision_start_idx:].strip().lower()

            context["feedback_for_generalizer"] = copy.deepcopy(
                context["generalizer_monolithic_reviewer.measure.output"]
            )
            internal_attempt["internal_feedback"] = context["feedback_for_generalizer"]

            internal_happy = "yes" in final_decision
            internal_attempt["internal_happy"] = internal_happy
            external_iteration["internal_attempts"].append(internal_attempt)
            num_internal_tries += 1

        # trigger external evaluation once the internal evaluator is "happy"

        typer.secho(
            "\n▶  Running Agent: generalizer_monolithic_reviewer_external",
            fg=typer.colors.BLUE,
        )
        generalizer_monolithic_reviewer_external_agent = create_agent(
            agent_type="generalizer_monolithic_reviewer_external",
            model_name=external_model_name,
            model_provider=external_model_provider,
        )
        context = await generalizer_monolithic_reviewer_external_agent.run(
            workflow_log={}, **copy.deepcopy(context)  # Pass empty workflow_log
        )
        typer.secho(" Finished EXTERNAL parallel block.", fg=typer.colors.CYAN)

        # Capture external evaluation with structured details
        external_evaluation_raw = context[
            "generalizer_monolithic_reviewer_external.measure.output"
        ]
        external_iteration["external_evaluation"] = extract_evaluation_details(
            external_evaluation_raw
        )

        final_judgement_str = extract_html_tag(
            "final_judgement",
            external_evaluation_raw,
        )
        if not final_judgement_str:
            external_iteration["external_happy"] = False
            custom_log["external_iterations"].append(external_iteration)
            num_external_tries += 1
            continue
        # append internal and external feedback
        context["feedback_for_generalizer"] = (
            f"Feedback from internal evaluation:\n\n{context['feedback_for_generalizer']}\n\n"
            + f"Feedback from external evaluation:\n\n{external_evaluation_raw}"
        )
        decision_start_idx = final_judgement_str.index("Judgement:")

        final_decision = final_judgement_str[decision_start_idx:].strip().lower()

        external_happy = "yes" in final_decision
        external_iteration["external_happy"] = external_happy
        custom_log["external_iterations"].append(external_iteration)
        num_external_tries += 1

    ## End generalization

    # Add final summary counts
    total_internal_attempts = sum(
        len(ext_iter["internal_attempts"])
        for ext_iter in custom_log["external_iterations"]
    )
    total_external_attempts = len(custom_log["external_iterations"])

    custom_log["summary"] = {
        "total_external_attempts": total_external_attempts,
        "total_internal_attempts": total_internal_attempts,
        "workflow_completed": external_happy,
        "final_external_happy": external_happy,
        "final_problem_statement": context.get("current_problem_statement", ""),
    }

    typer.secho(
        f"\n📊 Workflow Summary: {total_external_attempts} external attempts, {total_internal_attempts} total internal attempts",
        fg=typer.colors.GREEN,
    )
    typer.secho("\n🎉 Workflow finished successfully!", fg=typer.colors.CYAN, bold=True)

    return custom_log


if __name__ == "__main__":
    typer.run(main)
