"""
Analyze and visualize the trajectory geometry to understand the control point behavior.

The control point parameterization creates curves that bend PERPENDICULAR to the
start->handle line. This means:
- For close_drawer, start->handle is mostly in -Y direction
- perp1 is computed as cross(line_vec, [0,0,1]) -> X direction
- perp2 is computed as cross(line_vec, perp1) -> Z direction
- angle=0: offset in +X direction (curve bends RIGHT when looking from above)
- angle=90: offset in +Z direction (curve bends UP)
- angle=180: offset in -X direction (curve bends LEFT)
- angle=270: offset in -Z direction (curve bends DOWN)

The push phase is always in -Y direction (straight into cabinet).
So reach and push are NOT coplanar unless the reach curve stays in the Y-Z plane (angle 90 or 270).
"""
import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def build_local_frame(start_pos, target_pos):
    """Return (line_vec_norm, perp1, perp2) forming an orthonormal frame."""
    line_vec = target_pos - start_pos
    line_len = np.linalg.norm(line_vec)
    if line_len < 1e-6:
        return np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])

    line_vec_norm = line_vec / line_len
    if abs(line_vec_norm[2]) < 0.9:
        perp1 = np.cross(line_vec_norm, np.array([0, 0, 1]))
    else:
        perp1 = np.cross(line_vec_norm, np.array([1, 0, 0]))
    perp1 = perp1 / np.linalg.norm(perp1)
    perp2 = np.cross(line_vec_norm, perp1)
    return line_vec_norm, perp1, perp2


