import ast
from PIL import Image
from typing import Dict
from fastapi import FastAPI
from pddl.core import And
from pddl.parser.domain import DomainParser
from pddl.parser.problem import ProblemParser

from furniture_bench_api.api.api_predicates_validator import supported_predicates, StateValidator
from furniture_bench_api.api.api_schema import (
    GetEnvHashResponseModel,
    GetEnvSeedResponseModel,
    GetSupportedPredicatesResponseModel,
    GetValidPredicatesRequestModel,
    GetValidPredicatesResponseModel,
    ResetEnvRequestModel,
    RigidBody,
    RobotArm,
    RunMotionRequestModel,
    RunMotionResponseModel,
    SetEnvHashRequestModel,
    Table,
)
from furniture_bench_api.api.api_skills_parser import FunctionParser
from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment
from furniture_bench_api.utils.image_utils import get_image_with_labels
from furniture_bench_api.utils.other_utils import image_to_base64

app = FastAPI()
furniture_bench_env = FurnitureBenchEnvironment()
tolerance = 5e-3

@app.post("/set-environment")
def set_environment(furniture: str):
    global furniture_bench_env
    del furniture_bench_env
    furniture_bench_env = FurnitureBenchEnvironment(furniture=furniture)


@app.get("/stop-recording")
def close():
    print("stop-recording")
    global furniture_bench_env
    furniture_bench_env.stop_recording()


@app.post("/reset")
def reset(args: ResetEnvRequestModel):
    global furniture_bench_env
    furniture_bench_env.restart_recording()
    furniture_bench_env.reset_env(seed=args.seed)

    for grounder in supported_predicates.values():
        grounder.reset()

    Image.fromarray(furniture_bench_env.get_observation()["color_image2"][0].cpu().numpy()).save("image2.png")


@app.post("/get-seed")
def get_seed() -> GetEnvSeedResponseModel:
    global furniture_bench_env
    if hasattr(furniture_bench_env, "env_state"):
        seed = furniture_bench_env.env_state.seed
    else:
        seed = None
    return {"seed": seed}


@app.post("/get-env-hash")
def get_env_hash() -> GetEnvHashResponseModel:
    global furniture_bench_env
    return GetEnvHashResponseModel(hash=furniture_bench_env.get_hash())


@app.post("/set-env-hash")
def set_env_hash(data: SetEnvHashRequestModel):
    global furniture_bench_env

    furniture_bench_env.set_hash(hash=data.hash)
    assert furniture_bench_env.get_hash() == data.hash

    return None


@app.get("/rigid-bodies")
def get_rigid_bodies() -> Dict[str, RigidBody]:
    global furniture_bench_env
    return furniture_bench_env.rigid_bodies()


@app.get("/table")
def get_table() -> Table:
    global furniture_bench_env
    return furniture_bench_env.table()


@app.get("/robot-arm")
def get_robot_arm() -> RobotArm:
    global furniture_bench_env
    return furniture_bench_env.robot_arm()


@app.get("/image")
def get_image() -> str:
    """Get the current camera observation as a base64 encoded image.

    Returns:
        dict: Dictionary containing base64 encoded PNG image
    """
    global furniture_bench_env
    image = get_image_with_labels(furniture_bench_env)

    return image_to_base64(image)


@app.post("/run-motion")
def run_motion(data: RunMotionRequestModel) -> RunMotionResponseModel:
    print(data.motion)
    global furniture_bench_env

    furniture_bench_env.data = []

    function_parser = FunctionParser(env=furniture_bench_env, tolerance=tolerance)

    try:
        # run motion
        function_parser.visit(ast.parse(data.motion))

    except RuntimeError as e:
        # TODO: reset environment
        return RunMotionResponseModel(error_response=str(e), states=[])
    except TypeError as e:
        return RunMotionResponseModel(error_response="The motion command is invalid.", states=[], translation=True)

    prev_state = furniture_bench_env.env_state.curr_hash
    furniture_bench_env.add_state(hash=data.hash)
    post_state = furniture_bench_env.env_state.curr_hash

    print("State changed: %s -> %s" % (prev_state, post_state))

    furniture_bench_env.data

    return RunMotionResponseModel(error_response=None, states=furniture_bench_env.data)


@app.post("/predicates-evaluation")
def get_predicates_evaluation(data: GetValidPredicatesRequestModel) -> GetValidPredicatesResponseModel:
    domain = DomainParser()(data.domain)
    problem = ProblemParser()(data.problem)

    known_predicates = list(domain.predicates)

    state_validator = StateValidator(env=furniture_bench_env)

    predicates = state_validator.get_predicates_evaluation(known_predicates=known_predicates, objects=problem.objects)

    return GetValidPredicatesResponseModel(predicates=str(And(*predicates)))


@app.post("/gt-predicate-evaluation")
def get_gt_predicates_evaluation() -> str:
    from pddl.core import Predicate, Constant
    from pddl.logic.base import Variable

    known_predicates = [
        Predicate("assembled", Variable("p1", ["part"]), Variable("p2", ["part"])),
    ]

    objects = [
        Constant(p, "part")
        for p in furniture_bench_env.get_objects()
    ]

    state_validator = StateValidator(env=furniture_bench_env)

    predicates = state_validator.get_predicates_evaluation(known_predicates=known_predicates, objects=objects)

    return str(And(*predicates))


@app.get("/get-supported-predicates")
def get_supported_predicates() -> GetSupportedPredicatesResponseModel:
    state_validator = StateValidator(env=furniture_bench_env)
    predicates = list(state_validator.get_supported_predicates().keys())

    return GetSupportedPredicatesResponseModel(supported_predicates=predicates)
