import numpy as np
import plotly.graph_objects as go
import kaleido
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from scipy.optimize import minimize
from numpy.linalg import norm

############################################
# 3D plot of regions
############################################


# Constants
C_2 = 1
delta_0 = 0.01
A = 8
B = 4 + 1/4 * C_2 * delta_0
C = 4 - 1/4 * C_2 * delta_0
D = min(1, (2 - 3 * C_2 * delta_0) / (24 + 2 * C_2 * delta_0))

# Define region R2
def R2_eq(x, y, z):
    return A * (x**2 + y**2) + B * (x**2 + y**2 + z**2) - C

# Define region R3
def R3_eq(x, y, z):
    r = np.sqrt(x**2 + y**2)
    return z**2 + (r - 1)**2 - D

# Generate a grid
x = np.linspace(-1.3, 1.3, 50)
y = np.linspace(-1.3, 1.3, 50)
z = np.linspace(-1.3, 1.3, 50)
X, Y, Z = np.meshgrid(x, y, z)

# Compute values of R2 and R3
values_R2 = R2_eq(X, Y, Z)
values_R3 = R3_eq(X, Y, Z)

# Create a circle in the z=0 plane for the line x^2 + y^2 = 1
theta = np.linspace(0, 2 * np.pi, 100)
circle_x = np.cos(theta)
circle_y = np.sin(theta)
circle_z = np.zeros_like(theta)

# Points to add (0, 0, 0) and (0, 0, ±1/sqrt(2))
points_x = [0, 0, 0]
points_y = [0, 0, 0]
points_z = [0, 1 / np.sqrt(2), -1 / np.sqrt(2)]

# Plotly 3D Isosurface for R2 (red) and R3 (blue)
fig = go.Figure(data=[
    go.Isosurface(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=values_R2.flatten(),
        isomin=0,
        isomax=0,
        surface_count=1,
        colorscale='Reds',
        opacity=0.3,
        showscale=False,
        name='R₂'
    ),
    go.Isosurface(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=values_R3.flatten(),
        isomin=0,
        isomax=0,
        surface_count=1,
        colorscale='Blues',
        opacity=0.3,
        showscale=False,
        name='R₃'
    ),
    # Add the line (circle) in z=0 plane
    go.Scatter3d(
        x=circle_x, y=circle_y, z=circle_z,
        mode='lines',
        line=dict(color='black', width=5),
        name='Global minimizers of g',
        opacity=.7
    ),
    # Add points (0,0,0), (0,0,1/sqrt(2)), (0,0,-1/sqrt(2))
    go.Scatter3d(
        x=points_x, y=points_y, z=points_z,
        mode='markers',
        marker=dict(color='black', size=2),
        name='Others critical points of g',
        opacity=.7
    ),
    # Rectangles for legend
    go.Scatter3d(
        x=[None], y=[None], z=[None], mode='markers',
        marker=dict(color='red', size=10, opacity=.4),
        name='R₂'
    ),
    go.Scatter3d(
        x=[None], y=[None],  z=[None], mode='markers',
        marker=dict(color='blue', size=10, opacity=.4),
        name='R₃'
    )
])

# Set plot labels and title
fig.update_layout(scene=dict(
    xaxis_title=r'<x<sup>+</sup>, x<sup>#+</sup>>',
    yaxis_title=r'<x<sup>+</sup>, x<sup>#-</sup>>',
    zaxis_title='||y||',
    xaxis=dict(gridcolor='rgba(0,0,0,0)'),  # Transparent grid for x-axis
    yaxis=dict(gridcolor='rgba(0,0,0,0)'),  # Transparent grid for y-axis
    zaxis=dict(gridcolor='rgba(0,0,0,0)'),  # Transparent grid for z-axis
    ),
    title="3D Plot with Points and Line in Black",
    font=dict(size=10),
    legend=dict(x=0.25, y=1, traceorder='normal'),


)

# Save as png
fig.write_image("3D_plot.png", engine="kaleido")

############################################
# grid of norms of T-S 
############################################

sigma = 1

# Generate random unit vectors
def random_unit_vector(n):
    """Generate a random unit vector in n dimensions."""
    v = np.random.randn(n)
    return v / np.linalg.norm(v)

# Compute the contraction of a 4th-order tensor with vectors x, y, z, w
def tensor_contraction(T, x, y, z, w):
    """
    Computes the contraction of a 4th order tensor T with vectors x, y, z, w.
    Equivalent to T(x, y, z, w).
    """
    return np.einsum('ijkl,i,j,k,l->', T, x, y, z, w)

# Objective function for minimization
def objective_function(v, T, n):
    """
    Objective function to minimize (negative contraction).
    v is a flattened vector containing x, y, z, w, each of length n.
    """
    x, y, z, w = np.split(v, 4)
    # Normalize to ensure x, y, z, w are unit vectors
    x /= norm(x)
    y /= norm(y)
    z /= norm(z)
    w /= norm(w)
    # Return the negative contraction (because we are maximizing)
    return -tensor_contraction(T, x, y, z, w)

# Approximate the operator norm of the 4th-order tensor
def operator_norm(T, n):
    """
    Approximates the operator norm of a 4th order tensor T.
    T has shape (n, n, n, n).
    """
    # Initialize with random unit vectors for x, y, z, w
    v0 = np.hstack([random_unit_vector(2 * n) for _ in range(4)])  # 2*n is used in your tensor

    # Minimize the negative contraction (equivalent to maximizing the contraction)
    res = minimize(objective_function, v0, args=(T, 2 * n), method='BFGS')

    # The maximum contraction value (negative because we minimized)
    max_contraction = -res.fun
    return max_contraction