def main():
    base_dir = '/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/encoder/z2/generalize_test'
    distance_test_dir = os.path.join(base_dir, 'distance_test')
    plots_dir = os.path.join(base_dir, 'plots')

    # Load a few trajectories
    print("Analyzing trajectory geometry for distance test...\n")

    for i in range(3):
        ep_path = os.path.join(distance_test_dir, f'episode{i}')
        ee = np.load(os.path.join(ep_path, 'ee_trajectory.npy'))
        meta = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

        start = ee[0]
        end = ee[-1]  # After push
        mid = ee[40]  # Mid-reach

        print(f"Episode {i}: angle={meta['angle']}, distance={meta['distance']}")
        print(f"  Start: [{start[0]:.4f}, {start[1]:.4f}, {start[2]:.4f}]")
        print(f"  End:   [{end[0]:.4f}, {end[1]:.4f}, {end[2]:.4f}]")

        # Compute local frame
        line_vec_norm, perp1, perp2 = build_local_frame(start, end)
        print(f"  Line direction:  [{line_vec_norm[0]:.4f}, {line_vec_norm[1]:.4f}, {line_vec_norm[2]:.4f}]")
        print(f"  perp1 (angle=0): [{perp1[0]:.4f}, {perp1[1]:.4f}, {perp1[2]:.4f}]")
        print(f"  perp2 (angle=90):[{perp2[0]:.4f}, {perp2[1]:.4f}, {perp2[2]:.4f}]")

        # Check deviation direction
        # The Bezier curve with control point in perp1 direction will bend toward perp1
        # Deviation = mid - closest_point_on_line
        t = np.dot(mid - start, line_vec_norm)
        closest = start + t * line_vec_norm
        deviation = mid - closest
        dev_mag = np.linalg.norm(deviation)
        if dev_mag > 0.001:
            dev_dir = deviation / dev_mag
            # Project onto perp1 and perp2
            proj_perp1 = np.dot(deviation, perp1)
            proj_perp2 = np.dot(deviation, perp2)
            print(f"  Mid-point deviation: {dev_mag*100:.2f} cm")
            print(f"    Projection on perp1 (X): {proj_perp1*100:.2f} cm")
            print(f"    Projection on perp2 (Z): {proj_perp2*100:.2f} cm")
        print()

    # Explanation
    print("="*60)
    print("EXPLANATION:")
    print("="*60)
    print("""
The control point parameterization works as follows:

1. For close_drawer, the start->handle direction is mostly -Y (toward cabinet)
2. perp1 = cross(line_vec, [0,0,1]) = X direction (horizontal, perpendicular to motion)
3. perp2 = cross(line_vec, perp1) = Z direction (vertical)

With angle=0 and distance>0:
- Control point offset = distance * cos(0) * perp1 = distance * perp1
- This creates a curve that bends in the +X direction (to the right)
- NOT in the Y-Z plane (not coplanar with push direction)

With angle=180 and distance>0:
- Control point offset = distance * cos(180) * perp1 = -distance * perp1
- This creates a curve that bends in the -X direction (to the left)
- Also NOT in the Y-Z plane

To have reach trajectory in the same plane as push (Y-Z plane):
- Use angle=90 (bends up, +Z) or angle=270 (bends down, -Z)
- OR use distance=0 (straight line, no bend)

This is BY DESIGN - the control point creates perpendicular curves.
The current visualization is CORRECT for the parameterization being used.
""")

    # Create a diagram showing the geometry
    fig = plt.figure(figsize=(16, 6))

    # Left: Top-down view (X-Y plane) showing how angle=0,180 curves sideways
    ax1 = fig.add_subplot(1, 3, 1)

    # Load angle=0 and angle=180 trajectories
    for i, (ep_idx, label, color) in enumerate([(0, 'angle=0', 'blue'), (6, 'angle=180', 'red')]):
        ep_path = os.path.join(distance_test_dir, f'episode{ep_idx}')
        if os.path.exists(ep_path):
            ee = np.load(os.path.join(ep_path, 'ee_trajectory.npy'))
            ax1.plot(ee[:, 0], ee[:, 1], color=color, linewidth=2, label=label)
            ax1.scatter(ee[0, 0], ee[0, 1], color=color, s=80, marker='o', zorder=5)
            ax1.scatter(ee[-1, 0], ee[-1, 1], color=color, s=80, marker='s', zorder=5)

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_title('Top-Down View (X-Y)\nShows sideways bending')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.axis('equal')

    # Middle: Side view (Y-Z plane)
    ax2 = fig.add_subplot(1, 3, 2)

    for i, (ep_idx, label, color) in enumerate([(0, 'angle=0', 'blue'), (6, 'angle=180', 'red')]):
        ep_path = os.path.join(distance_test_dir, f'episode{ep_idx}')
        if os.path.exists(ep_path):
            ee = np.load(os.path.join(ep_path, 'ee_trajectory.npy'))
            ax2.plot(ee[:, 1], ee[:, 2], color=color, linewidth=2, label=label)
            ax2.scatter(ee[0, 1], ee[0, 2], color=color, s=80, marker='o', zorder=5)
            ax2.scatter(ee[-1, 1], ee[-1, 2], color=color, s=80, marker='s', zorder=5)

    ax2.set_xlabel('Y (m)')
    ax2.set_ylabel('Z (m)')
    ax2.set_title('Side View (Y-Z)\nPush plane - minimal deviation here')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Right: Diagram explaining the geometry
    ax3 = fig.add_subplot(1, 3, 3)
    ax3.set_xlim(-1.5, 1.5)
    ax3.set_ylim(-1.5, 1.5)

    # Draw axes
    ax3.arrow(0, 0, 1, 0, head_width=0.1, head_length=0.05, fc='black', ec='black')
    ax3.text(1.1, 0, 'perp1\n(angle=0)\n(+X)', fontsize=10, ha='left', va='center')

    ax3.arrow(0, 0, 0, 1, head_width=0.1, head_length=0.05, fc='black', ec='black')
    ax3.text(0, 1.2, 'perp2\n(angle=90)\n(+Z)', fontsize=10, ha='center', va='bottom')

    ax3.arrow(0, 0, -1, 0, head_width=0.1, head_length=0.05, fc='gray', ec='gray')
    ax3.text(-1.1, 0, 'angle=180\n(-X)', fontsize=10, ha='right', va='center', color='gray')

    ax3.arrow(0, 0, 0, -1, head_width=0.1, head_length=0.05, fc='gray', ec='gray')
    ax3.text(0, -1.2, 'angle=270\n(-Z)', fontsize=10, ha='center', va='top', color='gray')

    # Draw line direction (into page)
    ax3.scatter(0, 0, s=200, c='red', marker='x', linewidths=3, zorder=10)
    ax3.text(0.15, -0.15, 'line_vec\n(-Y, into page)', fontsize=9, ha='left', va='top', color='red')

    ax3.set_title('Control Point Offset Directions\n(looking from +Y toward cabinet)')
    ax3.set_aspect('equal')
    ax3.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'trajectory_geometry_explanation.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print(f"\nSaved geometry explanation to {plots_dir}/trajectory_geometry_explanation.png")


if __name__ == '__main__':
    main()
