from __future__ import annotations

import pathlib
import re
import json
import logging
import traceback

import robotic as ry

from vtamp.policies.code_as_policies.policy import CaP
from vtamp.environments.utils import Action, Environment, State
from vtamp.policies.utils import query_llm
from vtamp.utils import parse_text_prompt, save_log, write_prompt


_, _ = Action(), State()
log = logging.getLogger(__name__)


class KomoPolicy(CaP):
    def __init__(
        self,
        twin=None,
        max_feedbacks=0,
        seed=0,
        gaussian_blur=False,
        max_csp_samples=0,
        use_komo=False,
        **kwargs,
    ):
        super().__init__(twin, max_feedbacks, seed, gaussian_blur, max_csp_samples)
        with open(pathlib.Path(__file__).parent / "prompt_BridgeEnv.txt", "r") as f:
            self.prompt = [{"role": "system", "content": f.read()}]
        self.plan = None

    def get_action(self, belief, goal: str) -> tuple[ry.Komo, dict]:
        statistics = {}
        env = self.twin
        content = f"1. Configuration: {belief}\n2. Instruction: {goal}"
        chat_history = self.prompt + [{"role": "user", "content": content}]

        # TODO: add retry limit to config
        retry_limit = 10
        for attempt in range(retry_limit):
            llm_response, query_time = query_llm(chat_history, seed=0)
            statistics["llm_query_time"] = statistics.get("llm_query_time", 0) + query_time
            write_prompt("llm_input.txt", chat_history)
            chat_history.append({"role": "assistant", "content": llm_response})
            save_log("llm_output.txt", llm_response)

            try:
                # Here we assume that belief.config and n_phases are needed by the komo setup
                config = belief.config  # needed to run the code
                llm_code = self.cleanup_komo(llm_response)
                exec(llm_code, globals(), locals())
                komo_object = locals()["komo"]

                # Run eval
                env.reset()
                env.step_komo(komo_object, vis=False)
                cost = env.compute_cost()
                state = env.getState()
                chat_history.append({
                    "role": "user",
                    "content": f"Your code is runnable. Your solution has a cost of {cost}. "
                               f"The task is not solved because the cost must be < 2e-2. "
                               f"Please try to improve your solution. "
                               f"The final environment state after executing your code was: {state}."
                })

                if cost < 2e-2:
                    self.plan = komo_object
                    break
                else:
                    print(f"Iter {attempt}: Valid komo, but cost too high: {cost}")

            except Exception as e:
                error_message = traceback.format_exc()
                print(f"Iter {attempt}: Code execution error:\n{error_message}")
                chat_history.append({
                    "role": "user",
                    "content": f"That didn't work! Error: {error_message}"
                })
                attempt += 1

        return self.plan, statistics

    @staticmethod
    def extract_json_between_markers(llm_output):
        # Regular expression pattern to find JSON content between ```json and ```
        json_pattern = r"```json(.*?)```"
        matches = re.findall(json_pattern, llm_output, re.DOTALL)

        if not matches or len(matches) == 0:
            # Fallback: Try to find any JSON-like content in the output
            json_pattern = r"({.*})"
            matches = re.findall(json_pattern, llm_output, re.DOTALL)

        for json_string in matches:
            json_string = json_string.strip()
            try:
                parsed_json = json.loads(json_string)
                return parsed_json
            except json.JSONDecodeError as e:
                print(e)
                # Attempt to fix common JSON issues
                try:
                    pattern = r'("code":\s*")([\s\S]*?)(")'
                    def replacer(match):
                        code_content = match.group(2)
                        escaped_content = code_content.replace("\n", "\\n")
                        return match.group(1) + escaped_content + match.group(3)
                    fixed_text = re.sub(pattern, replacer, json_string, flags=re.DOTALL)
                    fixed_text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", fixed_text)
                    parsed_json = json.loads(fixed_text)
                    return parsed_json
                except json.JSONDecodeError as e:
                    print(e)
                    try:
                        # Remove invalid control characters
                        json_string_clean = re.sub(r"[\x00-\x1F\x7F]", "", json_string)
                        parsed_json = json.loads(json_string_clean)
                        return parsed_json
                    except json.JSONDecodeError as e:
                        print(e)
                        continue

        return None

    def cleanup_komo(self, original_func: str) -> str:
        json_output = self.extract_json_between_markers(original_func)
        assert json_output is not None, "Failed to extract JSON from LLM output"
        n_phases = int(json_output.get("n_phases", 1))
        komo_content = json_output.get("code")
        with open(pathlib.Path(__file__).parent / "komo_setup.txt", "r") as f:
            komo_prefix = f.read()
        new_func = f"n_phases = {n_phases}\n" + komo_prefix + komo_content
        return new_func
