import numpy as np
import math


def mat2euler(M, cy_thresh=None, seq='zyx'):

    M = np.asarray(M)
    if cy_thresh is None:
        cy_thresh = np.finfo(M.dtype).eps * 4

    r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat
    # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2)
    cy = math.sqrt(r33*r33 + r23*r23)
    if seq=='zyx':
        if cy > cy_thresh: # cos(y) not close to zero, standard form
            z = math.atan2(-r12,  r11) # atan2(cos(y)*sin(z), cos(y)*cos(z))
            y = math.atan2(r13,  cy) # atan2(sin(y), cy)
            x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y))
        else: # cos(y) (close to) zero, so x -> 0.0 (see above)
            # so r21 -> sin(z), r22 -> cos(z) and
            z = math.atan2(r21,  r22)
            y = math.atan2(r13,  cy) # atan2(sin(y), cy)
            x = 0.0
    elif seq=='xyz':
        if cy > cy_thresh:
            y = math.atan2(-r31, cy)
            x = math.atan2(r32, r33)
            z = math.atan2(r21, r11)
        else:
            z = 0.0
            if r31 < 0:
                y = np.pi/2
                x = math.atan2(r12, r13)
            else:
                y = -np.pi/2
    else:
        raise Exception('Sequence not recognized')
    return z, y, x

def euler2quat(z=0, y=0, x=0, isRadian=True):
    ''' Return quaternion corresponding to these Euler angles
    Uses the z, then y, then x convention above
    Parameters
    ----------
    z : scalar
        Rotation angle in radians around z-axis (performed first)
    y : scalar
        Rotation angle in radians around y-axis
    x : scalar
        Rotation angle in radians around x-axis (performed last)
    Returns
    -------
    quat : array shape (4,)
        Quaternion in w, x, y z (real, then vector) format
    Notes
    -----
    We can derive this formula in Sympy using:
    1. Formula giving quaternion corresponding to rotation of theta radians
        about arbitrary axis:
        http://mathworld.wolfram.com/EulerParameters.html
    2. Generated formulae from 1.) for quaternions corresponding to
        theta radians rotations about ``x, y, z`` axes
    3. Apply quaternion multiplication formula -
        http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to
        formulae from 2.) to give formula for combined rotations.
    '''

    if not isRadian:
        z = ((np.pi)/180.) * z
        y = ((np.pi)/180.) * y
        x = ((np.pi)/180.) * x
    z = z/2.0
    y = y/2.0
    x = x/2.0
    cz = math.cos(z)
    sz = math.sin(z)
    cy = math.cos(y)
    sy = math.sin(y)
    cx = math.cos(x)
    sx = math.sin(x)
    return np.array([
                    cx*cy*cz - sx*sy*sz,
                    cx*sy*sz + cy*cz*sx,
                    cx*cz*sy - sx*cy*sz,
                    cx*cy*sz + sx*cz*sy])

def visualize_optical_flow(img1, img2, flow):
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)  
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)  
    flow_np = flow.cpu().numpy().transpose(1, 2, 0)  
    
    flow_magnitude = np.sqrt(flow_np[:,:,0]**2 + flow_np[:,:,1]**2)
    flow_angle = np.arctan2(flow_np[:,:,1], flow_np[:,:,0])
    
    hsv = np.zeros((flow_np.shape[0], flow_np.shape[1], 3), dtype=np.uint8)
    hsv[:,:,0] = (flow_angle + np.pi) / (2 * np.pi) * 255  
    hsv[:,:,1] = 255  
    hsv[:,:,2] = np.clip(flow_magnitude * 10, 0, 255) 
    
    import cv2
    flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    axes[0,0].imshow(img1_np)
    axes[0,0].set_title('Image 1 (t)', fontsize=14)
    axes[0,0].axis('off')
    
    axes[0,1].imshow(img2_np)
    axes[0,1].set_title('Image 2 (t+1)', fontsize=14)
    axes[0,1].axis('off')
    
    im1 = axes[0,2].imshow(flow_magnitude, cmap='jet')
    axes[0,2].set_title('Flow Magnitude', fontsize=14)
    axes[0,2].axis('off')
    plt.colorbar(im1, ax=axes[0,2], fraction=0.046, pad=0.04)
    
    im2 = axes[1,0].imshow(flow_np[:,:,0], cmap='RdBu_r')
    axes[1,0].set_title('Flow X Component', fontsize=14)
    axes[1,0].axis('off')
    plt.colorbar(im2, ax=axes[1,0], fraction=0.046, pad=0.04)
    
    im3 = axes[1,1].imshow(flow_np[:,:,1], cmap='RdBu_r')
    axes[1,1].set_title('Flow Y Component', fontsize=14)
    axes[1,1].axis('off')
    plt.colorbar(im3, ax=axes[1,1], fraction=0.046, pad=0.04)
    
    axes[1,2].imshow(flow_rgb)
    axes[1,2].set_title('Flow Color Coding\n(Hue=Direction, Brightness=Magnitude)', fontsize=14)
    axes[1,2].axis('off')
    
    plt.tight_layout()
    plt.show()
    