from enum import Enum
from pddl.core import And
import hashlib
import json
import logging
from PIL import Image
from typing import Dict, List, Optional

from pddl.core import Formula
from tp_lodge.utils.pddl_domain_syntax import parse_formula
from python_utils.data_utils import base64_to_image
from tp_lodge.motion_planning.dummy_motion_validator import PDDLDomain, PDDLProblem
from tp_lodge.utils.pddl_utils import combine_predicates, get_predicate_evaluation

import fb_client
from fb_client.models.point3d import Point3d
from tp_lodge.motion_planning.motion_validator import MotionSimulationException, MotionSimulationResponseCode

from .llm_code_interface import Part, Robot, Table
from state_estimation.motion_validation.se_motion_validator import SEMotionValidator
from state_estimation.predicate_grounder import PredicateGrounder
from state_estimation.se_variable import SEVariable
from state_estimation.motion_validation.reply_buffer import ReplyBuffer
from state_estimation.vlm_grounder import VLMGrounder


def point_to_list(p: Point3d):
    return (p.x, p.y, p.z)


logger = logging.getLogger(__name__)


class FurnitureEnum(str, Enum):
    LAMP = "lamp"
    ROUND_TABLE = "round_table"


class RemoteSEMotionValidator(SEMotionValidator):

    def __init__(
        self,
        grounder: PredicateGrounder,
        vlm_grounder: VLMGrounder,
        reply_buffer: ReplyBuffer,
        ip: str,
        port: int,
        furniture: FurnitureEnum,
    ):
        super().__init__(grounder=grounder)

        self.vlm_grounder = vlm_grounder

        api_client = fb_client.ApiClient(configuration=fb_client.Configuration(host="http://%s:%d" % (ip, port)))
        self.validator_api = fb_client.DefaultApi(api_client=api_client)
        self.validator_api.set_environment_set_environment_post(furniture=furniture.value)

        self.reply_buffer = reply_buffer

    def _run_motion(self, motion: str):
        prev_env_hash = self.get_env_hash()
        response = self.validator_api.run_motion_run_motion_post(
            run_motion_request_model=fb_client.RunMotionRequestModel(motion=motion),
        )
        if response.error_response is None:
            # for state in response.states:
            #     self._add_rb_state(
            #         vars=self._parse_to_variables(robot_arm=state.robot_arm, rigid_bodies=state.rigid_bodies, table=state.table),
            #         image=base64_to_image(state.image),
            #     )

            self._add_rb_state(
                env_hash=self.get_env_hash(),
                executed_skill=motion,
                prev_env_hash=prev_env_hash,
            )
        else:
            logger.info("Motion execution failed: %s", response.error_response)
            raise MotionSimulationException(
                message=response.error_response,
                code=(
                    MotionSimulationResponseCode.PDDL_PY_TRANSLATION
                    if response.translation
                    else MotionSimulationResponseCode.EFFECT_FAILED
                ),
                expected="",
                ground_truth="",
            )

    def get_variables(self) -> List[SEVariable]:
        rigid_bodies = self.validator_api.get_rigid_bodies_rigid_bodies_get()
        robot_arm = self.validator_api.get_robot_arm_robot_arm_get()
        table = self.validator_api.get_table_table_get()
        return self._parse_to_variables(robot_arm=robot_arm, rigid_bodies=rigid_bodies, table=table)

    def _parse_to_variables(
        self, robot_arm: fb_client.RobotArm, rigid_bodies: Dict[str, fb_client.RigidBody], table: fb_client.Table
    ) -> List[SEVariable]:

        vars = [
            SEVariable(
                name="arm",
                value=Robot(
                    gripper_closed=robot_arm.gripper_closed,
                    gripper_center=point_to_list(robot_arm.gripper_position),
                ),
            ),
            SEVariable(
                name="table",
                value=Table(surface_z=table.surface_z),
            ),
        ]

        for name, rigid_body in rigid_bodies.items():
            import numpy as np

            min_bound = point_to_list(rigid_body.min_bound)
            max_bound = point_to_list(rigid_body.max_bound)
            center = (np.asanyarray(max_bound) + np.asanyarray(min_bound)) / 2

            vars.append(
                SEVariable(
                    name=name,
                    value=Part(
                        bounding_box=(*min_bound, *max_bound),
                        center=center.tolist(),
                        orientation=point_to_list(rigid_body.orientation),
                    ),
                )
            )

        vars = [self.reply_buffer.var_parser.get_printable_for_llm(var, precision=3) for var in vars]

        return vars

    def _add_rb_state(
        self,
        env_hash: str,
        vars: Optional[List[SEVariable]] = None,
        image: Optional[Image.Image] = None,
        prev_env_hash: Optional[str] = None,
        executed_skill: Optional[str] = None,
    ):
        if vars is None:
            vars = self.get_variables()

        var_parser = self.reply_buffer.var_parser

        state_hash = hashlib.sha256(
            json.dumps(
                {v.name: var_parser.to_dict(var_parser.get_printable_for_llm(v)) for v in vars}, sort_keys=True
            ).encode("utf-8")
        ).hexdigest()

        if image is None:
            image = base64_to_image(self.validator_api.get_image_image_get())

        self.reply_buffer.add_state(
            hash=state_hash,
            variables=self.get_variables(),
            image=image,
            env_hash=env_hash,
            prev_env_hash=prev_env_hash,
            executed_skill=executed_skill,
        )

    def reset(self, seed: int, init_hash: Optional[str]):
        self.validator_api.reset_reset_post(reset_env_request_model=fb_client.ResetEnvRequestModel(seed=seed))
        if init_hash is not None:
            self.set_env_hash(hash=init_hash)

        self._add_rb_state(env_hash=self.get_env_hash())

    def _refresh_reply_buffer_with_vlm(self, domain: PDDLDomain, problem: PDDLProblem):
        """we'll check whether all predicates in the domain are evaluated in the reply buffer.

        If not, use the VLM to evaluate them.
        """
        all_predicate_names = [p.name for p in domain.predicates]
        # TODO: this makes it dependant on the current domain what predicate for every state. we dont reground the state if we have another set of predicates later on
        new_predicates = set(all_predicate_names) - set(self.reply_buffer.evaluated_predicates)
        for state_hash, state in self.reply_buffer.get_all_states().items():
            if state.similar_state is not None:
                continue

            preds_to_evaluate = []
            if state.predicates is None:
                # we did not evaluate this state at all. Evaluate all predicates
                preds_to_evaluate = domain.predicates

            elif len(new_predicates) > 0:
                # we have some new predicates we need to evaluate
                preds_to_evaluate = [domain.get_predicate(name) for name in new_predicates]

            if len(preds_to_evaluate) > 0:
                # use the VLM to evaluate the predicates
                # preds_to_evaluate = [p for p in preds_to_evaluate if p.is_visual]

                kwargs = {}
                if state.prev_state_hash is not None:
                    assert state.executed_skill is not None
                    prev_state = self.reply_buffer.get_state(state.prev_state_hash)
                    if prev_state.similar_state is not None:
                        _, prev_state = self.reply_buffer.get_similar_state(prev_state)
                    preds_to_evaluate_names = [p.name for p in preds_to_evaluate]

                    use_grounder_function = True
                    if use_grounder_function:
                        prev_predicates = get_predicate_evaluation(
                            self.grounder.ground_state(
                                predicates=preds_to_evaluate, variables=prev_state.variables, verify=False
                            )
                        )
                        # predefined predictes are grounded by the vlm, but not with the grounder -> use prev vlm state
                        assert prev_state.predicates is not None
                        prev_predicates = {
                            **prev_predicates,
                            **{
                                k: v
                                for k, v in prev_state.predicates.items()
                                if k.name in [p.name for p in preds_to_evaluate if p.predefined]
                            },
                        }
                    else:
                        prev_predicates = prev_state.predicates

                    assert prev_predicates is not None
                    prev_predicates = {k: v for k, v in prev_predicates.items() if k.name in preds_to_evaluate_names}

                    kwargs["prev_grounded"] = prev_predicates
                    kwargs["executed_skill"] = state.executed_skill
                    kwargs["prev_image"] = self.reply_buffer.get_image(state_hash=state.prev_state_hash)

                ground_predicates = self.vlm_grounder.ground_predicates_of_state(
                    variables=[self.reply_buffer.var_parser.to_dict(v) for v in state.variables],
                    image=self.reply_buffer.get_image(state_hash=state_hash),
                    predicates=preds_to_evaluate,
                    types=domain.types,
                    objects=problem.objects,
                    **kwargs,
                )

                if state.predicates is not None:
                    ground_predicates = {**state.predicates, **ground_predicates}
                self.reply_buffer.set_predicates(state_hash=state_hash, predicates=ground_predicates)
        self.reply_buffer.set_evaluated_predicates(all_predicate_names)

    def get_predicates_evaluation(self, domain: PDDLDomain, problem: PDDLProblem) -> Formula:
        # if not self.grounder.up_to_date(domain.predicates):
        #     # we must first generate all predicates before evaluating the vlm
        #     pred_eval = super().get_predicates_evaluation(domain, problem, verify=False)

        # for predicate in domain.predicates:
        #     grounder = self.grounder.get_grounder_for_predicate(predicate)
        #     if grounder is not None:
        #         predicate.update_inplace(description=grounder.description)

        self._refresh_reply_buffer_with_vlm(domain, problem)
        self.grounder.update_grounder_functions(domain.predicates)

        pred_eval = super().get_predicates_evaluation(domain, problem)

        response = self.validator_api.get_predicates_evaluation_predicates_evaluation_post(
            fb_client.GetValidPredicatesRequestModel(
                domain=str(domain.copy_with(predicates=[p for p in domain.predicates if p.predefined]).to_pddl()),
                problem=str(problem.to_pddl(force=True)),
            )
        )
        predefined_eval = parse_formula(response.predicates, only_variables=False)

        return And(*combine_predicates(pred_eval, predefined_eval))

    def get_env_hash(self) -> str:
        return self.validator_api.get_env_hash_get_env_hash_post().hash

    def set_env_hash(self, hash: str):
        self.validator_api.set_env_hash_set_env_hash_post(
            set_env_hash_request_model=fb_client.SetEnvHashRequestModel(hash=hash)
        )