# Tensor construction (T and S)
def T(m, sigma, n):
    # Initialize T as a 2n x 2n x 2n x 2n tensor
    T = np.zeros((2 * n, 2 * n, 2 * n, 2 * n))
    for i in (range(m)):
        a_i = np.random.normal(0, sigma, 2 * n)
        # Compute the 4th order Kronecker product a_i ⊗ a_i ⊗ a_i ⊗ a_i
        T_i = np.kron(np.kron(np.kron(a_i, a_i), a_i), a_i)
        # Reshape T_i so that coefficient i_1, i_2, i_3, i_4 of T_i is (a_i)_i1 (a_i)_i2 ...
        T_i = T_i.reshape((2 * n, 2 * n, 2 * n, 2 * n))
        # Add T_i to T
        T += T_i
    return T / (m * sigma ** 4)

# Construct S tensor
def S(n):
    S = np.zeros((2 * n, 2 * n, 2 * n, 2 * n))
    # Define S as 1_(i1=i2, i3=i4) + 1_(i1=i3, i2=i4) + 1_(i1=i4, i2=i3)
    for i in range(2 * n):
        for j in range(2 * n):
            S[i][i][j][j] += 1
            S[i][j][i][j] += 1
            S[i][j][j][i] += 1
    return S

# Define the different values of m and n to compute
m_values = range(2000, 22000, 2000)
n_values = range(1,9)
K=5
# Create an empty list to store results
results_list = []

# Compute ||T - S|| for different m and n using the operator norm approximation
for m in m_values:
    for n in n_values:
        print(f"Computing for m={m}, n={n}...")
        score=0
        for k in range(K):
          T_val = T(m, sigma, n)
          S_val = S(n)
          # Compute the operator norm of T - S
          diff_tensor = T_val - S_val
          score += operator_norm(diff_tensor, n)

        # Append result to the list
        results_list.append({"m": m, "n": n, "||T-S||": score/K})

# Convert the list to a DataFrame
results = pd.DataFrame(results_list)
# Pivot the data to have 'm' as columns and 'n' as rows
pivot_table = results.pivot(index="n", columns="m", values="||T-S||")
# sort values of n by descending order
pivot_table = pivot_table.sort_index(ascending=False)
# Plot the heatmap using Seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(pivot_table, annot=True, cmap="coolwarm", fmt=".1f")
plt.title("||T - S|| for different values of m and n")

# save as png
plt.savefig("heatmap.png")

############################################
# Plot of trajectories
############################################

# Define parameters
n = 4
m_list = [10,20,30,40]
learning_rate_list=[.00005,.00005,.00005, .00005]
n_step_list=[5000, 5000, 5000,5000]
sigma = 1
K = 20  # Number of independent gradient descent instances
# fix seed
np.random.seed(2)

# Function to construct vectors and matrices
def construct_A_i(n, sigma):
    # Generate normally distributed vectors
    a_i_plus = np.random.normal(0, sigma, 2 * n)
    # Define a_i_minus
    a_i_minus = np.concatenate([-a_i_plus[n:], a_i_plus[:n]])

    # Construct A_i
    A_i = np.outer(a_i_plus, a_i_plus) + np.outer(a_i_minus, a_i_minus)

    return A_i

# Function f(x, x_0, a_i)
def f(x, x_0, A_i_list):
    loss = 0
    m=len(A_i_list)
    for A_i in A_i_list:
        term = np.dot(x.T, np.dot(A_i, x)) - np.dot(x_0.T, np.dot(A_i, x_0))
        loss += term ** 2
    return loss/m

# Gradient of the function f with respect to x
def gradient_f(x, x_0, A_i_list):
    grad = np.zeros_like(x)
    m= len(A_i_list)
    for A_i in A_i_list:
        term = np.dot(x.T, np.dot(A_i, x)) - np.dot(x_0.T, np.dot(A_i, x_0))
        grad += 2 * term * (np.dot(A_i, x)) /m
    return grad

# Gradient descent function
def gradient_descent(x_0,x, A_i_list, n_step, learning_rate):
    losses = []
    m=len(A_i_list)
    for step in range(n_step):
        loss = f(x, x_0, A_i_list)
        losses.append(loss)
        grad = gradient_f(x, x_0, A_i_list)
        x -= learning_rate * grad
    return losses

# Main loop for each m
for m,n_step, learning_rate in zip(m_list, n_step_list, learning_rate_list):
    plt.figure()
    A_i_list = [construct_A_i(n, sigma) for _ in range(m)]
    x=np.random.normal(0, 1, 2 * n)
    for k in tqdm(range(K)):
        x_0=np.random.uniform(-10, 10, 2 * n)
        # Perform gradient descent for each k instance
        losses = gradient_descent(x_0,x , A_i_list, n_step, learning_rate)
        plt.plot(range(n_step), losses)
    plt.ylim(bottom=1e-20)
    plt.yscale('log')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title(f'Loss vs Iterations for m={m}, n=4')
    plt.legend()
    # Save the plot 
    plt.savefig(f"loss_vs_iterations_m_{m}.png")
