#Code that generates Figure 1 and 2b of the paper

import matplotlib
matplotlib.use('Agg')
import os
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
from generate_jax import generate_points_mixture_jax
from generate_centroid_manif import generate_centroid_manifold
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

# Define constants of the problem

dim = 5
# True centroids
z0 = jnp.zeros(dim)
z0 = z0.at[dim - 1].set(1)
z1 = jnp.zeros(dim)
z1 = z1.at[0].set(-1)
L = 30
sigma = 0.3
sigma1 = 1

# Create results directory
os.makedirs('main_experiments/results', exist_ok=True)

# Linear attention head
def H_lin(mu, lambda_val, points):
    X1 = points[0]
    a = jnp.dot(mu, X1)
    inner = points @ mu
    b = points.T @ inner
    return (2 / L) * lambda_val * a * b

# Single-sample risk
def single_sample_risk(m0, m1, lambda_val, reg3, points):
    H_m0 = H_lin(m0, lambda_val, points)
    H_m1 = H_lin(m1, lambda_val, points)
    H_sum = H_m0 + H_m1
    X1 = points[0]
    diff = X1 - H_sum
    return jnp.linalg.norm(diff) ** 2 + reg3 * (jnp.dot(m0, X1)**2) * (jnp.dot(m1, X1)**2)

# Full batch risk
def full_risk(m0, m1, lambda_val, reg3, batch):
    return jnp.mean(
        jax.vmap(single_sample_risk, in_axes=(None, None, None, None, 0))(
            m0, m1, lambda_val, reg3, batch
        )
    )

loss = jax.jit(lambda params, lv, r3, batch: full_risk(params[0], params[1], lv, r3, batch))

# Update step with dynamic lambda
def update(params, reg3, lr, batch, sigma):
    lambda_v = lax.select(sigma > 0.5, 0.2, 0.6)
    loss_val, grads = jax.value_and_grad(full_risk, argnums=(0,1))(params[0], params[1], lambda_v, reg3, batch)
    mu0, mu1 = params
    M0 = jnp.eye(dim) - jnp.outer(mu0, mu0)
    M1 = jnp.eye(dim) - jnp.outer(mu1, mu1)
    t0 = mu0 - lr * (M0 @ grads[0])
    t1 = mu1 - lr * (M1 @ grads[1])
    t0 = t0 / jnp.linalg.norm(t0)
    t1 = t1 / jnp.linalg.norm(t1)
    return (t0, t1), loss_val, grads

update = jax.jit(update)

# Experiment settings
n_runs = 10
num_iterations = 10001
batch_size = 256
lr=0.01

# Run both on_manifold settings
for on_manifold in [0, 1]:
    logs = {"Run": [], "Iterations": [], "Dist sigma03": [], "Dist sigma1": []}
    key = jrandom.PRNGKey(0)
    run_subkeys = jrandom.split(key, n_runs)

    for run in range(n_runs):
        print(f'Run {run+1}/{n_runs} (on_manifold={on_manifold})')
        key, subkey0, subkey1 = jrandom.split(run_subkeys[run], 3)
        if on_manifold:
            mu0, mu1 = generate_centroid_manifold(z0, z1, subkey0, subkey1, dim)
            reg3 = 0
        else:
            mu0 = jrandom.normal(subkey0, (dim,))
            mu0 = mu0 / jnp.linalg.norm(mu0)
            mu1 = jrandom.normal(subkey1, (dim,))
            mu1 = mu1 - jnp.dot(mu1, mu0) * mu0
            mu1 = mu1 / jnp.linalg.norm(mu1)
            reg3 = 0.2
        params = (mu0, mu1)
        params1 = params

        for i in range(num_iterations):
            key, subkey = jrandom.split(key)
            dataset = generate_points_mixture_jax(L, sigma, dim, batch_size, subkey)
            dataset1 = generate_points_mixture_jax(L, sigma1, dim, batch_size, subkey)
            params, current_loss, _ = update(params, reg3, lr, dataset, sigma)
            params1, current_loss1, _ = update(params1, reg3, lr, dataset1, sigma1)

            d = jnp.minimum(jnp.linalg.norm(params[0] - z0)**2, jnp.linalg.norm(params[0] + z0)**2)
            d += jnp.minimum(jnp.linalg.norm(params[1] - z1)**2, jnp.linalg.norm(params[1] + z1)**2)
            d1 = jnp.minimum(jnp.linalg.norm(params[0] - z1)**2, jnp.linalg.norm(params[0] + z1)**2)
            d1 = jnp.minimum(jnp.linalg.norm(params[1] - z0)**2, jnp.linalg.norm(params[1] + z0)**2)
            dis = jnp.sqrt(jnp.min(d,d1))
            d_s = jnp.minimum(jnp.linalg.norm(params1[0] - z0)**2, jnp.linalg.norm(params1[0] + z0)**2)
            d_s += jnp.minimum(jnp.linalg.norm(params1[1] - z1)**2, jnp.linalg.norm(params1[1] + z1)**2)
            d_s1 = jnp.minimum(jnp.linalg.norm(params1[0] - z1)**2, jnp.linalg.norm(params1[0] + z1)**2)
            d_s1 += jnp.minimum(jnp.linalg.norm(params1[1] - z0)**2, jnp.linalg.norm(params1[1] + z0)**2)
            diss = jnp.sqrt(jnp.minimum(d_s,d_s1))

            if i % 1000 == 0:
                print(f"Iteration {i}, Loss sigma0.3: {current_loss}, sigma1: {current_loss1}")

            logs["Run"].append(run+1)
            logs["Iterations"].append(i+1)
            logs["Dist sigma03"].append(float(dis))
            logs["Dist sigma1"].append(float(diss))

        print(f"Finished run {run+1}")

    df = pd.DataFrame(logs)
    df_long = df.melt(id_vars="Iterations", value_vars=["Dist sigma03", "Dist sigma1"], var_name="Sigma", value_name="Distance")
    sigma_map = {"Dist sigma03": r"$\sigma=0.3$", "Dist sigma1": r"$\sigma=1$"}
    df_long["Sigma"] = df_long["Sigma"].map(sigma_map)

    plt.figure(figsize=(35, 20))
    sns.set_context("notebook", font_scale=8.5)
    ax = sns.lineplot(data=df_long, x="Iterations", y="Distance", hue="Sigma", linewidth=3.0)
    ax.set_ylabel("Distance to centroid \n (up to a sign)")
    plt.tight_layout()
    plt.yscale('log')
    plt.grid(True)
    suffix = "manifold" if on_manifold else ""
    filename = f"main_experiments/results/plot_linear_{suffix}_iters_0.3and1.pdf".replace("__", "_")
    plt.savefig(filename, format="pdf", bbox_inches="tight")
    plt.close()

