from collections import defaultdict
import copy
from abc import abstractmethod
from typing import Optional, Tuple, List

# isort: off
from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment
import torch

# isort: on

from python_utils.transformations import pose_to_affine



class StatePredicate:

    def reset(self):
        pass

    def get_state(self):
        return {}

    def set_state(self, state: dict):
        pass

    @abstractmethod
    def validate(self, env, *args) -> Tuple[bool, str]:
        raise NotImplementedError()


class StateArmEmptyPredicate(StatePredicate):

    def validate(self, env: FurnitureBenchEnvironment, arm: str):
        gripper_conf = env.get_current_pose()[7]
        if not env.gripper_closed:
            # gripper open, so empty
            return True, "gripper is open"
        elif gripper_conf < 0.01:
            # gripper closed, but nothing grasped
            return True, "gripper is closed and has nothing grasped"
        else:
            return False, "gripper grasps something"


class StateGripperOpenPredicate(StatePredicate):

    def validate(self, env: FurnitureBenchEnvironment, arm: str):
        return not env.gripper_closed, ""


class StateHoldingPredicate(StatePredicate):
    def validate(self, env: FurnitureBenchEnvironment, arm: str, object: str):
        gripper_conf = env.get_current_pose()[7]
        if not env.gripper_closed:
            # gripper open, so empty
            return False, "gripper is open"
        elif gripper_conf < 0.01:
            # gripper closed, but nothing grasped
            return False, "gripper is closed, but does not hold anything"

        if env.grasps_object(obj_name="lamp_hood") and object == "lamp_bulb":
            # happens when putting hood on bulb. Hack to correct
            return False, ""

        grasps_object = env.grasps_object(obj_name=object)

        if not grasps_object:
            return False, "gripper does not grasp object %s" % object
        else:
            return True, "gripper grasps object %s" % object


class StateOnTablePredicate(StatePredicate):
    def validate(self, env: FurnitureBenchEnvironment, *args):
        assert len(args) == 1

        object = args[0]

        obj_center = env.get_object_origin(object)

        mesh = copy.deepcopy(env.objects[object])
        obj_center_affine = pose_to_affine(obj_center)
        mesh = mesh.transform(obj_center_affine.cpu().numpy())
        oobb_object = mesh.get_axis_aligned_bounding_box()
        z_min = oobb_object.min_bound[2]

        # print(object, z_min)

        if z_min > -0.013 or z_min < -0.016:  # table at -0.015
            return False, "%s not on table" % args[0]
        else:
            return True, "%s on table" % args[0]


class StateArmAbovePartPredicate(StatePredicate):
    def validate(self, env: FurnitureBenchEnvironment, arm, obj):
        valid, desc = StateArmAtPartPredicate().validate(env, arm, obj)
        if valid:
            return False, "arm at part, so cannot be near"
        valid, desc = StateHoldingPredicate().validate(env, arm, obj)
        if valid:
            return False, "arm at part, so cannot be near"
        if obj == "lamp_base":  # TODO: REMOVE!?
            # cannot be near base if bulb is on it # TODO: rather check any other object is between obj and arm along z
            valid, desc = StateScrewedInPredicate().validate(env, "lamp_bulb", "lamp_base")
            if valid:
                return False, "arm at part, so cannot be near"

        obj_center = env.get_transformed_pose(obj, pose="center")
        arm_center = env.get_current_pose(at_flange=True)[:7]

        overlap_in_z = (arm_center[:2] - obj_center[:2]).norm().item() < 0.04  # 4 cm because the bulb is rolling...
        # arm_above_obj = (arm_center[2] > obj_center[2]).item()
        arm_above_obj = True

        if overlap_in_z and arm_above_obj:
            return True, "arm above %s" % obj
        else:
            message = "arm not above %s. " % obj
            if not overlap_in_z:
                message += "arm deviates in x- and y-axis too much"
            if not arm_above_obj:
                message += "z-position of end-effector not larger than of object"
            return False, message


class StateArmAtPartPredicate(StatePredicate):
    def validate(self, env: FurnitureBenchEnvironment, arm, obj):
        # valid, desc = StateHoldingPredicate().validate(env, arm, obj)
        # if valid:
        #     return False, "arm at part, so cannot be near"

        if env.grasps_object(obj_name="lamp_hood") and obj == "lamp_bulb":
            return False, ""

        obj_center = env.get_transformed_pose(obj, pose="center")
        arm_center = env.get_current_pose(at_flange=True)[:7]

        overlap = (arm_center[:3] - obj_center[:3]).norm().item() < 0.03  # 3 cm because the bulb is rolling...

        if overlap:
            return True, "arm at %s" % obj
        else:
            return False, "arm not at %s. " % obj


