import logging
from typing import List, Optional, Set

import fb_client
from pddl.core import And, Formula

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from enum import Enum
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_parse_utils import parse_formula
from tp_lodge.utils.pddl_utils import (
    filter_formula_by_predicates,
    get_valid_predicates,
    is_predicates_subset,
    combine_predicates,
)

logger = logging.getLogger(__name__)


class FurnitureEnum(str, Enum):
    LAMP = "lamp"
    ROUND_TABLE = "round_table"


class RemoteMotionValidator(MotionValidator):

    def __init__(self, ip: str, port: int, furniture: FurnitureEnum):
        api_client = fb_client.ApiClient(configuration=fb_client.Configuration(host="http://%s:%d" % (ip, port)))
        self.validator_api = fb_client.DefaultApi(api_client=api_client)
        self.validator_api.set_environment_set_environment_post(furniture=furniture.value)

    def get_supported_predicates(self, domain: PDDLDomain) -> List[str]:
        return [
            p
            for p in self.validator_api.get_supported_predicates_get_supported_predicates_get().supported_predicates
            if p in [pd.name for pd in domain.predicates if pd.predefined]
        ]

    def _run_motion(self, motion: str):
        response = self.validator_api.run_motion_run_motion_post(
            run_motion_request_model=fb_client.RunMotionRequestModel(motion=motion),
        )
        if response.error_response is None:
            return
        else:
            raise MotionSimulationException(
                message=response.error_response,
                code=MotionSimulationResponseCode.PDDL_PY_TRANSLATION,
                expected="",
                ground_truth="",
            )

    def get_predicates_evaluation(self, domain: PDDLDomain, problem: PDDLProblem) -> Formula:
        response = self.validator_api.get_predicates_evaluation_predicates_evaluation_post(
            fb_client.GetValidPredicatesRequestModel(
                domain=str(
                    domain.copy_with(
                        predicates=[p for p in domain.predicates if p.name in self.get_supported_predicates(domain)]
                    ).to_pddl()
                ),
                problem=str(problem.to_pddl(force=True)),
            )
        )
        return parse_formula(response.predicates, only_variables=False)

    def validate_predicates(self, predicates: Set[Formula], domain: PDDLDomain, problem: PDDLProblem):
        gt_predicates = self.get_predicates_evaluation(domain=domain, problem=problem)
        supported_predicates = filter_formula_by_predicates(
            And(*predicates), known_predicates=self.get_supported_predicates(domain)
        ).operands
        if not is_predicates_subset(supported_predicates, subset_of=gt_predicates):
            raise MotionSimulationException(
                ground_truth=str(And(*get_valid_predicates(gt_predicates))),
                expected=str(And(*predicates)),
                code=MotionSimulationResponseCode.EFFECT_FAILED,
            )

    def reset(self, seed: int, init_hash: Optional[str] = None):
        self.validator_api.reset_reset_post(reset_env_request_model=fb_client.ResetEnvRequestModel(seed=seed))
        if init_hash is not None:
            self.set_env_hash(hash=init_hash)

    def get_env_hash(self) -> str:
        return self.validator_api.get_env_hash_get_env_hash_post().hash

    def set_env_hash(self, hash: str):
        self.validator_api.set_env_hash_set_env_hash_post(
            set_env_hash_request_model=fb_client.SetEnvHashRequestModel(hash=hash)
        )

    def inject_init_predicates(self, domain: PDDLDomain, problem: PDDLProblem):
        predicates_evaluation = self.get_predicates_evaluation(domain=domain, problem=problem)
        if problem.initial_state is not None:
            custom_predicates = filter_formula_by_predicates(
                problem.initial_state, known_predicates=self.get_supported_predicates(domain), inverse=True
            )
            predicates_evaluation = combine_predicates(predicates_evaluation, custom_predicates)
        return problem.copy_with(
            initial_state=predicates_evaluation,
            grounder_initial_state=predicates_evaluation,
        )
