import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors

# Load the vector data and starting points
vectors_data = np.load("../env_vectors.npy", allow_pickle=True)
starting_points = np.load("../starting_points.npy", allow_pickle=True)

# Ensure the number of vectors matches the number of starting points
if len(vectors_data) != len(starting_points):
    raise ValueError("Number of vectors and starting points must match.")

# --- Visualization using Matplotlib ---

# Create a figure
fig, ax = plt.subplots(figsize=(9, 6)) # Adjusted size slightly for colorbar

# --- Color mapping setup ---
num_vectors = len(vectors_data)
# Choose a colormap (e.g., 'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'jet')
cmap = plt.colormaps['viridis'].resampled(num_vectors)# Create a normalization object to map indices [0, num_vectors-1] to [0, 1]
norm = mcolors.Normalize(vmin=0, vmax=num_vectors - 1)
# Create a scalar mappable object for the colorbar
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # You need to set_array for the colorbar to work

# Plot each vector starting from its corresponding starting point
for i, (start, vec) in enumerate(zip(starting_points, vectors_data)):
    start_x, start_y = start[0], start[1]
    vec_x, vec_y = vec[0], vec[1]
    color = cmap(norm(i)) # Get color based on index
    # Plot the starting point (optional: color code it too?)
    ax.plot(start_x, start_y, 'o', color=color, markersize=4) # Use vector color for start
    # Plot the vector as an arrow from the starting point
    ax.quiver(start_x, start_y, vec_x, vec_y, angles='xy', scale_units='xy', scale=1, color=color)
    # Optional: Plot the end point
    # ax.plot(start_x + vec_x, start_y + vec_y, 'o', color=color, markersize=2) # Use vector color for end

# Determine axis limits based on starting and ending points
all_start_coords = np.array(starting_points)
all_end_coords = np.array([s + v for s, v in zip(starting_points, vectors_data)])
all_coords = np.vstack((all_start_coords, all_end_coords)) # Combine start and end points

# Handle case where there might be no vectors or points
if all_coords.size == 0:
    x_min, x_max, y_min, y_max = -1, 1, -1, 1 # Default view
else:
    x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
    y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()

# Add some padding to the limits
x_range = x_max - x_min
y_range = y_max - y_min
padding_x = x_range * 0.1 if x_range > 0 else 1 # 10% padding or fixed if range is zero
padding_y = y_range * 0.1 if y_range > 0 else 1
ax.set_xlim(x_min - padding_x, x_max + padding_x)
ax.set_ylim(y_min - padding_y, y_max + padding_y)

# Set aspect ratio to 'equal' to avoid distortion
ax.set_aspect('equal', adjustable='box')

# Add grid, title, and labels
ax.grid(True)
ax.set_title("Vector Visualization from Starting Points (Color Coded by Index)")
ax.set_xlabel("X coordinate")
ax.set_ylabel("Y coordinate (Inverted)") # Update label to reflect inversion

# Invert the Y-axis
ax.invert_yaxis()

# Add the colorbar
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('Vector Index (0 to {})'.format(num_vectors - 1))

# Show the plot
plt.show()