#Code that generates Figure 12 of the paper

import matplotlib
matplotlib.use('Agg')
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
from generate_jax_generic import generate_points_mixture_jax
from generate_centroid_manif import generate_centroid_manifold
import os
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

# Define constants of the problem


L = 30

# Linear attention head
def H_lin(mu, lambda_val, points):
    a = jnp.dot(mu, points[0])
    inner = points @ mu
    b = points.T @ inner
    return (2 / L) * lambda_val * a * b

# Risk definitions
def single_sample_risk(m0, m1, lambda_val, reg3, points):
    H_sum = H_lin(m0, lambda_val, points) + H_lin(m1, lambda_val, points)
    X1 = points[0]
    diff = X1 - H_sum
    return jnp.linalg.norm(diff) ** 2 + reg3 * (jnp.dot(m0, X1) ** 2) * (jnp.dot(m1, X1) ** 2)

def full_risk(m0, m1, lambda_val, reg3, batch):
    return jnp.mean(jax.vmap(single_sample_risk, in_axes=(None, None, None, None, 0))(m0, m1, lambda_val, reg3, batch))

def loss(params, lambda_val, reg3, batch):
    mu0, mu1 = params
    return full_risk(mu0, mu1, lambda_val, reg3, batch)

loss = jax.jit(loss)

# Riemannian SGD update
def update(params, reg3, lr, batch,sigma):
    lambda_v = lax.select(sigma > 0.5, 0.2, 0.6) #Choice of lambda_v
    loss_val, grads = jax.value_and_grad(loss)(params, lambda_v, reg3, batch)
    mu0, mu1 = params
    M0 = jnp.eye(dim) - jnp.outer(mu0, mu0)
    M1 = jnp.eye(dim) - jnp.outer(mu1, mu1)
    t0 = mu0 - lr * M0 @ grads[0]
    t1 = mu1 - lr * M1 @ grads[1]

    # Normalize
    t0 = t0 / jnp.linalg.norm(t0)
    t1 = t1 / jnp.linalg.norm(t1)

    return (t0, t1), loss_val, grads

update = jax.jit(update)

# Training configurations
#sigma_list = [0, 0.3, 1.0]
sigma_list = [0.3, 1.0]
low_dim=4
big_dim=100 #100
num_dim=10 #10
dim_list=jnp.linspace(low_dim,big_dim, num_dim).astype(int)
n_runs = 10 #10
num_iterations = 5001 #5001
batch_size = 256
lr = 0.01

# Ensure results directory exists
os.makedirs('experiments_norm1centroid/results', exist_ok=True)

all_runs=[]

for dim in dim_list:
    
    for sigma in sigma_list:
        logs = {"Run": [], "Dimension": [], "RMSE": [], "Sigma":[]}
        key = jrandom.PRNGKey(0)
        run_subkeys = jrandom.split(key, n_runs)

        for run in range(n_runs):
            key, sk0, sk1, sk2, sk3 = jrandom.split(run_subkeys[run], 5)
            z0 = jrandom.normal(sk0, (dim,))
            z1 = jrandom.normal(sk1, (dim,))

            # Normalize them to have unit norm
            z0 = z0 / jnp.linalg.norm(z0)
            z1 = z1 / jnp.linalg.norm(z1)

            mu0 = jrandom.normal(sk2, (dim,))
            mu0 = mu0 / jnp.linalg.norm(mu0)
            mu1 = jrandom.normal(sk3, (dim,))
            mu1 = mu1 - jnp.dot(mu1, mu0) * mu0
            mu1 = mu1 / jnp.linalg.norm(mu1)
            reg3 = 0.2

            params = (mu0, mu1)

            for i in range(num_iterations):
                key, subkey = jrandom.split(key)
                batch = generate_points_mixture_jax(L, sigma, dim, batch_size, z0, z1, subkey)
                params, loss_val, grads = update(params, reg3, lr, batch,sigma)

                # Compute Distance
                d1 = jnp.minimum(jnp.linalg.norm(params[0] - z0)**2,
                                 jnp.linalg.norm(params[0] + z0)**2)
                d1 += jnp.minimum(jnp.linalg.norm(params[1] - z1)**2,
                                  jnp.linalg.norm(params[1] + z1)**2)
                d2 = jnp.minimum(jnp.linalg.norm(params[0] - z1)**2,
                                 jnp.linalg.norm(params[0] + z1)**2)
                d2 += jnp.minimum(jnp.linalg.norm(params[1] - z0)**2,
                                  jnp.linalg.norm(params[1] + z0)**2)
                D = jnp.sqrt(jnp.minimum(d1, d2)) / jnp.sqrt(dim)

            if i%5000==0:
                print(f'Dimension: {dim}, Sigma: {sigma}, Run:{run+1}/{n_runs}, Iteration:{i}/{num_iterations-1}')

            logs["Run"].append(run + 1)
            logs["Dimension"].append(float(dim))
            logs["RMSE"].append(float(D))
            logs["Sigma"].append(sigma)
            all_runs.append(pd.DataFrame(logs))

        # Create DataFrame and plot
df_all = pd.concat(all_runs, ignore_index=True)
plt.figure(figsize=(35, 20))
sns.set_context("notebook", font_scale=8.5)
ax=sns.lineplot(
    data=df_all,
    x="Dimension",
    y="RMSE",
    hue="Sigma",
    linewidth=3.0,
    palette=["blue", "orange"]
    )
ax.set_ylabel("RMSE")
plt.tight_layout()
plt.yscale('log')
plt.grid(True)

fname = f"experiments_norm1centroid/results/plot_linear_iters_norm1centroid_changedim.pdf"
plt.savefig(fname, format="pdf", bbox_inches="tight")
plt.close()