#Code that generates Figure 2a of the paper

import matplotlib
matplotlib.use('Agg')
import jax
import os
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
from generate_jax import generate_points_mixture_jax
from generate_centroid_manif import generate_centroid_manifold
import time
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

# Define constants of the problem
dim = 5

# Centroids to retrieve
z0 = jnp.zeros(dim)
z0 = z0.at[dim-1].set(1)
z1 = jnp.zeros(dim)
z1 = z1.at[0].set(-1)

# Gaussian mixture parameters
L = 30
sigma = 0.3
sigma1 = 1

# Linear attention head
def H_lin(mu, lambda_val, points):
    X1 = points[0, :]
    a = jnp.dot(mu, X1)
    inner = points @ mu 
    b = points.T @ inner 
    return (2/L) * lambda_val * a * b

# Risk functions
def single_sample_risk(m0, m1, lambda_val, reg3, points):
    H_m0 = H_lin(m0, lambda_val, points)
    H_m1 = H_lin(m1, lambda_val, points)
    H_sum = H_m0 + H_m1
    X1 = points[0, :]
    diff = X1 - H_sum 
    return jnp.linalg.norm(diff)**2 + reg3 * ((jnp.dot(m0, X1))**2 * (jnp.dot(m1, X1))**2)

def full_risk(m0, m1, lambda_val, reg3, batch):
    return jnp.mean(jax.vmap(single_sample_risk, in_axes=(None, None, None, None, 0))(m0, m1, lambda_val, reg3, batch))

def loss(params, lambda_val, reg3, batch):
    mu0, mu1 = params
    return full_risk(mu0, mu1, lambda_val, reg3, batch)

loss = jax.jit(loss)

# Projected gradient descent update step
def update(params, reg3, lr, batch, iter, sigma):
    lambda_v = lax.select(sigma > 0.5, 0.2, 0.6)
    loss_val, grads = jax.value_and_grad(loss)(params, lambda_v, reg3, batch)
    mu0, mu1 = params
    M0 = jnp.eye(dim) - jnp.outer(mu0, mu0)
    M1 = jnp.eye(dim) - jnp.outer(mu1, mu1)
    t0 = mu0 - lr * M0 @ grads[0] 
    t1 = mu1 - lr * M1 @ grads[1] 
    t0 = t0 / jnp.linalg.norm(t0)
    t1 = t1 / jnp.linalg.norm(t1)
    new_params = (t0, t1)
    return new_params, loss_val, grads

update = jax.jit(update)

# Logging structure
logs = {
    "Run": [],
    "Regularization": [],
    "Dist sigma03": [],
    "Dist sigma1": []
}

# Experiment configuration
n_runs = 10
key = jrandom.PRNGKey(0)
run_subkeys = jrandom.split(key, n_runs)
num_iterations = 5001
batch_size = 256
lr = 0.01
grid_reg3v = 30
reg3v = jnp.linspace(0, 3, grid_reg3v)

# Outer loop over random initializations
for run in range(n_runs):
    print(f'Run {run+1}/{n_runs} of initial points')
    
    key, subkey0, subkey1 = jrandom.split(run_subkeys[run], 3)
    
    # Initialize orthonormal mu0, mu1
    mu0 = jrandom.normal(subkey0, shape=(dim,))
    mu0 = mu0 / jnp.linalg.norm(mu0)
    mu1 = jrandom.normal(subkey1, shape=(dim,))
    mu1 = mu1 / jnp.linalg.norm(mu1)
    mu1 = mu1 - jnp.dot(mu1, mu0) * mu0
    mu1 = mu1 / jnp.linalg.norm(mu1)

    print(f'Initial mu0: {mu0}')
    print(f'Initial mu1: {mu1}')

    # Params for two experiments (sigma=0.3 and sigma=1)
    params = (mu0, mu1)
    params1 = (mu0, mu1)

    # Inner loop over regularization values
    for k in range(jnp.size(reg3v)):
        print(f'Run {k+1}/{grid_reg3v} of the algorithm')
        for i in range(num_iterations):
            t_init = time.time()
            key, subkey = jrandom.split(key)
            dataset = generate_points_mixture_jax(L, sigma, dim, batch_size, subkey)
            dataset1 = generate_points_mixture_jax(L, sigma1, dim, batch_size, subkey)

            params, current_loss, current_grad = update(params, reg3v[k], lr, dataset, i, sigma)
            params1, current_loss1, current_grad1 = update(params1, reg3v[k], lr, dataset1, i, sigma1)
            t_end = time.time()

            if i % 1000 == 0:
                print(f"Iteration {i}, Loss on the batch: {current_loss}, {current_loss1}")

        # Compute distance to ground truth centroids
        dis1 = jnp.minimum(jnp.linalg.norm(params[0] - z0)**2, jnp.linalg.norm(params[0] + z0)**2) + \
               jnp.minimum(jnp.linalg.norm(params[1] - z1)**2, jnp.linalg.norm(params[1] + z1)**2)
        dis2 = jnp.minimum(jnp.linalg.norm(params[0] - z1)**2, jnp.linalg.norm(params[0] + z1)**2) + \
               jnp.minimum(jnp.linalg.norm(params[1] - z0)**2, jnp.linalg.norm(params[1] + z0)**2)
        dis = jnp.sqrt(jnp.minimum(dis1, dis2))

        dis1sigma = jnp.minimum(jnp.linalg.norm(params1[0] - z0)**2, jnp.linalg.norm(params1[0] + z0)**2) + \
                    jnp.minimum(jnp.linalg.norm(params1[1] - z1)**2, jnp.linalg.norm(params1[1] + z1)**2)
        dis2sigma = jnp.minimum(jnp.linalg.norm(params1[0] - z1)**2, jnp.linalg.norm(params1[0] + z1)**2) + \
                    jnp.minimum(jnp.linalg.norm(params1[1] - z0)**2, jnp.linalg.norm(params1[1] + z0)**2)
        dissigma = jnp.sqrt(jnp.minimum(dis1sigma, dis2sigma))

        # Log results
        logs["Run"].append(run + 1)
        logs["Regularization"].append(float(reg3v[k]))
        logs["Dist sigma03"].append(float(dis))
        logs["Dist sigma1"].append(float(dissigma))

# Format results for plotting
df = pd.DataFrame(logs)
df_long = df.melt(
    id_vars="Regularization",
    value_vars=["Dist sigma03", "Dist sigma1"],
    var_name="Sigma",
    value_name="Distance"
)

# Use LaTeX-style labels for sigmas
sigma_map = {
    "Dist sigma03": r"$\sigma=0.3$",
    "Dist sigma1": r"$\sigma=1$"
}
df_long["Sigma"] = df_long["Sigma"].map(sigma_map)

# Plotting
plt.figure(figsize=(35, 20))
sns.set_context("notebook", font_scale=8.5)

ax = sns.lineplot(
    data=df_long,
    x="Regularization",
    y="Distance",
    hue="Sigma",
    linewidth=3.0,
    palette=["blue", "orange"]
)

ax.set_ylabel("Distance to centroid \n (up to a sign)")

plt.tight_layout()
plt.yscale('log')
plt.grid(True)

# Save plot
os.makedirs("main_experiments/results", exist_ok=True)
plt.savefig("main_experiments/results/plot_linear_reg_0.3and1.pdf", format="pdf", bbox_inches="tight")
plt.close()
