import json
import re
import textwrap

from bigcodebench.eval import untrusted_check
from bigcodebench.sanitize import sanitize

from core.arch.base import BaseComponent
from core.arch.system import CompoundAISystem
from core.utils.api import get_llm_output
from core.utils.textresnet import (
    BACKWARD_SYSTEM_PROMPT,
    DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS,
    STOP_GRADIENT,
    BackwardResponse,
    GradientSignal,
    apply_unified_diff,
    build_backward_context_prompt,
    parse_gradient_response,
    short_text,
)


def _make_code_prompt(instruction: str, question: str) -> str:
    if not isinstance(instruction, str):
        instruction = "" if instruction is None else str(instruction)
    if not isinstance(question, str):
        question = "" if question is None else str(question)
    instruction = instruction.strip()
    question = question.strip()
    return (
        f"{instruction}\n\n"
        f"Problem:\n{question}\n\n"
        "Return ONLY valid Python code (no markdown, no explanations).\n"
    )


class CodeGenerator(BaseComponent):
    """BigCodeBench step 1/4: generate base Python code for the problem (identity)."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(
        self,
        model="anthropic/claude-3-haiku",
        backward_model="openai/gpt-4o-mini",
        max_tokens=2048,
        temperature=0.0,
    ):
        super().__init__(
            description=(
                "BigCodeBench step 1/4 (CodeGenerator): generate a ROBUST, SELF-CONTAINED Python solution. "
                "This code must handle edge cases and include all necessary standard imports. "
                "It serves as the foundation for execution; syntax errors here are fatal."
            ),
            input_fields=["question"],
            output_fields=["initial_code", "local_fix", "upstream_grad"],
            variable="Provide a self-contained Python solution:",
            config={
                "model": model,  # forward model
                "backward_model": backward_model,
                "max_tokens": max_tokens,
                "temperature": temperature,
            },
        )

    def forward(self, **inputs):
        prompt = _make_code_prompt(self.variable, inputs.get("question", ""))

        initial_code = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )

        entry_point = inputs.get("entry_point") or "task_func"
        initial_code = sanitize(initial_code, entry_point)

        return {
            "initial_code": initial_code,
            "local_fix": "base code generated",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        question = component_traj.get("input", {}).get("question", "")
        initial_code = component_traj.get("output", {}).get("initial_code", "")

        lm_input = _make_code_prompt(self.variable, question)
        lm_output = initial_code
        objective = signal.feedback
        response_desc = "the base code used for execution and patching"

        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"

        raw = get_llm_output(
            message=prompt,
            model=getattr(self.config, "backward_model", None) or self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


class UnitTestGenerator(BaseComponent):
    """BigCodeBench step 2/4: generate additional unit tests (incremental constraints)."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(
        self,
        model="anthropic/claude-3-haiku",
        backward_model="openai/gpt-4o-mini",
        max_tokens=1024,
        temperature=0.0,
    ):
        UNIT_TEST_PROMPT = textwrap.dedent(
            """\
            You are an expert Python QA engineer.
            Your goal is to write unit tests to verify the correctness of a given problem solution.
            
            **CRITICAL FORMATTING REQUIREMENT**:
            The test runner REQUIRES the tests to be wrapped in a class named `TestCases` that inherits from `unittest.TestCase`.
            
            **Instructions**:
            1. Import `unittest` and any other necessary libraries.
            2. Define a class: `class TestCases(unittest.TestCase):`
            3. Inside the class, write 3-5 test methods (must start with `test_`).
            4. Use `self.assertEqual`, `self.assertTrue`, `self.assertRaises`, etc.
            
            **CRITICAL RULES FOR FILE I/O**:
            - If the problem involves reading/writing files, you MUST create temporary files within the test method and clean them up.
            - NEVER assume a file exists locally. Use `tempfile` or `unittest.mock`.

            **CRITICAL RULES FOR FUNCTION**:
            - **DO NOT REDEFINE THE FUNCTION**: The function `task_func` is already available in the environment.
            - **NO PLACEHOLDERS**: Do not write `def task_func... pass`.
            
            **Output Format Example**:
            ```python
            import unittest
            import os
            
            class TestCases(unittest.TestCase):
                def test_basic_case(self):
                    self.assertEqual(task_func(1, 2), 3)
                    
                def test_error_case(self):
                    with self.assertRaises(ValueError):
                        task_func(-1, -1)
            ```
            """
        )

        super().__init__(
            description=(
                "BigCodeBench step 2/4 (UnitTestGenerator): generate VALID assertion tests to verify the code. "
                "CRITICAL: Do not invent requirements not present in the prompt. "
                "If the prompt is ambiguous, do not generate a test for that case. "
                "Tests must be deterministic and use standard `assert` statements."
            ),
            variable=UNIT_TEST_PROMPT,
            input_fields=["question"],
            output_fields=["additional_unit_tests", "local_fix", "upstream_grad"],
            config={
                "model": model,  # forward model
                "backward_model": backward_model,
                "max_tokens": max_tokens,
                "temperature": temperature,
            },
        )

    def forward(self, question):
        prompt = f"{self.variable}\nProblem:\n{question}\n\nUnit tests (python):\n"
        unit_tests = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )

        # If model wraps code in fences, strip.
        match = re.search(r"```python\s*(.*?)\s*```", unit_tests, re.DOTALL | re.IGNORECASE)
        if match:
            unit_tests = match.group(1).strip()

        # delete pulic dedent
        unit_tests = textwrap.dedent(unit_tests).strip()

        return {
            "additional_unit_tests": (unit_tests or "").strip(),
            "local_fix": "tests generated",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        question = component_traj.get("input", {}).get("question", "")
        unit_tests = component_traj.get("output", {}).get("additional_unit_tests", "")
        exec_result = signal.context.get("execution_result", {}) or {}

        lm_input = f"{self.variable}\nProblem:\n{question}\n\nUnit tests (python):\n"
        lm_output = unit_tests or ""
        objective = f"{signal.feedback}\nExecution result: {exec_result}"
        response_desc = "unit tests for execution"

        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"

        raw = get_llm_output(
            message=prompt,
            model=getattr(self.config, "backward_model", None) or self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


class Executor(BaseComponent):
    """BigCodeBench step 3/4: execute code + tests (tool component, not optimized)."""

    def __init__(self):
        super().__init__(
            description="Execute Python code with the provided unit tests and return results.",
            input_fields=["initial_code", "additional_unit_tests"],
            output_fields=["execution_result"],
        )

    def forward(self, initial_code, additional_unit_tests, entry_point="task_func"):
        raw = untrusted_check(
            initial_code,
            additional_unit_tests,
            entry_point,
            max_as_limit=300 * 1024,
            max_data_limit=300 * 1024,
            max_stack_limit=300 * 1024,
            min_time_limit=2,
            gt_time_limit=5,
        )
        passed = bool(raw[0] == "pass")
        trace_raw = raw[1] if len(raw) > 1 else ""
        stderr_raw = raw[2] if len(raw) > 2 else ""

        if isinstance(trace_raw, dict):
            trace = "\n".join([f"{k}: {v}" for k, v in trace_raw.items() if v])
        elif isinstance(trace_raw, list):
            trace = "\n".join([str(t) for t in trace_raw if t])
        else:
            trace = str(trace_raw)

        stderr = str(stderr_raw)
        if not passed and not trace.strip() and not stderr.strip():
            trace = "Unknown Error: The execution failed but captured no trace. This often means a SyntaxError in the generated code or tests, or a missing dependency that crashed the runner immediately."

        exec_result = {
            "passed": passed,
            "failed_tests": [] if passed else ["unknown"],
            "trace": trace[:3000],
            "stderr": stderr[:3000],
        }
        return {"execution_result": exec_result}


class FinalCodeGenerator(BaseComponent):
    """BigCodeBench step 4/4: Refine the code based on execution feedback."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    REPAIR_PROMPT = """        
        You are an expert Python Code Repair Agent.
        You will receive:
        1. A coding problem.
        2. The current buggy code.
        3. The execution error trace (traceback/failed tests).

        Your goal: **Rewrite the code to FIX the errors.**
        Output the **COMPLETE, CORRECTED** function/script.
        Do not use `pass` or `...` as placeholders.
        Return ONLY valid Python code inside markdown code fences.
        """

    def __init__(
        self,
        model="anthropic/claude-3-haiku",
        backward_model="openai/gpt-4o-mini",
        max_tokens=2048,
        temperature=0.0,
    ):
        super().__init__(
            description=(
                "BigCodeBench step 4/4 (FinalCodeGenerator): A Code REFINER. "
                "Responsibility: Fix logic/syntax errors in the base code based on execution traces. "
                "Output: The fully corrected, executable code. "
                "Limitations: This component is only for reparing logic/syntax errors in the base code based on execution trace."
                "If in the backward process, the error has not been shown in the execution trace, it is an UPSTREAM failure."
            ),
            input_fields=["question", "initial_code", "additional_unit_tests", "execution_result"],
            output_fields=["code", "local_fix", "upstream_grad", "patch_applied"],
            variable=self.REPAIR_PROMPT,
            config={
                "model": model,
                "backward_model": backward_model,
                "max_tokens": max_tokens,
                "temperature": temperature,
            },
        )

    def forward(self, **inputs):
        exec_result = inputs.get("execution_result", {}) or {}

        prompt = self.variable.format(
            initial_code=inputs.get("initial_code", ""),
            exec_result=f"Passed: {exec_result.get('passed')}\nTrace:\n{exec_result.get('trace', '')}\nStderr:\n{exec_result.get('stderr', '')}",
        )

        prompt = (
            f"{prompt}\n\nProblem Description:\n{inputs.get('question', '')}\n\nCorrected Code:\n"
        )

        raw = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )

        match = re.search(r"```python\s*(.*?)\s*```", raw, re.DOTALL | re.IGNORECASE)
        if match:
            patched_code = match.group(1).strip()
        else:
            patched_code = raw.strip()

        if len(patched_code) < 10:
            patched_code = inputs.get("initial_code", "")

        return {
            "code": patched_code,
            "local_fix": "code refined based on trace",
            "upstream_grad": STOP_GRADIENT,
            "patch_applied": True,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        exec_result = signal.context.get("execution_result", {}) or {}
        question = component_traj.get("input", {}).get("question", "")

        trace = exec_result.get("trace", "")
        stderr = exec_result.get("stderr", "")

        objective = f"{signal.feedback}\n" f"Execution result after refinement: {exec_result}"

        lm_input = (
            f"Problem: {question}\n"
            f"Base Code: {component_traj.get('input', {}).get('initial_code', '')}\n"
            f"Trace: {trace}\n"
        )
        lm_output = component_traj.get("output", {}).get("code", "")

        response_desc = "the refined/fixed code"

        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )

        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"

        raw = get_llm_output(
            message=prompt,
            model=getattr(self.config, "backward_model", None) or self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


def system_engine(*args, **kwargs):
    return CompoundAISystem(
        components={
            "code_generator": CodeGenerator(),
            "unit_test_generator": UnitTestGenerator(),
            "executor": Executor(),
            "final_code_generator": FinalCodeGenerator(),
        },
        final_output_fields=["code"],
        ground_fields=["unit_tests", "entry_point"],
        eval_func=lambda code, unit_tests, entry_point="task_func": float(
            untrusted_check(
                code,
                unit_tests,
                entry_point,
                max_as_limit=300 * 1024,
                max_data_limit=300 * 1024,
                max_stack_limit=300 * 1024,
                min_time_limit=2,
                gt_time_limit=5,
            )[0]
            == "pass"
        ),
        # eval_func=pass_rate,
        *args,
        **kwargs,
    )
