import cv2 as cv 
import numpy as np 
  
  
# The video feed is read in as 
# a VideoCapture object 
image1 = cv.resize(cv.imread('/mnt/zhang-nas/carlq/research/ac_infer/datasets/box2d_default/0/state0_Control.png'), (64, 64))
image2 = cv.resize(cv.imread('/mnt/zhang-nas/carlq/research/ac_infer/datasets/box2d_default/0/state0_Control.png'), (64, 64))

cv.imwrite('state0.png', image1)
cv.imwrite('state1.png', image2)

# Converts frame to grayscale because we 
# only need the luminance channel for 
# detecting edges - less computationally  
# expensive 
prev_gray = cv.cvtColor(image1, cv.COLOR_BGR2GRAY)
# # Creates an image filled with zero 
# # intensities with the same dimensions  
# # as the frame 
mask = np.zeros_like(image1) 
  
# Sets image saturation to maximum 
mask[..., 1] = 255

frame = image2
# Converts each frame to grayscale - we previously  
# only converted the first frame to grayscale 
gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) 

# # Create a grid of points
# height, width = prev_gray.shape
# y, x = np.mgrid[0:height, 0:width].reshape(2,-1).astype(float)
# initial_points = np.vstack((x, y)).T
# # Reshape points for calcOpticalFlowPyrLK
# initial_points = initial_points.reshape(-1, 1, 2).astype(np.float32)
# print(initial_points.shape, initial_points.dtype)

# # Calculate optical flow
# next_points, status, _ = cv.calcOpticalFlowPyrLK(prev_gray, gray, initial_points, None)
# # Ensure valid points after calculation
# valid_points = next_points[status == 1]
# valid_initial_points = initial_points[status == 1]

# print(initial_points)
# print(next_points)
# flow = (next_points - initial_points).reshape(64, 64, 2)

# # Visualize the flow
# for i, (new, old) in enumerate(zip(valid_points, valid_initial_points)):
#     a, b = new.ravel()
#     c, d = old.ravel()

#     # Ensure coordinates are within image bounds and are integers
#     a, b, c, d = map(lambda x: int(min(max(x, 0), width - 1)), [a, b, c, d])
#     frame1 = cv.arrowedLine(image1, (a, b), (c, d), (0, 255, 0), 1, tipLength=1)
# # Opens a new window and displays the output frame 
# cv.imwrite("dense optical flow.png", frame1)







  

    
# Calculates dense optical flow by Farneback method 
flow = cv.calcOpticalFlowFarneback(prev_gray, gray,  None, pyr_scale=0.5, levels=5, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
print(flow.dtype)
    
# Computes the magnitude and angle of the 2D vectors 
magnitude, angle = cv.cartToPolar(flow[..., 0], flow[..., 1]) 
    
# Sets image hue according to the optical flow  
# direction 
mask[..., 0] = angle * 180 / np.pi / 2
    
# Sets image value according to the optical flow 
# magnitude (normalized) 
mask[..., 2] = cv.normalize(magnitude, None, 0, 255, cv.NORM_MINMAX) 
    
# Converts HSV to RGB (BGR) color representation 
rgb = cv.cvtColor(mask, cv.COLOR_HSV2BGR) 
    
# Opens a new window and displays the output frame 
cv.imwrite("dense optical flow.png", rgb) 
cv.imwrite('state0.png', image1)
cv.imwrite('state1.png', image2)

print(flow.shape)
print(np.min(np.linalg.norm(flow, axis=2)), print(np.max(np.linalg.norm(flow, axis=2))))
    
# # Updates previous frame 
# prev_gray = gray 

# #ITR:1	Action:-0.13417 -0.61869	Control:-0.17933 -0.82693 -0.22361 -1.03115 1.0	Ball0:-3.99511 1.79323 -0.0 9.31363 1.0	Ball1:-3.71428 -0.24712 0.0183 5.36697 1.0	Poly3vert0form:1.44777 1.96946 6.60889 8.25463 0.0 1.0 0.0 0.7285	Poly4vert3form:0.0 0.0 0.0 0.0 0.0 0.0 0.0	Poly5vert7form:0.0 0.0 0.0 0.0 0.0 0.0 0.0	Target:0.1756 0.80975 0.0 0.0 1.0	Reward:0.0	Done:False	VALID_NAMES:1 1 1 1 1 0 0 1 1 1	
# # ITR:2	Action:0.625 -0.42384	Control:-0.19958 -1.01216 0.81806 -1.73756 1.0	Ball0:-3.99509 1.94846 0.0 9.31363 1.0	Ball1:-3.71398 -0.15767 0.0183 5.36697 1.0	Poly3vert0form:1.55792 2.10703 6.60889 8.25463 0.0 1.0 0.0 0.7285	Poly4vert3form:0.0 0.0 0.0 0.0 0.0 0.0 0.0	Poly5vert7form:0.0 0.0 0.0 0.0 0.0 0.0 0.0	Target:0.20949 0.96602 0.0 0.0 1.0	Reward:0.0	Done:False	VALID_NAMES:1 1 1 1 1 0 0 1 1 1	


import numpy as np
import matplotlib.pyplot as plt

# Grid of x, y points
# Extracting the x and y components
U = flow[:, :, 0]
V = flow[:, :, 1]

# Create a grid (if needed)
x = np.linspace(0, 1, 64)
y = np.linspace(0, 1, 64)
X, Y = np.meshgrid(x, y)


# Plotting the vector field
plt.figure(figsize=(8, 8))
plt.quiver(X, Y, U, V, scale=50, color='blue')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Vector Field with Flow Lines')
# plt.grid()
plt.savefig('vector_field.png')