from vllm import LLM
import sys
sys.path.append('.')
from src.utils.prompt_handler import PromptHandler
from src.agent import Agent 
from src.utils.parse import parse_response
import logging
import json
from typing import List, Dict, Any
from transformers import AutoTokenizer

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


CODEGEN_COMPLETE_MUTATOR_PHANDLER = PromptHandler(
    template=(
        "Modify the given function to introduce a subtle bug that causes some unit tests to fail.\n\n"
        "Rules:\n"
        "1. Respond with the entire modified function body (the code inside the function).\n"
        "2. Do not include the function signature, docstrings, or comments.\n\n"
        "Response Format:\n"
        "1. The modified function body, wrapped in ```python and ```\n"
        "2. A brief explanation of the bug prefixed with 'Explanation:'\n\n"
        "Example:\n"
        "Problem:\n"
        "def add(a, b):\n"
        "Correct Implementation:\n"
        "    return a + b\n\n"
        "Response:\n"
        "```python\n"
        "    return a - b\n"
        "```\n\n"
        "Explanation: Replaced addition operator with subtraction.\n\n"
        "Problem:\n{prompt}\n\n"
        "Correct Implementation:\n{canonical_solution}\n\n"
        "Response:\n"
    ),
    input_keys=["prompt", "canonical_solution"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_COMPLETE_MUTATOR_PHANDLER"
)

KODCODE_CODEGEN_COMPLETE_MUTATOR_PHANDLER = PromptHandler(
    template=(
        "Modify the code below to introduce a subtle bug that causes some unit tests to fail.\n\n"
        "Response Format:\n"
        "1. The modified function with function headers, wrapped in ```python and ```\n"
        "2. A brief explanation of the bug prefixed with 'Explanation:'\n\n"
        "Example:\n"
        "Problem:\n"
        "def add(a, b):\n"
        "Original Implementation:\n"
        "    return a + b\n\n"
        "Response:\n"
        "```python\n"
        "def add(a, b):\n"
        "    return a - b\n"
        "```\n\n"
        "Explanation: Replaced addition operator with subtraction.\n\n"
        "Problem:\n{prompt}\n\n"
        "Original Implementation:\n{canonical_solution}\n\n"
        "Response:\n"
    ),
    input_keys=["prompt", "canonical_solution"],
    output_format=str,
    strict_input=False,
    name="KODCODE_CODEGEN_COMPLETE_MUTATOR_PHANDLER"
)

CODEGEN_INSTRUCT_MUTATOR_PHANDLER = PromptHandler(
    template=(
        "Modify the given code to introduce a subtle bug that causes some unit tests to fail.\n\n"
        "Response Format:\n"
        "1. The entire modified code, wrapped in ```python and ```\n"
        "2. A brief explanation of the bug prefixed with 'Explanation:'\n\n"
        "Example:\n\n"
        "Problem:\n"
        "Write a function to add two numbers.\n\n"
        "Original Implementation:\n"
        "def add(a, b):\n"
        "    return a + b\n\n"
        "Response:\n"
        "```python\n"
        "def add(a, b):\n"
        "    return a - b\n"
        "```\n\n"
        "Explanation: Replaced addition operator with subtraction.\n\n"
        "Problem:\n{prompt}\n\n"
        "Original Implementation:\n{canonical_solution}\n\n"
        "Response:\n"
    ),
    input_keys=["prompt", "canonical_solution"],
    output_format=str,
    strict_input=False,
    name="CODEGEN_INSTRUCT_MUTATOR_PHANDLER"
)

def mutator_post_process_func(
    response: List[str], 
    input_fields: List[Dict[str, str]]
    ) -> Dict[str, List[Dict[str, Any]]]:
    results = []

    for i, resp in enumerate(response):
        input_data = input_fields[i]
        task_id = input_fields[i].get("task_id", f"task_{i}")

        if task_id not in results:
            mutation_result = {
                "success": False,
                "problem": input_data
            }
        try:
            # Parse the response using the existing parse_response function
            parsed = parse_response(resp)
            mutation_result = {
                "success": True,
                "mutation": parsed.get("code"),
                "mutation_explanation": parsed.get("explanation"),
                "error": None,
                "problem": input_data,
                "raw_response": resp
            }
            
        except Exception as e:
            logger.error(f"Parsing failed for response {i}: {str(e)}")
            mutation_result = {
                "success": False,
                "mutation": None,
                "mutation_explanation": None,
                "error": str(e),
                "problem": input_data,
                "raw_response": resp
            }

        results.append(mutation_result)

    return results


def mutator_engine(
    model_name: str,
    vllm_model: LLM = None,
    peft_dir: str = None,
    prompt_handler = CODEGEN_COMPLETE_MUTATOR_PHANDLER,
    post_process_func = mutator_post_process_func,
    verbose: bool = False,
    **kwargs
) -> Agent:
    if prompt_handler is None:
        raise ValueError("A valid prompt_handler must be provided.")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        prompt_handler.tokenizer = tokenizer
    except Exception as e:
        pass

    # Create and return the mutator agent
    return Agent(
        model_name=model_name,
        prompt_handler=prompt_handler,
        vllm_model=vllm_model,
        peft_dir=peft_dir,
        post_process_func=post_process_func,
        verbose=verbose,
        **kwargs
    )

