import numpy as np
import matplotlib.pyplot as plt
import sympy as sp

GRAD_INV_PLOT = True
PARALLEL_PLOT = False

# Define symbolic variables
y1_sym, y2_sym = sp.symbols('y1 y2')

# Define the surface function symbolically (change this to your custom function)
# 'y1**2 + y2**2 - 1'  # 'y1**3 + y2**2 - 1' #'y1**3 - y2**3 + y1**2 * y2 - y2**2 + 0.5'  # Example function
function_expr = 'y1**2 + y2**2 - 1'
function_sym = sp.sympify(function_expr)

# Compute symbolic partial derivatives
df_dy1_sym = sp.diff(function_sym, y1_sym)
df_dy2_sym = sp.diff(function_sym, y2_sym)

# Create numerical functions from symbolic expressions
surface_function = sp.lambdify((y1_sym, y2_sym), function_sym, 'numpy')
df_dy1_func = sp.lambdify((y1_sym, y2_sym), df_dy1_sym, 'numpy')
df_dy2_func = sp.lambdify((y1_sym, y2_sym), df_dy2_sym, 'numpy')

# Define parameters for the plot range
# Adjusted to include the tangent point
plot_range = 5  # Increase if needed to include the tangent point
y1_range = np.linspace(-plot_range, plot_range, 1000)
y2_range = np.linspace(-plot_range, plot_range, 1000)
y1_grid, y2_grid = np.meshgrid(y1_range, y2_range)
c_grid = surface_function(y1_grid, y2_grid)

# Arbitrary point on the surface to compute tangent (ensure it's within the plot range)
y1_tangent, y2_tangent = -0.002,0.1
c_tangent = surface_function(y1_tangent, y2_tangent)

# Compute the gradient at the tangent point
dy1 = df_dy1_func(y1_tangent, y2_tangent)
dy2 = df_dy2_func(y1_tangent, y2_tangent)

# Define the tangent plane function using the gradient
def tangent_plane(y1, y2):
    return c_tangent + dy1 * (y1 - y1_tangent) + dy2 * (y2 - y2_tangent)

# Tangent plane values
c_tangent_plane = tangent_plane(y1_grid, y2_grid)

# Calculate the intersection of the original surface with c=0
threshold = 0.05  # Adjust for precision
mask_surface = np.abs(c_grid) < threshold
y1_intersect_surface = y1_grid[mask_surface]
y2_intersect_surface = y2_grid[mask_surface]
c_intersect_surface = c_grid[mask_surface]

# Calculate the intersection of the tangent plane with c=0
mask_tangent_plane = np.abs(c_tangent_plane) < threshold
y1_intersect_tangent = y1_grid[mask_tangent_plane]
y2_intersect_tangent = y2_grid[mask_tangent_plane]
c_intersect_tangent = c_tangent_plane[mask_tangent_plane]

# Create a 3D plot
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(111, projection='3d')

# Plot the original surface
surface = ax.plot_surface(y1_grid, y2_grid, c_grid, alpha=0.7, cmap='viridis', edgecolor='none')

# Plot the tangent plane
tangent_surface = ax.plot_surface(y1_grid, y2_grid, c_tangent_plane, alpha=0.5, color='orange', edgecolor='none')

# Plot the intersection curve of the original surface with c=0 in dark green
ax.plot(y1_intersect_surface, y2_intersect_surface, c_intersect_surface,
        color='darkgreen', linewidth=3, label='Surface Intersection with c=0')

# Plot the intersection curve of the tangent plane with c=0 in dark orange
ax.plot(y1_intersect_tangent, y2_intersect_tangent, c_intersect_tangent,
        color='darkorange', linewidth=3, label='Tangent Plane Intersection with c=0')

# Plot the inverse gradient vector at the tangent point
# Compute the inverse gradient vector components
grad_vector = np.array([dy1+1e-7, dy2+1e-7])  # Normal vector to the surface
inv_grad_vector = 1/-np.abs(grad_vector)  # Inverse gradient vector

# Normalize the vector for plotting
vector_length = np.linalg.norm(inv_grad_vector)
inv_grad_vector_normalized = inv_grad_vector / vector_length
grad_vector_length = np.linalg.norm(grad_vector)
grad_vector_normalized = grad_vector / grad_vector_length

# Scale the vector for better visualization
vector_scale = 1  # Adjust the scale as needed
inv_grad_vector_scaled = inv_grad_vector_normalized * vector_scale
grad_vector_scaled = grad_vector_normalized * vector_scale

# Plot the arrow using quiver
# ax.quiver(y1_tangent, y2_tangent, c_tangent,
#           inv_grad_vector_scaled[0], inv_grad_vector_scaled[1], 0,
#           color='blue', linewidth=2, label='Inverse Gradient Vector', length=2)

# Plot the tangent point (plot it after other elements to ensure it's on top)
ax.scatter(y1_tangent, y2_tangent, c_tangent, color='red', s=100, label='Tangent Point',
           edgecolors='black', zorder=5)
ax.scatter(y1_tangent, y2_tangent, 0, color='blue', s=100, label='Original prediction',
           edgecolors='black', zorder=5)

# Set labels and title
ax.set_xlabel("y1")
ax.set_ylabel("y2")
ax.set_zlabel("c")
ax.set_title("3D Surface with Tangent Plane, Intersection Curves, Tangent Point, and Inverse Gradient Vector")

# Ensure the plane and intersection are visible
ax.set_xlim(-plot_range, plot_range)
ax.set_ylim(-plot_range, plot_range)
ax.set_zlim(np.min(c_grid), np.max(c_grid))

