import numpy as np
import matplotlib.pyplot as plt


class RobotArm:
    def __init__(self, L1, L2, obstacles=None, to_draw=False):
        self.L1 = L1
        self.L2 = L2
        self.theta1 = 0.0
        self.theta2 = 0.0
        self.obstacles = obstacles if obstacles else []
        self.joint_positions = self.forward_kinematics(self.theta1, self.theta2)
        self.gripper_position = self.joint_positions[-1]
        self.to_draw = to_draw
        self.fig, self.ax = plt.subplots()

    def inverse_kinematics(self, x, y):
        current_theta1=self.theta1
        current_theta2=self.theta2
        dist = np.hypot(x, y)
        if dist > self.L1 + self.L2 or dist < abs(self.L1 - self.L2):
            raise ValueError("Target is out of reach.")

        cos_theta2 = (x**2 + y**2 - self.L1**2 - self.L2**2) / (2 * self.L1 * self.L2)
        sin_theta2_options = [np.sqrt(1 - cos_theta2**2), -np.sqrt(1 - cos_theta2**2)]

        best_solution = None
        min_angle_change = float('inf')

        for sin_theta2 in sin_theta2_options:
            theta2 = np.arctan2(sin_theta2, cos_theta2)
            k1 = self.L1 + self.L2 * cos_theta2
            k2 = self.L2 * sin_theta2
            theta1 = np.arctan2(y, x) - np.arctan2(k2, k1)

            if current_theta1 is not None and current_theta2 is not None:
                delta = abs(theta1 - current_theta1) + abs(theta2 - current_theta2)
            else:
                delta = 0  # default to the first option if no current state

            if delta < min_angle_change:
                min_angle_change = delta
                best_solution = (theta1, theta2)

        return best_solution

    def forward_kinematics(self, theta1, theta2):
        joint1 = np.array([self.L1 * np.cos(theta1), self.L1 * np.sin(theta1)])
        gripper = joint1 + np.array([self.L2 * np.cos(theta1 + theta2), self.L2 * np.sin(theta1 + theta2)])
        return [np.array([0, 0]), joint1, gripper]

    def interpolate_motion(self, theta1_start, theta2_start, theta1_end, theta2_end, steps=50):
        path = []
        for t in np.linspace(0, 1, steps):
            theta1 = (1 - t) * theta1_start + t * theta1_end
            theta2 = (1 - t) * theta2_start + t * theta2_end
            path.append(self.forward_kinematics(theta1, theta2))
        return path

    def check_collision(self, p1, p2):
        for ox, oy, r in self.obstacles:
            center = np.array([ox, oy])
            v = p2 - p1
            u = center - p1
            proj_len = np.dot(u, v) / np.linalg.norm(v)
            if proj_len < 0:
                closest = p1
            elif proj_len > np.linalg.norm(v):
                closest = p2
            else:
                closest = p1 + proj_len * v / np.linalg.norm(v)

            dist = np.linalg.norm(center - closest)
            if dist <= r:
                return True
        return False

    def full_collision_check(self, path):
        for frame in path:
            for i in range(len(frame) - 1):
                if self.check_collision(frame[i], frame[i + 1]):
                    return True
        return False

    def print_joint_positions(self) -> str:
        s = ""
        for i, p in enumerate(self.joint_positions):
            label = "Gripper" if i == 2 else f"Joint {i}"
            s += f" '{label}': [{p[0]:.2f}, {p[1]:.2f}]"
        return "Joint positions:" + s

    def get_joint_positions(self):
        """Return ( joint0_pos, joint1_pos, gripper_pos )"""
        rounded_data = [[round(float(x), 2) for x in item] for item in self.joint_positions]
        return tuple(rounded_data)

    def draw(self):
        xs, ys = zip(*self.joint_positions)
        self.ax.clear()
        self.ax.set_xlim(-self.L1 - self.L2 - 1, self.L1 + self.L2 + 1)
        self.ax.set_ylim(-self.L1 - self.L2 - 1, self.L1 + self.L2 + 1)
        self.ax.set_aspect('equal')
        self.ax.grid(True)
        self.ax.plot(xs, ys, 'bo-', linewidth=4)

        for ox, oy, r in self.obstacles:
            circle = plt.Circle((ox, oy), r, color='gray', alpha=0.5)
            self.ax.add_patch(circle)

        self.ax.set_title("2D Robot Arm Simulation")
        plt.pause(0.05)

    def move_to(self, x, y)->str:
        try:
            theta1_target, theta2_target = self.inverse_kinematics(x, y)
            path = self.interpolate_motion(self.theta1, self.theta2, theta1_target, theta2_target)

            if self.full_collision_check(path):
                return "Failed! Collision detected along the path. Move aborted."

            for segment in path:
                self.joint_positions = segment
                self.gripper_position = segment[-1]
                if self.to_draw:
                    self.draw()

            self.theta1, self.theta2 = theta1_target, theta2_target
            return f"Success!"

        except ValueError as ve:
            return f"Failed! {ve} Move aborted."

    def rotate_to(self, theta1_target, theta2_target)->str:
        try:
            if not (-np.pi <= theta1_target <= np.pi and -np.pi <= theta2_target <= np.pi):
                raise ValueError("Target is out of reach")
            path = self.interpolate_motion(self.theta1, self.theta2, theta1_target, theta2_target)

            if self.full_collision_check(path):
                return "Failed! Collision detected along the path. Move aborted."

            for segment in path:
                self.joint_positions = segment
                self.gripper_position = segment[-1]
                if self.to_draw:
                    self.draw()

            self.theta1, self.theta2 = theta1_target, theta2_target
            return f"Success!"

        except ValueError as ve:
            return f"Failed! {ve} Move aborted."

# ==== MAIN LOOP ====
if __name__ == "__main__":
    robot = RobotArm(L1=5.0, L2=3.0, obstacles=[
        (3.5, 1.0, 0.5),
        (4.5, 3.0, 0.4)
    ])
    robot.draw()

    while True:
        try:
            user_input = input("\nEnter target gripper X Y (or 'q' to quit): ")
            if user_input.strip().lower() == 'q':
                break
            gx, gy = map(float, user_input.strip().split())
            robot.move_to(gx, gy)
        except Exception:
            print("❌ Invalid input. Please enter X and Y coordinates like: 5.5 2.0")

    plt.close()