from typeguard import check_type, TypeCheckError
import typing
import inspect
from typing import get_origin, get_args
from dataclasses import dataclass, field
import functools
from itertools import product
import json
from pathlib import Path
from collections import defaultdict
from pddl.core import Predicate
from pddl.logic.base import Not
from typing import Callable, Dict, List, Optional, Tuple, Union

from llm_utils.openai_api import Chat

from state_estimation.se_variable import SEVariable

data_dir = Path(__file__).parent
prompts_dir = data_dir / "prompts"


def get_global_namespace(variables: List[SEVariable]) -> Tuple[Dict[str, Callable], str]:
    """Create a global namespace for the grounding code."""
    global_namespace = {}

    gen_vars_by_t = defaultdict(list)
    for v_ in variables:
        gen_vars_by_t[v_.pddl_type].append(v_)

    for c_type, vars_for_type in gen_vars_by_t.items():
        vars = [v.value for v in vars_for_type]
        # i have to use partial to directly pas the `vars` to the function and not only after calling it
        if len(vars_for_type) > 1:
            plural = vars_for_type[0].type_tag_plural
            global_namespace[f"get_all_{plural}"] = functools.partial(lambda e: e, vars)
        elif len(vars_for_type) == 0:
            continue
        else:
            global_namespace[f"get_{c_type}"] = functools.partial(lambda e: e[0], vars)

    global_namespace_annotation = []
    for var, func in global_namespace.items():
        objs = func()
        if isinstance(objs, list):
            global_namespace_annotation.append(f"def {var}() -> List[{type(objs[0]).__name__}]: ...")
        else:
            global_namespace_annotation.append(f"def {var}() -> {type(objs).__name__}: ...")
    global_namespace_annotation = "\n".join(global_namespace_annotation)

    return global_namespace, global_namespace_annotation


def _parse_grounding_code(func_name: str, code: str, code_context: str) -> Callable:
    func_name = func_name.replace("-", "_")
    gn = {
        "__builtins__": __builtins__,  # Preserve original builtins including __import__
    }
    exec(code_context, gn, gn)

    # Create a custom global namespace that includes:
    # 1. Original builtins (including __import__)
    # 2. The low-level code API definitions

    exec(code, gn)  # define llm-generated predicate description function

    assert callable(gn[func_name]), f"Function {func_name} is not callable."

    defined_function = gn[func_name]

    return defined_function


@dataclass
class GroundingCallable:
    code: str
    callable: Callable
    out_dir: Path
    name: str
    chat: str

    description: str
    code_context: str

    hps: dict = field(default_factory=dict)
    referenced_groundings: dict[str, "GroundingCallable"] = field(default_factory=dict)

    def __post_init__(self):
        assert isinstance(self.callable, Callable), "The callable must be a valid Python callable."
        self.__signature__ = inspect.signature(self.callable)
        self._save_to_disk()

    @staticmethod
    def construct(
        code: str,
        out_dir: Path,
        name: str,
        code_context: str,
        description: str,
        chat: str,
        *,
        hps: Optional[dict] = None,
    ) -> "GroundingCallable":
        return GroundingCallable(
            code=code,
            callable=_parse_grounding_code(func_name=name, code=code, code_context=code_context),
            out_dir=out_dir,
            code_context=code_context,
            description=description,
            name=name,
            chat=chat,
            hps=hps or {},
        )

    @staticmethod
    def load_from_disk(out_dir: Path, name: str, code_context: str) -> "GroundingCallable":
        code = (out_dir / f"{name}.py").read_text()
        chat = (out_dir / f"{name}.chat").read_text()
        metadata = json.loads((out_dir / f"{name}.json").read_text())
        hps = metadata.get("hps", {})
        description = metadata.get("description", "")
        # diff_code = (out_dir / f"{name}_df.py").read_text()
        return GroundingCallable.construct(
            code=code,
            out_dir=out_dir,
            name=name,
            code_context=code_context,
            chat=chat,
            hps=hps,
            description=description,
        )

    def _save_to_disk(self):
        (self.out_dir / f"{self.name}.py").write_text(self.code)
        (self.out_dir / f"{self.name}.chat").write_text(self.chat)
        metadata = {"hps": self.hps, "description": self.description}
        (self.out_dir / f"{self.name}.json").write_text(json.dumps(metadata, indent=4))

    @property
    def func_name(self) -> str:
        return self.name.replace("-", "_")

    def __call__(self, *args, **kwargs):
        # we want to pass alternative hyperparameters to the grounding function
        kwargs = {**kwargs, **self.hps}

        type_hints = typing.get_type_hints(self.callable)

        try:
            bound = self.__signature__.bind(*args, **kwargs)
            bound.apply_defaults()
        except TypeError as e:
            raise TypeError(f"Signature mismatch: {e}")

        # Type checking
        for name, value in bound.arguments.items():
            if name in type_hints:
                expected = type_hints[name]
                if get_origin(expected) is typing.Union:
                    if type(value) not in get_args(expected):
                        raise TypeError(
                            f"Argument '{name}' of function {self.func_name} expects one of {[t.__name__ for t in get_args(expected)]}, got {value.__class__.__name__}"
                        )
                elif value.__class__.__name__ == expected.__name__:  # special case to allow injected code
                    continue

                try:
                    check_type(value, expected)
                except TypeCheckError as e:
                    raise TypeError(f"Argument '{name}' of function {self.func_name} has invalid type: {e}")

        return self.callable(*args, **kwargs)

    def _update_ns(self, ns: dict):
        self.callable.__globals__.update(ns)  # update the global namespace with the variables
        for g in self.referenced_groundings.values():
            g._update_ns(ns)  # update the global namespace with the variables
        self.callable.__globals__.update(self.referenced_groundings)

    def update_hps(self, hps: dict[str, float], save: bool = True):
        self.hps = hps
        if save:
            self._save_to_disk()

    def get_hps_names(self) -> List[str]:
        hps = list(self.__signature__.parameters.items())
        return [p[0] for p in hps if isinstance(p[1].annotation, type) and issubclass(p[1]._annotation, float)]

    def update_callable(self, code: str, description: str, chat: Chat):
        self.code = code
        self.description = description
        self.callable = _parse_grounding_code(func_name=self.name, code=code, code_context=self.code_context)
        self.__signature__ = inspect.signature(self.callable)
        self.chat = self.chat + "\n" + "-" * 100 + "\n" + str(chat)
        self.hps = {}

        self._save_to_disk()

    def ground(self, predicate: Predicate, variables: list[SEVariable]) -> list[Union[Predicate, Not]]:
        ll_objects_per_type = defaultdict(list)
        for var in variables:
            ll_objects_per_type[var.pddl_type].append(var)

        predicates_evaluation = []

        gn, _ = get_global_namespace(variables)

        assert all(len(t.type_tags) == 1 for t in predicate.terms)
        term_types = [list(term.type_tags)[0] for term in predicate.terms]
        # get all child objects for each type in `term_types`
        ll_objects_per_arg = [[t for t in ll_objects_per_type[term_type]] for term_type in term_types]

        for args in list(product(*ll_objects_per_arg)):
            if len([a.name for a in args]) != len(set(a.name for a in args)):
                # print("predicates with multiple times same args are ignored")
                continue

            self._update_ns(gn)  # update the global namespace with the variables
            is_valid = self(*[a.value for a in args])

            ground_predicate = Predicate(self.name, *[arg.pddl_object for arg in args])

            predicates_evaluation.append(Not(ground_predicate) if not is_valid else ground_predicate)

        return predicates_evaluation
