import ast
from functools import partial

from furniture_bench_api.api.api_skills import APISkills
from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment


class FunctionParser(ast.NodeTransformer):

    def __init__(self, env: FurnitureBenchEnvironment, tolerance: float):
        super().__init__()

        skills = APISkills(furniture_bench_env=env, tolerance=tolerance)

        self.known_functions = {
            "move_linear": skills.move_linear,
            "move_linear_relative": skills.move_linear_relative,
            "move_linear_up": skills.move_linear_up,
            "move_linear_down_until_touching": skills.move_linear_relative_until_touching,
            "close_gripper": skills.close_gripper,
            "open_gripper": skills.open_gripper,
            "set_gripper_around_part": skills.move_to_part_center,
            "hover_above_part": skills.move_to_part_above,
            "align_orientation_for_assembly": skills.align,
            "screw_touching_parts_together": skills.rotate_arm_until_screwed,
        }

    def visit_Call(self, node):
        func_name = node.func.id
        func_args = [self.visit(arg) for arg in node.args]
        func_kwargs = {keyword.arg: self.visit(keyword.value) for keyword in node.keywords}

        known_f = self.known_functions.get(func_name, None)
        if known_f is None:
            raise RuntimeError("function %s unknown" % func_name)

        return known_f(*func_args, **func_kwargs)

    def visit_Constant(self, node):
        return node.value

    def visit_UnaryOp(self, node: ast.UnaryOp):
        if isinstance(node.op, ast.USub):
            return -1 * self.visit(node.operand)
        raise NotImplementedError()

    def visit_Name(self, node):
        if node.id not in self.variables:
            raise RuntimeError("variable %s not defined" % node.id)

        return self.variables[node.id]

    def visit_Tuple(self, node: ast.Tuple):
        return tuple([self.visit(e) for e in node.elts])

    def visit_Assign(self, node):
        responses = self.visit(node.value)
        if not isinstance(responses, tuple):
            responses = [responses]

        targets = node.targets[0]
        if not isinstance(targets, ast.Tuple):
            targets = [targets]
        else:
            targets = targets.elts
        for target, response in zip(targets, responses):
            assert isinstance(target, ast.Name)
            self.variables[target.id] = response
        # for target, response in zip(node.targets[0].elts, responses):
        #     assert isinstance(target, ast.Name)
        #     self.variables[target.id] = response