# Adjust the view angle for better visibility
ax.view_init(elev=30, azim=-60)

# Add a legend
ax.legend()

# Prepare variables for 2D plotting
grad_vector = np.array([dy1, dy2])
grad_vector_length = np.linalg.norm(grad_vector)
grad_vector_normalized = grad_vector / grad_vector_length

# Compute the direction vector along the intersection line (parallel vector)
parallel_vector = np.array([-dy2, dy1])
parallel_vector_normalized = parallel_vector / np.linalg.norm(parallel_vector)

# Compute k for the line equation of the tangent plane intersection with c=0
k = -c_tangent + dy1 * y1_tangent + dy2 * y2_tangent

# Compute a point on the intersection line
if dy2 != 0:
    y1_intersect = y1_tangent + 1  # Arbitrary choice
    y2_intersect = (k - dy1 * y1_intersect) / dy2
else:
    y2_intersect = y2_tangent + 1
    y1_intersect = (k - dy2 * y2_intersect) / dy1

x0 = np.array([y1_tangent, y2_tangent])  # Tangent point
x1 = np.array([y1_intersect, y2_intersect])  # Point on the intersection line

vector_to_intersection = x1 - x0

# Ensure the parallel vector points towards the intersection
if np.dot(parallel_vector_normalized, vector_to_intersection) < 0:
    parallel_vector_normalized = -parallel_vector_normalized

# Ensure the gradient vector points towards the intersection
if np.dot(grad_vector_normalized, vector_to_intersection) < 0:
    grad_vector_normalized = -grad_vector_normalized

# The orthogonal vector is the normalized gradient vector
orthogonal_vector_normalized = grad_vector_normalized

# Compute the negative inverse gradient vector
epsilon = 1e-7  # Small value to avoid division by zero
inv_grad_vector = -1 / (grad_vector + epsilon)

# Normalize the inverse gradient vector
inv_grad_vector_normalized = inv_grad_vector / np.linalg.norm(inv_grad_vector)

# Ensure the inverse gradient vector points towards the intersection
if np.dot(inv_grad_vector_normalized, vector_to_intersection) < 0:
    inv_grad_vector_normalized = -inv_grad_vector_normalized

# Scale vectors for visualization
vector_scale = 1
parallel_vector_scaled = parallel_vector_normalized * vector_scale
orthogonal_vector_scaled = orthogonal_vector_normalized * vector_scale
inv_grad_vector_scaled = inv_grad_vector_normalized * vector_scale

# Start 2D plot
fig2, ax2 = plt.subplots(figsize=(8, 8))

# Plot the intersection curves
ax2.plot(y1_intersect_surface, y2_intersect_surface, color='darkgreen', linewidth=2, label='Surface Intersection with c=0')
ax2.plot(y1_intersect_tangent, y2_intersect_tangent, color='darkorange', linewidth=2, label='Tangent Plane Intersection with c=0')

# Plot the tangent and prediction points
ax2.scatter(y1_tangent, y2_tangent, color='red', s=50, label='Tangent Point', edgecolors='black', zorder=5)
ax2.scatter(y1_tangent, y2_tangent, color='blue', s=50, label='Original Prediction', edgecolors='black', zorder=5)

# Plot the gradient vector of c
ax2.quiver(y1_tangent, y2_tangent,
           grad_vector_normalized[0], grad_vector_normalized[1],
           color='red', angles='xy', scale_units='xy', scale=1, linewidth=1, label='Gradient Vector of c')

# Plot the inverse gradient vector if enabled
if GRAD_INV_PLOT:
    ax2.quiver(y1_tangent, y2_tangent,
               inv_grad_vector_scaled[0], inv_grad_vector_scaled[1],
               color='blue', angles='xy', scale_units='xy', linewidth=1, label='Negative Inverse Gradient Vector', scale=1)

# Plot the orthogonal vector
ax2.quiver(y1_tangent, y2_tangent,
           orthogonal_vector_scaled[0], orthogonal_vector_scaled[1],
           color='purple', angles='xy', scale_units='xy', scale=1, linewidth=2, label='Orthogonal Projection to Intersection')

if PARALLEL_PLOT:
    # Plot the parallel vector
    ax2.quiver(y1_tangent, y2_tangent,
            parallel_vector_scaled[0], parallel_vector_scaled[1],
            color='orange', angles='xy', scale_units='xy', scale=1, linewidth=2, label='Parallel to Intersection')

# Compute the weighting matrix W
dy = np.array([dy1, dy2])
inv_dy_abs = 1 / np.abs(dy + epsilon)
W = np.diag(inv_dy_abs)

# Compute the weighted projection
u = parallel_vector_normalized
delta_x = x0 - x1
numerator = np.dot(u.T, W @ delta_x)
denominator = np.dot(u.T, W @ u)
t = numerator / denominator

# Adjust u and t to ensure projection is towards the intersection
if t < 0:
    t = -t
    u = -u

# Compute the projection point and direction
x_p = x1 + t * u
d = x_p - x0
d_normalized = d / np.linalg.norm(d)
d_scaled = d_normalized * vector_scale

# Plot the weighted projection vector
ax2.quiver(y1_tangent, y2_tangent,
           d_scaled[0], d_scaled[1],
           color='green', angles='xy', scale_units='xy', scale=1, linewidth=2, label='Weighted Projection Direction')

# Set labels, aspect ratio, and legend
ax2.set_xlabel("y1")
ax2.set_ylabel("y2")
ax2.set_title("Intersection Curves on y1-y2 Plane (c=0) with Tangent Point and Vectors")
ax2.set_aspect('equal')
ax2.legend()

# Display the plot
plt.show()