import re
import time
from openai import OpenAI 
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate

CODE_AGENT_PROMPT = """
You are tasked with correcting a computation step of a generated solution for a physics problem.

Question: {question}

Error Step: {error_step}

Error Explanation: {error_explanation}

Task:
1. Carefully review the provided error step and identify the incorrect calculation that was performed.
2. Generate a Python code to correctly perform the computation and calculation where the failure occurred. 
3. The code should contain a function "def solve()" which returns a string describing the final computation and calculation result.
4. Make sure all the variables are intialized inside the solve() and it doesn't require any input and also import all the required libraries.
5. Make sure to correctly use variables values with appropiate unit conversions and that variables are properly intialized before they are used.

Ensure that the code includes the following:
- All required arithmetic operations.
- Algebraic manipulations.
- Application of mathematical procedures (e.g., integration, differentiation).
- Value substitution.
- Handling of fractions, exponents, and radicals.
- Numerical approximations or rounding.
- Dimensional analysis.

Use the following format for the code:
```python\n<--Your Code-->\n```
"""


REFINE_CALCULATION_PROMPT = """
You are a physics expert assistant.

You are given:
A physics problem: {question}
A code written to fix a computation mistake in the following step: {error_step}
Explanation of the original mistake: {error_explanation}
Correction Code: {code}
Correction Code Output: {code_output}

Task:
Using the code and it's output, generate natural language feedback. The feedback should:
    1. Explain the correct computation performed
    2. Clarify the mistake made
    3. Describe the correct logic and units
    4. Just write the correction feedback.

Use the following output format:
Feedback: <feedback>
"""

class CodeAgent:
    def init(self, question, error_step, error_explanation, llm_model_name, llm_api_key, llm_api_base):
        self.question = question
        self.error_step = error_step
        self.error_explanation = error_explanation
        self.llm_model_name = llm_model_name
        self.client = OpenAI(api_key=llm_api_key, base_url=llm_api_base)
        self.scratch_pad = ""

    def _llama_response(self, prompt):
        response = self.client.chat.completions.create(
            model=self.llm_model_name,
            messages=[{"role": "user", "content": prompt}],
        )
        return response.choices[0].message.content

    def _parse_code_block(self, raw_response):
        """Extract Python code block from LLM output."""
        match = re.search(r"```python(.*?)```", raw_response, re.DOTALL)
        return match.group(1).strip() if match else raw_response.strip()

    def generate_code(self):
        prompt = CODE_AGENT_PROMPT.format(
            question=self.question,
            error_step=self.error_step,
            error_explanation=self.error_explanation,
        )
        response = self._llama_response(prompt)
        code = self._parse_code_block(response)
        self.scratch_pad += f"\n\n--- Python Code ---\n{code}\n-------------------\n"
        return code

    def execute_code(self, code):
        local_context = {}
        try:
            exec(code, {}, local_context)
            result = local_context["solve"]()
            self.scratch_pad += f"\n\nCode Output: {result}\n"
            return result
        except Exception as e:
            self.scratch_pad += f"\n\nExecution Error: {e}\n"
            raise RuntimeError(f"Failed to execute code: {e}")

    def generate_feedback(self, code, code_output):
        prompt = REFINE_CALCULATION_PROMPT.format(
            question=self.question,
            error_step=self.error_step,
            error_explanation=self.error_explanation,
            code = code,
            code_output=code_output,
        )
        feedback = self._llama_response(prompt)
        self.scratch_pad += f"\n\n--- Feedback ---\n{feedback}\n----------------\n"
        return feedback

    def run(self):
        code = self.generate_code()
        result = self.execute_code(code)
        feedback = self.generate_feedback(code, result)
        return feedback
