import sys
sys.path.append('.')
try:
    from vllm import LLM  
except ImportError:
    LLM = None  
from src.utils.prompt_handler import PromptHandler
from src.agent import Agent
from src.utils.parse import parse_response
import logging
import json
from typing import List, Dict, Any
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


CODEGEN_COMPLETE_SOLVER_PHANDLER = PromptHandler(
    template=(
        "Your task is to fix the buggy implementation of a function.\n\n"
        "Rules:\n"
        "1. Respond with the entire corrected function body (the full code inside the function, not just the lines that were wrong).\n"
        "2. Do not include any function headers, docstrings, comments, or test cases.\n"
        "3. Preserve correct parts of the original function — your response must be a complete and corrected function body.\n\n"
        "Response Format:\n"
        "1. The entire corrected function body (excluding the function header) surrounded by ```python ```\n"
        "2. A brief explanation of the fix prefixed with 'Explanation:'\n\n"
        "Example:\n"
        "Buggy Implementation:\n"
        "def add(a, b):\n"
        "    return a - b\n\n"
        "Response:\n"
        "```python\n"
        "    return a + b\n"
        "```\n\n"
        "Explanation: Fixed the incorrect subtraction operator to addition to properly implement the addition function.\n\n"
        "Buggy Implementation:\n"
        "{prompt}"
        "{mutation}\n\n"
        "Response:\n\n"
    ),
    input_keys=["prompt", "mutation"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_COMPLETE_SOLVER_PHANDLER"
)

KODCODE_CODEGEN_COMPLETE_SOLVER_PHANDLER = PromptHandler(
    template=(
        "Your task is to fix the buggy implementation of a function.\n\n"
        "Do not include docstrings, comments, or test cases.\n\n"
        "Respond ONLY with:\n"
        "1. The corrected function surrounded by ```python ```\n"
        "2. A brief explanation of the fix prefixed with 'Explanation:'\n\n"
        "Example:\n"
        "Buggy Implementation:\n"
        "def add(a, b):\n"
        "    return a - b\n\n"
        "Response:\n"
        "```python\n"
        "def add(a, b):\n"
        "    return a + b\n"
        "```\n\n"
        "Explanation: Fixed the incorrect subtraction operator to addition to properly implement the addition function.\n\n"
        "Buggy Implementation:\n"
        "{mutation}\n\n"
        "Response:\n\n"
    ),
    input_keys=["prompt", "mutation"],
    output_format=str,
    strict_input=False,
    name="KODCODE_CODEGEN_COMPLETE_SOLVER_PHANDLER"
)

CODEGEN_INSTRUCT_SOLVER_PHANDLER = PromptHandler(
    template=(
        "Your task is to fix buggy code.\n\n"
        "Respond with:\n"
        "1. The entire correct code surrounded by ```python ```\n"
        "2. A brief explanation of the fix prefixed with 'Explanation:'\n\n"
        "Example:\n"
        "Problem:\n"
        "Write a function to add two numbers.\n"
        "Buggy Implementation:\n"
        "def add(a, b):\n"
        "    return a - b\n\n"
        "Response:\n"
        "```python\n"
        "def add(a, b):\n"
        "    return a + b\n"
        "```\n\n"
        "Explanation: Fixed the incorrect subtraction operator to addition to properly implement the addition function.\n\n"
        "Problem:\n"
        "{prompt}\n"
        "Buggy Implementation:\n"
        "{mutation}\n\n"
        "Response:\n\n"
    ),
    input_keys=["prompt", "mutation"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_INSTRUCT_SOLVER_PHANDLER"
)

CODEGEN_COMPLETE_PHANDLER = PromptHandler(
    template=(
        "Complete the function below. Do not include the function header. Surround code with ```python ```.\n\n"
        "Example:\n"
        "def add(a, b):\n\n"
        "Response:\n"
        "```python\n"
        "    return a + b\n"
        "```\n\n"
        "{prompt}\n"
        "Response:\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_COMPLETE_PHANDLER"
)

KODCODE_CODEGEN_COMPLETE_PHANDLER = PromptHandler(
    template=(
        "Complete the function below. Include the function definition. Surround code with ```python ```.\n\n"
        "Example:\n"
        "def add(a, b):\n\n"
        "Response:\n"
        "```python\n"
        "def add(a, b):\n\n"
        "    return a + b\n"
        "```\n\n"
        "{prompt}\n"
        "Response:\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="KODCODE_CODEGEN_COMPLETE_PHANDLER"
)

CODEGEN_INSTRUCT_PHANDLER = PromptHandler(
    template=(
        "{prompt}\n"
        "Surround code with ```python ```.\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_INSTRUCT_PHANDLER"
)

INCORRECT_CODEGEN_COMPLETE_PHANDLER = PromptHandler(
    template=(
        "Complete the function below with incorrect code. Do not include the function header. Surround code with ```python ```.\n\n"
        "Example:\n"
        "def add(a, b):\n\n"
        "Response:\n"
        "```python\n"
        "    return a - b\n"
        "```\n\n"
        "{prompt}\n"
        "Response:\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="INCORRECT_CODEGEN_COMPLETE_PHANDLER"
)

KODCODE_INCORRECT_CODEGEN_COMPLETE_PHANDLER = PromptHandler(
    template=(
        "Complete the function below with incorrect code. Include the function header. Surround code with ```python ```.\n\n"
        "Example:\n"
        "def add(a, b):\n\n"
        "Response:\n"
        "```python\n"
        "    return a - b\n"
        "```\n\n"
        "{prompt}\n"
        "Response:\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="KODCODE_INCORRECT_CODEGEN_COMPLETE_PHANDLER"
)

INCORRECT_CODEGEN_INSTRUCT_PHANDLER = PromptHandler(
    template=(
        "Write incorrect code for the following problem:\n{prompt}\n"
        "Surround code with ```python ```.\n"
    ),
    input_keys=["prompt"],
    output_format=str,
    strict_input=False,
    name="INCORRECT_CODEGEN_INSTRUCT_PHANDLER"
)

def solver_post_process_func(
    response: List[str], 
    input_fields: List[Dict[str, str]]
    ) -> Dict[str, List[Dict[str, Any]]]:
    results = []

    for i, resp in enumerate(response):
        input_data = input_fields[i]
        task_id = input_fields[i].get("task_id", f"task_{i}")
        problem = input_data.pop("problem")

        if task_id not in results:
            solution_result = {
                "success": False,
                "problem": problem,
                **input_data
            }
        try:
            # Parse the response using the existing parse_response function
            parsed = parse_response(resp)
            solution_result = {
                "success": True,
                "solution": parsed.get("code"),
                "solution_explanation": parsed.get("explanation"),
                "error": None,
                "problem": problem,
                "raw_response": resp,
                **input_data
            }
        except Exception as e:
            logger.error(f"Parsing failed for response {i}: {str(e)}")
            solution_result = {
                "success": False,
                "solution": None,
                "solution_explanation": None,
                "error": str(e),
                "problem": problem,
                "raw_response": resp,
                **input_data
            }
        results.append(solution_result)

    return results

def solver_engine(
    model_name: str,
    vllm_model: LLM = None,
    peft_dir: str = None,
    prompt_handler = CODEGEN_COMPLETE_SOLVER_PHANDLER,
    post_process_func = solver_post_process_func,
    verbose: bool = False,
    **kwargs
) -> Agent:
    if prompt_handler is None:
        raise ValueError("A valid prompt_handler must be provided.")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        prompt_handler.tokenizer = tokenizer
    except Exception as e:
        pass

    # Create and return the solver agent
    return Agent(
        model_name=model_name,
        prompt_handler=prompt_handler,
        vllm_model=vllm_model,
        peft_dir=peft_dir,
        post_process_func=post_process_func,
        verbose=verbose,
        **kwargs
    )

