import ast
import logging
import re
from pathlib import Path
from pddl.core import Predicate
from pddl.logic.base import Not
from typing import Dict, List, Optional, Tuple, Union

from llm_utils.openai_api import Chat, TextMessageContent
from llm_utils.prompt_generation import Prompt
from llm_utils.textgen_api import TextGenApi
from python_utils.string_utils import get_markup_from_text, snake_to_camel
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate
from tp_lodge.utils.python_parse_utils import (
    remove_function_docstrings,
    get_top_level_definitions_from_code_str,
    remove_objects_from_code_str,
    remove_global_instructions,
    get_function_arg_variables,
)

from state_estimation.grounding_callable import GroundingCallable, get_global_namespace
from state_estimation.se_variable import SEVariable

data_dir = Path(__file__).parent
prompts_dir = data_dir / "prompts"
logger = logging.getLogger(__name__)


class PredicateGrounder:

    def __init__(self, code_api_file: Path, out_dir: Path, textgen_api: TextGenApi, domain_knowledge: str) -> None:
        self.out_dir = out_dir
        self.textgen_api = textgen_api
        self.domain_knowledge = domain_knowledge
        if not out_dir.is_dir():
            out_dir.mkdir()

        assert code_api_file.is_file()
        self.low_level_code_api = code_api_file.read_text()
        self._low_level_class_names = [
            e.name for e in ast.parse(self.low_level_code_api).body if isinstance(e, ast.ClassDef)
        ]

        self._load_from_disk(out_dir=out_dir)

    def set_namespace_annotation(self, vars: List[SEVariable]):
        self.gn_annotation = get_global_namespace(vars)[1]

    def _load_from_disk(self, out_dir: Path):
        grounding_callable: Dict[str, GroundingCallable] = {}
        for file in out_dir.glob("*.py"):
            callable = GroundingCallable.load_from_disk(
                out_dir=out_dir, name=file.stem, code_context=self.low_level_code_api
            )
            grounding_callable[callable.name] = callable
        self.grounding_callable = grounding_callable

    def _add_grounding_code(
        self, predicate: PDDLPredicate, code: str, description: str, chat: Chat, differentiable: bool
    ):
        """Add grounding code for a predicate to the internal dictionary."""
        assert predicate.name not in self.grounding_callable
        self.grounding_callable[predicate.name] = GroundingCallable.construct(
            code=code,
            out_dir=self.out_dir,
            name=predicate.name,
            description=description,
            code_context=self.low_level_code_api,
            chat=str(chat),
        )

    def differentiable_stub_of_predicate(self, predicate: PDDLPredicate):
        """Get a differentiable stub of a PDDL predicate."""
        non_diff_code_str = self.grounding_callable[predicate.definition.name].code
        code = ast.parse(non_diff_code_str).body

        func_stub = self.predicate_to_function_def(predicate, differentiable=True)

        for line in code:
            if not isinstance(line, ast.FunctionDef):
                continue

            if line.name != predicate.definition.name:
                continue

            non_diff_func = line

            non_diff_func.body = func_stub.body
            non_diff_func.returns = func_stub.returns

            def map_args(args):
                for arg in args:
                    assert isinstance(arg, ast.arg), f"Expected ast.arg, got {type(arg)} for {arg}"
                    if arg.annotation.id in ["float"]:
                        arg.annotation.id = "torch.Tensor"

            map_args(non_diff_func.args.args)
            map_args(non_diff_func.args.kwonlyargs)
            map_args(non_diff_func.args.posonlyargs)
            new_defs = []
            for default in non_diff_func.args.kw_defaults:
                assert isinstance(default, ast.Constant) and isinstance(default.value, float)
                new_defs.append(
                    ast.Call(
                        func=ast.Name(id="torch.tensor", ctx=ast.Load()),
                        args=[default],
                        keywords=[],
                    )
                )
            non_diff_func.args.kw_defaults = new_defs

            return non_diff_func

    def predicate_to_function_def(self, predicate: PDDLPredicate, *, differentiable: bool = False) -> ast.FunctionDef:
        """Convert a PDDL predicate to a Python function stub."""
        docstring = predicate.description
        function_name = predicate.definition.name.replace("-", "_")

        # replace ?obj with `obj` in the docstring
        docstring = re.sub(r"\?(\w+)", r"`\1`", docstring)

        if differentiable:
            docstring += "\nReturns: A scalar representing the probability of the predicate being true (0.0 to 1.0).\n"
        else:
            docstring += "\nReturns: True if the predicate holds, False otherwise.\n"

        args = []
        for arg in predicate.definition.terms:
            type_hint = snake_to_camel(list(arg.type_tags)[0])
            # Ensure first character is uppercase for proper type naming
            type_hint = type_hint[0].upper() + type_hint[1:] if type_hint else ""
            assert type_hint in self._low_level_class_names, f"Type hint {type_hint} not found in low-level code API."

            arg = ast.arg(arg=arg.name, annotation=ast.Name(id=type_hint, ctx=ast.Load()))
            args.append(arg)

        func_node = ast.FunctionDef(
            name=function_name,
            args=ast.arguments(
                posonlyargs=[],
                args=args,
                kwonlyargs=[],
                kw_defaults=[],
                # kwarg=ast.arg(arg="kwargs", annotation=None),
                defaults=[],
            ),
            body=[
                ast.Expr(value=ast.Constant(value=docstring)),
                ast.Expr(value=ast.Constant(value=...)),
            ],
            decorator_list=[],
            returns=ast.Name(id="bool" if not differentiable else "torch.Tensor", ctx=ast.Load()),
        )
        func_node = ast.fix_missing_locations(func_node)
        return func_node

    def _parse_grounder_function(
        self, text: str, predicate: PDDLPredicate, code: str
    ) -> Tuple[Optional[str], Optional[str]]:
        py_function_str = get_markup_from_text(text, ["python"])
        if len(py_function_str) != 1:
            return (
                None,
                "The response did not contain a single Python code block. Multiple code blocks are not supported. Try again.",
            )

        py_function_str = py_function_str[0]

        func_name = self.predicate_to_function_def(predicate, differentiable=False).name
        if len([o for o in get_top_level_definitions_from_code_str(py_function_str).keys() if o == func_name]) != 1:
            return None, f"Multiple definitions of the function {func_name} found. Only return "

        # remove already defined definitions from generated code
        tl_objects = get_top_level_definitions_from_code_str(code)
        py_function_str = remove_objects_from_code_str(py_function_str, list(tl_objects.keys()))
        py_function_str = remove_global_instructions(py_function_str)

        arg_vars = get_function_arg_variables(py_function_str)
        func_arg_vars = [var for var in arg_vars if tl_objects.get(var) == 'function']

        if len(func_arg_vars) > 0:
            return None, f"Function arguments {func_arg_vars} cannot be used. Directly use them or define new functions."

        return py_function_str, None

    def _get_code_section(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        differentiable: bool = False,
    ) -> str:
        code = self.low_level_code_api
        code += "\n\n" + self.gn_annotation
        existing_func_defs = "\n".join(
            [
                ast.unparse(self.predicate_to_function_def(p, differentiable=differentiable))
                for p in existing_predicates
                if p.definition.name != predicate.definition.name and p.is_visual
            ]
        )
        code += "\n\n" + existing_func_defs

        return code

    def refine_grounder_function_for_predicate(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        error: str,
        *,
        reprompt_chat: Optional[Chat] = None,
        differentiable: bool = False,
    ):
        """Generate Python code that implements the grounding of a PDDL predicate."""
        logger.info(f"Refining grounder function for predicate: {predicate.name}")
        code = self._get_code_section(
            predicate=predicate,
            existing_predicates=existing_predicates,
            differentiable=differentiable,
        )

        grounding = self.grounding_callable[predicate.name]

        if differentiable:
            func_str = grounding.differentiable_code
        else:
            func_str = grounding.code
        assert func_str is not None

        func_str = remove_function_docstrings(func_str)

        if reprompt_chat is None:
            prompt = Prompt.load_from_file(prompts_dir / "refine_predicate_code.xml")
            prompt.replace_all(
                domain_knowledge=self.domain_knowledge,
                predicate=str(predicate),
                py_predicate_stub=func_str,
                code=code,
                errors=error,
            )

            chat = prompt.to_chat()
        else:
            chat = reprompt_chat

        response = self.textgen_api.do_call(chat)

        chat = chat.add_message(response)

        content = response.content[0]
        assert isinstance(content, TextMessageContent)
        text = content.text.strip()

        matches = re.search(r"# Fixed Code(.*?)# Grounder Description(.*?)(?:\[END OUTLINE\]|$)", text, re.DOTALL)
        if matches is not None:
            py_code_str = matches.group(1).strip()
            py_description = matches.group(2).strip()

            py_function_str, error = self._parse_grounder_function(py_code_str, predicate, code)
        else:
            error = "You should list your fixed code below `# Fixed Code`"

        if error is not None:
            reprompt_chat = chat.add_user_text(error)
            return self.refine_grounder_function_for_predicate(
                predicate=predicate,
                reprompt_chat=reprompt_chat,
                existing_predicates=existing_predicates,
                differentiable=differentiable,
                error=error,
            )

        assert py_function_str is not None, "No valid Python function string was generated."
        return py_function_str, py_description, chat

    def generate_grounder_function_for_predicate(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        *,
        reprompt_chat: Optional[Chat] = None,
        differentiable: bool = False,
    ):
        """Generate Python code that implements the grounding of a PDDL predicate."""
        code = self._get_code_section(
            predicate=predicate,
            existing_predicates=existing_predicates,
            differentiable=differentiable,
        )

        if differentiable:
            func_str = self.grounding_callable[predicate.definition.name].code
        else:
            func_str = ast.unparse(self.predicate_to_function_def(predicate, differentiable=False))

        if reprompt_chat is None:
            if not differentiable:
                prompt = Prompt.load_from_file(prompts_dir / "get_predicate_code.xml")
                prompt.replace_all(
                    domain_knowledge=self.domain_knowledge,
                    predicate=str(predicate),
                    py_predicate_stub=func_str,
                    code=code,
                )
            else:
                diff_signature = self.differentiable_stub_of_predicate(predicate)
                assert diff_signature is not None
                prompt = Prompt.load_from_file(prompts_dir / "trans_diff_predicate_code.xml")
                prompt.replace_all(
                    predicate=str(predicate),
                    py_predicate_stub=func_str,
                    code=code,
                    diff_signature=ast.unparse(diff_signature),
                )

            chat = prompt.to_chat()
        else:
            chat = reprompt_chat

        response = self.textgen_api.do_call(chat)

        chat = chat.add_message(response)

        content = response.content[0]
        assert isinstance(content, TextMessageContent)
        text = content.text.strip()

        matches = re.search(r"# Predicate Grounding(.*?)# Grounder Description(.*?)(?:\[END OUTLINE\]|$)", text, re.DOTALL)
        if matches is not None:
            py_code_str = matches.group(1).strip()
            py_description = matches.group(2).strip()

            py_function_str, error = self._parse_grounder_function(py_code_str, predicate, code)
        else:
            error = "You should list your fixed code below the # Predicate Grounding and # Grounder Description"

        if error is not None:
            reprompt_chat = chat.add_user_text(error)
            return self.generate_grounder_function_for_predicate(
                predicate=predicate,
                reprompt_chat=reprompt_chat,
                existing_predicates=existing_predicates,
                differentiable=differentiable,
            )

        assert py_function_str is not None, "No valid Python function string was generated."
        return py_function_str, py_description, chat

    def get_grounder_function(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        differentiable: bool = False,
        verify: bool = True,
    ) -> GroundingCallable:
        """Ground a PDDL predicate to a Python function."""
        assert not differentiable

        if predicate.name not in self.grounding_callable:
            grounding_code_str, grounding_description, chat = self.generate_grounder_function_for_predicate(
                predicate,
                existing_predicates=existing_predicates,
                differentiable=differentiable,
            )

            self._add_grounding_code(
                predicate=predicate,
                code=grounding_code_str,
                description=grounding_description,
                chat=chat,
                differentiable=differentiable,
            )

        # Use the existing grounding code
        g_callable = self.grounding_callable[predicate.name]

        # we have to check what predicate definitions `grounding_code_str` uses
        ex_predicates_by_f_name = {
            self.predicate_to_function_def(p, differentiable=differentiable).name: p
            for p in existing_predicates
            if p.name != predicate.name
        }

        # avoid cycles by removing this predicate -> wont be shown to rely on for parsing
        child_existing_predicates = [v for k, v in ex_predicates_by_f_name.items() if k != g_callable.func_name]
        called_predicates = {}
        for node in ast.walk(ast.parse(g_callable.code)):
            if (
                not isinstance(node, ast.Call)
                or not hasattr(node.func, "id")
                or node.func.id not in ex_predicates_by_f_name
            ):
                continue
            if node.func.id in called_predicates:
                continue

            called_predicates[node.func.id] = self.get_grounder_function(
                predicate=ex_predicates_by_f_name[node.func.id],
                existing_predicates=child_existing_predicates,
                differentiable=differentiable,
                verify=verify,
            )

        g_callable.referenced_groundings = called_predicates

        return g_callable

    def has_grounder(self, predicate: PDDLPredicate, differentiable: bool) -> bool:
        assert not differentiable, "Differentiable predicates are not supported yet."
        return self.get_grounder_for_predicate(predicate) is not None

    def get_grounder_for_predicate(self, predicate: PDDLPredicate) -> Optional[GroundingCallable]:
        return self.grounding_callable.get(predicate.name)

    def up_to_date(self, predicates: List[PDDLPredicate]) -> bool:
        return all(self.has_grounder(p, differentiable=False) for p in predicates if p.is_visual)

    def update_grounder_functions(self, predicates: List[PDDLPredicate], verify: bool = True):
        """Update grounding functions for a list of predicates."""
        for predicate in predicates:
            if not predicate.is_visual:
                continue

            self.get_grounder_function(
                predicate=predicate,
                existing_predicates=predicates,
                differentiable=False,
                verify=verify,
            )


    def ground_state(
        self,
        predicates: List[PDDLPredicate],
        variables: List[SEVariable],
        verify: bool = True,
    ) -> List[Union[Predicate, Not]]:
        predicates_evaluation = []
        for predicate in predicates:
            if not predicate.is_visual:
                continue

            func = self.get_grounder_function(
                predicate=predicate,
                existing_predicates=predicates,
                differentiable=False,
                verify=verify,
            )
            ground_predicates_for_predicate = func.ground(predicate.definition, variables)

            predicates_evaluation.extend(ground_predicates_for_predicate)

        return predicates_evaluation
