import os
import json
import threading
from typing import Any, Dict, List, Optional

from ast_policy import sha1
from remote_llm_socket_client import RemoteLLMSocketClient


class PolicyProvider:
    """
    Provides code generation and repair for robot skill execution.
    Supports:
      - step-wise revision (remote_llm)
      - failure-based repair (local_llm/remote_llm)
    Acts as a central provider for code adaptation during task execution loops.
    """

    def __init__(
        self,
        mode: str,                 # "static"|"local_llm"|"remote_llm"
        policy_path: str,
        task_name: str,
        source_robot: str,
        target_robot: str,
        model: str,
        initial_prompt: str,
        remote_host: str = "127.0.0.1",
        remote_port: int = 9000,
        object_names: list = None,
        descriptions: str = "",
        target_scene_info = None,
        grasp_guidance: str = "",
    ):
        self.mode = mode
        self.policy_path = policy_path
        self.target_policy_path = f"./tasks_{target_robot}/{task_name}.py"
        self.task_name = task_name
        self.source_robot = source_robot
        self.target_robot = target_robot
        self.model = model
        self.initial_prompt = initial_prompt
        self.object_names = object_names or []
        self.descriptions = descriptions
        self.target_scene_info = target_scene_info
        self.grasp_guidance = grasp_guidance
        self.last_response_num_tokens: Optional[int] = None

        self.client: Optional[RemoteLLMSocketClient] = None
        self.bg_client: Optional[RemoteLLMSocketClient] = None
        self._bg_client_lock = threading.Lock()
        self._remote_host = remote_host
        self._remote_port = remote_port
        if self.mode == "remote_llm":
            self.client = RemoteLLMSocketClient(
                remote_host,
                remote_port,
                connect_timeout=30.0,
                recv_timeout=600.0
            )
            print(f"[INFO] Connecting to remote LLM server at {remote_host}:{remote_port}...")
            self.client.connect()
            print(f"[INFO] Successfully connected to {remote_host}:{remote_port}")
        
        self.reference_code = self._read(self.source_robot, self.task_name)

    def close(self):
        if self.client:
            self.client.close()
            self.client = None
        if self.bg_client:
            self.bg_client.close()
            self.bg_client = None

    def get_initial_code(self) -> str:
        if self.mode == "static":
            with open(self.policy_path, "r", encoding="utf-8") as f:
                return f.read()

        if self.mode == "local_llm":
            from generate import generate_code_with_local_llm
            code = generate_code_with_local_llm(self.initial_prompt)
            code = self._save(code)
            return code

        if self.mode == "remote_llm":
            assert self.client is not None
            assert self.reference_code is not None

            resp = self.client.request({
                "type": "generate",
                "task_name": self.task_name,
                "robot_type": self.target_robot,
                "prompt": self.initial_prompt,
                "reference_code": self.reference_code,
                "available_objects": self.object_names,
                "descriptions": self.descriptions,
                "source_robot": self.source_robot,
                "target_robot": self.target_robot,
                "skill_name": self.task_name,
                "target_scene_info": self.target_scene_info,
                "grasp_guidance": self.grasp_guidance,
            })
            if resp.get("type") != "generate_result" or not resp.get("ok", False):
                raise RuntimeError(f"Remote generate failed: {resp}")
            code = resp.get("code", "")
            if not code.strip():
                raise RuntimeError("Remote generate returned empty code.")
            self.last_response_num_tokens = resp.get("num_tokens")
            code = self._save(code)
            return code

        raise ValueError(f"Unknown mode: {self.mode}")

    def revise_on_step(
        self,
        current_code: str,
        executed_stmt_texts: List[str],
        next_index: int,
        scene: Dict[str, Any],
        last_step_payload: Dict[str, Any],
    ) -> Optional[str]:
        if self.mode != "remote_llm":
            return None

        assert self.client is not None
        resp = self.client.request({
            "type": "revise",
            "task_name": self.task_name,
            "robot_type": self.target_robot,
            "meta": {"model": self.model},
            "constraints": {
                "must_keep_executed_prefix": True,
                "executed_prefix_count": len(executed_stmt_texts),
                "next_stmt_index": next_index,
            },
            "inputs": {
                "current_code": current_code,
                "executed_stmt_texts": executed_stmt_texts,
                "scene": scene,
                "last_step": last_step_payload,
                "grasp_guidance": self.grasp_guidance,
            }
        })

        if resp.get("type") != "revise_result" or not resp.get("ok", False):
            return None

        new_code = resp.get("code", "")
        print(f"[REVISE] Remote revise response received.: \n{new_code}")
        if not new_code.strip():
            return None
        if sha1(new_code) == sha1(current_code):
            return None
        self.last_response_num_tokens = resp.get("num_tokens")
        new_code = self._save(new_code)
        return new_code

    def repair_on_failure(
        self,
        current_code: str,
        executed_stmt_texts: List[str],
        failed_stmt_text: str,
        error_payload: Dict[str, Any],
        scene: Dict[str, Any],
    ) -> str:
        if self.mode == "static":
            raise RuntimeError("static mode cannot repair. Use local_llm or remote_llm.")

        if self.mode == "local_llm":
            from generate import generate_code_with_local_llm

            augmented = (
                self.initial_prompt
                + "\n\n[FAILURE]\n" + json.dumps(error_payload, ensure_ascii=False)
                + "\n\n[SCENE]\n" + json.dumps(scene, ensure_ascii=False)
                + "\n\n[FAILED_STATEMENT]\n" + failed_stmt_text
                + "\n\n[CONSTRAINT]\nDo NOT change already executed prefix statements.\n"
                + "\n[EXECUTED_PREFIX]\n" + "\n---\n".join(executed_stmt_texts)
            )
            new_code = generate_code_with_local_llm(augmented)
            if not new_code.strip():
                raise RuntimeError("local repair returned empty code.")
            new_code = self._save(new_code)
            return new_code

        # remote_llm
        assert self.client is not None
        resp = self.client.request({
            "type": "repair",
            "task_name": self.task_name,
            "robot_type": self.target_robot,
            "meta": {"model": self.model},
            "constraints": {
                "must_keep_executed_prefix": True,
                "executed_prefix_count": len(executed_stmt_texts),
            },
            "inputs": {
                "current_code": current_code,
                "executed_stmt_texts": executed_stmt_texts,
                "failed_stmt_text": failed_stmt_text,
                "error": error_payload,
                "scene": scene,
                "grasp_guidance": self.grasp_guidance,
            }
        })
        if resp.get("type") != "repair_result" or not resp.get("ok", False):
            raise RuntimeError(f"Remote repair failed: {resp}")

        new_code = resp.get("code", "")
        print(f"[REPAIR] Remote repair response received.: \n{new_code}")
        if not new_code.strip():
            raise RuntimeError("Remote repair returned empty code.")
        self.last_response_num_tokens = resp.get("num_tokens")  #   
        new_code = self._wrap_with_run_skill(new_code)
        new_code = self._save(new_code)
        return new_code

    def _wrap_with_run_skill(self, code: str) -> str:
        """
        Wrap repaired code with run_skill function signature.

        Converts LLM-generated function body into complete run_skill function.
        Example:
            Input:
                target_pos = [0.0, 0.0, 0.3]
                ball = Shape('ball')
                obs, reward, done = sawyer_pick(env, task, ...)
            Output:
                from skill_code import *
                def run_skill(env, task):
                    target_pos = [0.0, 0.0, 0.3]
                    ball = Shape('ball')
                    obs, reward, done = sawyer_pick(env, task, ...)
        """
        from headers import HEADER, RUN_SKILL_INTERFACE

        code = code.replace("```python", "").replace("```", "")

        if 'def run_skill' in code:
            return code

        stripped = code.strip()
        if not stripped:
            return code

        lines = code.splitlines()
        first_non_empty = None
        for line in lines:
            if line.strip():
                first_non_empty = line
                break

        if first_non_empty is None:
            return code

        if first_non_empty.lstrip().startswith(('import ', 'from ')):
            return code

        header = HEADER.get(self.target_robot, HEADER["panda"])
        result = header + "\n\n" + RUN_SKILL_INTERFACE + code
        print(f"[POSTPROCESS] Wrapped repaired code with run_skill function signature (robot: {self.target_robot})")
        return result

    def prefetch_infeasibility_repair(
        self,
        current_code: str,
        executed_stmt_texts: List[str],
        infeasible_stmt_text: str,
        infeasible_stmt_lineno: int,
        infeasibility_info: Dict[str, Any],
        scene: Dict[str, Any],
    ) -> Optional[str]:
        """
        Request code repair when a future statement's waypoints are found infeasible
        during prefetch/lookahead analysis.

        This is called BEFORE the infeasible statement is executed, allowing
        proactive repair without actually failing at runtime.

        Args:
            current_code: Current full source code
            executed_stmt_texts: List of already executed statement texts
            infeasible_stmt_text: The statement code that has infeasible waypoints
            infeasible_stmt_lineno: Line number of the infeasible statement
            infeasibility_info: Dict with details like:
                - skill_name: 'pick', 'place', 'move', 'push'
                - waypoint_index: index of first infeasible waypoint
                - error_message: description of the feasibility error
            scene: Current scene information

        Returns:
            New code with repaired statement, or None if repair fails/not applicable
        """
        if self.mode != "remote_llm":
            return None

        assert self.client is not None

        resp = self.client.request({
            "type": "prefetch_repair",
            "task_name": self.task_name,
            "robot_type": self.target_robot,
            "meta": {"model": self.model},
            "constraints": {
                "must_keep_executed_prefix": True,
                "executed_prefix_count": len(executed_stmt_texts),
                "infeasible_line": infeasible_stmt_lineno,
            },
            "inputs": {
                "current_code": current_code,
                "executed_stmt_texts": executed_stmt_texts,
                "infeasible_stmt_text": infeasible_stmt_text,
                "infeasibility_info": infeasibility_info,
                "scene": scene,
                "grasp_guidance": self.grasp_guidance,
            }
        })

        if resp.get("type") != "prefetch_repair_result" or not resp.get("ok", False):
            print(f"[PREFETCH_REPAIR] Remote prefetch repair failed: {resp}")
            return None

        new_code = resp.get("code", "")
        if not new_code.strip():
            return None
        if sha1(new_code) == sha1(current_code):
            return None

        new_code = self._save(new_code)
        return new_code

    def batch_invalid_repair(
        self,
        current_code: str,
        executed_stmt_texts: List[str],
        invalid_statements: List[Dict[str, Any]],
        scene: Dict[str, Any],
    ) -> Optional[str]:
        """
        Request code repair for multiple invalid future statements at once.

        This is called when background validation detects one or more future
        statements that violate preconditions based on projected state.

        NOTE: This method is called from a background thread, so it uses a
        separate socket connection (bg_client) to avoid conflicts with
        the main thread's client.

        Args:
            current_code: Current full source code
            executed_stmt_texts: List of already executed statement texts
            invalid_statements: List of dicts, each containing:
                - line_number: int
                - statement_code: str
                - skill_name: str
                - violations: list of violation reasons
                - warnings: list of warning messages
                - projected_state: dict summarizing projected state at that point
            scene: Current scene information

        Returns:
            New code with repaired statements, or None if repair fails/not applicable
        """
        if self.mode != "remote_llm":
            return None

        if not invalid_statements:
            return None

        with self._bg_client_lock:
            if self.bg_client is None or self.bg_client._sock is None:
                self.bg_client = RemoteLLMSocketClient(
                    self._remote_host,
                    self._remote_port,
                    connect_timeout=30.0,
                    recv_timeout=600.0
                )
                print(f"[INFO] Connecting background client to {self._remote_host}:{self._remote_port}...")
                try:
                    self.bg_client.connect()
                    print(f"[INFO] Background client connected successfully")
                except Exception as e:
                    print(f"[ERROR] Failed to connect background client: {e}")
                    self.bg_client = None
                    return None

            try:
                resp = self.bg_client.request({
                    "type": "batch_invalid_repair",
                    "task_name": self.task_name,
                    "robot_type": self.target_robot,
                    "meta": {"model": self.model},
                    "constraints": {
                        "must_keep_executed_prefix": True,
                        "executed_prefix_count": len(executed_stmt_texts),
                    },
                    "inputs": {
                        "current_code": current_code,
                        "executed_stmt_texts": executed_stmt_texts,
                        "invalid_statements": invalid_statements,
                        "scene": scene,
                        "grasp_guidance": self.grasp_guidance,
                    }
                })
            except (OSError, ConnectionError) as e:
                print(f"[BATCH_INVALID_REPAIR] Socket error, closing bg_client: {e}")
                if self.bg_client:
                    self.bg_client.close()
                self.bg_client = None
                return None

        if resp.get("type") != "batch_invalid_repair_result" or not resp.get("ok", False):
            print(f"[BATCH_INVALID_REPAIR] Remote batch repair failed: {resp}")
            return None

        new_code = resp.get("code", "")
        if not new_code.strip():
            return None
        if sha1(new_code) == sha1(current_code):
            return None

        self.last_response_num_tokens = resp.get("num_tokens")
        new_code = self._save(new_code)
        return new_code

    def _save(self, code: str) -> str:
        # Remove ``` markers from LLM output
        code = code.replace("```", "")
        os.makedirs(os.path.dirname(self.target_policy_path), exist_ok=True)
        with open(self.target_policy_path, "w", encoding="utf-8") as f:
            f.write(code)
        return code

    def _read(self, robot_name: str, task_name: str) -> str:
        path = f"./tasks_{robot_name}/{task_name}.py"
        with open(path, "r", encoding="utf-8") as f:
            return f.read()