class StateTouchingPredicate(StatePredicate):
    def validate(self, env: FurnitureBenchEnvironment, obj1: str, obj2: str):
        try:
            obj1_center = env.get_transformed_pose(obj1, pose="center_for_align")
            obj2_center = env.get_transformed_pose(obj2, pose="center")
        except KeyError:
            return False, ""

        overlap = ((obj1_center[:2] - obj2_center[:2]).abs() < 0.01).all()

        obj1_bbox = env.get_object_bounding_box(obj1)
        obj1_min_z = obj1_bbox[0][2]
        obj2_bbox = env.get_object_bounding_box(obj2)
        obj2_max_z = obj2_bbox[1][2]

        insert_depth = env.config["parts"][obj1.split(f"{env.furniture}_")[1]].get("insert_depth", {}).get(obj2.split(f"{env.furniture}_")[1])
        if insert_depth is None:
            insert_depth = env.config["predicates"]["insert_depth"]
        low_enough = obj1_min_z < obj2_max_z - insert_depth

        is_touching = bool(overlap.item()) and bool(low_enough)

        return is_touching, "touching or not"


def get_connection_path(target_connection: Tuple[str, str], valid_connections: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    """
    Find a path of connections to reach the target connection.
    
    Args:
        target_connection: The connection to reach (a, b)
        valid_connections: List of available connections
        
    Returns:
        List of connections forming a path to reach the target connection.
        Returns connections exactly as given in valid_connections.
    """
    # Extract the two nodes from target connection
    target_a, target_b = target_connection
    
    # Build a set of all available connections (bidirectional)
    connection_set = set()
    connection_lookup = {}  # maps (a,b) and (b,a) to original connection in valid_connections
    
    for conn in valid_connections:
        a, b = conn
        connection_set.add((a, b))
        connection_set.add((b, a))
        connection_lookup[(a, b)] = conn
        connection_lookup[(b, a)] = conn
    
    # Check if target connection exists directly
    if (target_a, target_b) in connection_set:
        return [connection_lookup[(target_a, target_b)]]
    
    # If not direct, we need to find a path
    # Use BFS to find shortest path
    from collections import deque
    
    # Build adjacency list for graph traversal
    graph = defaultdict(list)
    for a, b in connection_set:
        graph[a].append(b)
    
    # BFS to find path from target_a to target_b
    queue = deque([(target_a, [target_a])])
    visited = {target_a}
    
    while queue:
        current_node, path = queue.popleft()
        
        if current_node == target_b:
            # Found path, now convert to list of connections
            connection_path = []
            for i in range(len(path) - 1):
                node_a, node_b = path[i], path[i + 1]
                connection_path.append(connection_lookup[(node_a, node_b)])
            return connection_path
        
        for neighbor in graph[current_node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, path + [neighbor]))
    
    # No path found
    return []

class StateAssembledPredicate(StatePredicate):

    def __init__(self):
        self.reset()

    def get_state(self):
        return {k: v for k, v in self.response.items()}

    def set_state(self, state: dict):
        self.response = defaultdict(bool, state)

    def reset(self):
        self.response = defaultdict(bool)

    def validate(
        self, env: FurnitureBenchEnvironment, obj1: str, obj2: str, *, pos_threshold: Optional[torch.Tensor] = None
    ):
        if obj1 == obj2:
            return False, ""
        f_env = env.env.env.env

        part_names = []
        # for obj in [obj1, obj2]:
        for obj in [obj2, obj1]: # we reverse the list, which makes more sense
            obj_info = env.config["parts"][obj.split(f"{env.furniture}_")[1]]
            if "part" in obj_info:
                part_names.append(f"{env.furniture}_{obj_info['part']}")
            else:
                part_names.append(obj)

        parts = {p.name: p.part_idx for p in f_env.furnitures[0].parts}

        part_tuple = tuple([parts[obj_n] for obj_n in part_names])
        path = get_connection_path(target_connection=part_tuple, valid_connections=f_env.furnitures[0].should_be_assembled)

        if len(path) == 0:
            return False, ""

        env_idx = 0
        parts_poses, founds = f_env._get_parts_poses(sim_coord=True)
        env_parts_poses = parts_poses[env_idx].cpu().numpy()
        env_founds = founds[env_idx].cpu().numpy()
        f_env.furnitures[env_idx].ori_bound = env.config["predicates"]["ori_bound"]
        f_env.furnitures[env_idx].should_assembled_first = {}  # due to custom assembled_pos_threshold?

        is_valid = True
        for connection in path:
            if self.response[str(connection)]:
                continue

            is_assembled = f_env.furnitures[env_idx].is_assembled_idx(
                *connection,
                env_parts_poses,
                env_founds,
                assembled_pos_threshold=pos_threshold.tolist() if pos_threshold is not None else None
            )
            if is_assembled:
                self.response[str(connection)] = True

            if is_valid and not is_assembled:
                is_valid = False
                break
        
        return is_valid, ""
