#Code that generates Figure 14b of the paper

import matplotlib
matplotlib.use('Agg')
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
from generate_jax_generic3 import generate_points_mixture_jax_3
from generate_centroid_manif import generate_centroid_manifold
import os
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
import itertools

# Define constants of the problem
dim = 6
L = 30

def min_signed_assignment_distance_3(params, zs):
    n = len(params)
    assert len(zs) == n == 3

    # prepare all permutations of assignments, and all sign‐flip patterns
    perms  = list(itertools.permutations(range(n)))          # 6 permutations
    signs  = list(itertools.product([1, -1], repeat=n))      # 2^3 = 8 sign patterns

    # compute squared‐distances for each combination
    all_ds = []
    for perm in perms:
        for sign_pattern in signs:
            # accumulate squared norm for this choice of assignment+signs
            d2 = 0.0
            for i, p_idx in enumerate(perm):
                z_signed = sign_pattern[i] * zs[i]
                d2 += jnp.sum((params[p_idx] - z_signed)**2)
            all_ds.append(d2)

    # take the minimum over all choices
    min_d2 = jnp.min(jnp.stack(all_ds))
    return jnp.sqrt(min_d2)



# Linear attention head
def H_lin(mu, lambda_val, points):
    a = jnp.dot(mu, points[0])
    inner = points @ mu
    b = points.T @ inner
    return (2 / L) * lambda_val * a * b

# Risk definitions
def single_sample_risk(m0, m1, m2, lambda_val, reg3, points):
    H_sum = H_lin(m0, lambda_val, points) + H_lin(m1, lambda_val, points) + H_lin(m2, lambda_val, points)
    X1 = points[0]
    diff = X1 - H_sum
    #return jnp.linalg.norm(diff) ** 2 + reg3 * ((jnp.dot(m0, X1) ** 2 * jnp.dot(m1, X1) ** 2) + (jnp.dot(m0, X1) ** 2 * jnp.dot(m2, X1) ** 2)+(jnp.dot(m1, X1) ** 2 * jnp.dot(m2, X1) ** 2))
    return jnp.linalg.norm(diff) ** 2 + reg3 * (jnp.dot(m0, X1) ** 2 * jnp.dot(m1, X1) ** 2 * jnp.dot(m2, X1) ** 2 ) 

def full_risk(m0, m1, m2, lambda_val, reg3, batch):
    return jnp.mean(jax.vmap(single_sample_risk, in_axes=(None, None, None, None, None, 0))(m0, m1, m2, lambda_val, reg3, batch))

def loss(params, lambda_val, reg3, batch):
    mu0, mu1, mu2 = params
    return full_risk(mu0, mu1, mu2, lambda_val, reg3, batch)

loss = jax.jit(loss)

# Riemannian SGD update
def update(params, reg3, lr, batch,sigma):
    lambda_v = lax.select(sigma > 0.5, 0.2, 0.6) #Choice of lambda_v
    loss_val, grads = jax.value_and_grad(loss)(params, lambda_v, reg3, batch)
    mu0, mu1, mu2 = params
    M0 = jnp.eye(dim) - jnp.outer(mu0, mu0)
    M1 = jnp.eye(dim) - jnp.outer(mu1, mu1)
    M2 = jnp.eye(dim) - jnp.outer(mu2, mu2)
    t0 = mu0 - lr * M0 @ grads[0]
    t1 = mu1 - lr * M1 @ grads[1]
    t2 = mu2 - lr * M2 @ grads[2]

    # Normalize
    t0 = t0 / jnp.linalg.norm(t0)
    t1 = t1 / jnp.linalg.norm(t1)
    t2 = t2 / jnp.linalg.norm(t2)

    return (t0, t1, t2), loss_val, grads

update = jax.jit(update)

# Training configurations
#sigma_list = [0, 0.3, 1.0]
sigma_list = [0.3, 1.0]
n_runs = 1 #10
num_iterations = 20001 #10001
batch_size = 256
lr = 0.01

# Ensure results directory exists
os.makedirs('experiments_mixture3/results', exist_ok=True)

all_runs = [] 

for sigma in sigma_list:
    logs = {"Run": [], "Iterations": [], "Distance to centroid (up to a sign)": [], "Sigma":[]}
    key = jrandom.PRNGKey(0)
    run_subkeys = jrandom.split(key, n_runs)

    for run in range(n_runs):
        key, sk0, sk1, sk2 = jrandom.split(run_subkeys[run], 4)
        dim2 = int(dim/2)
        z0 = jnp.zeros(dim)
        z0 = z0.at[dim - 1].set(1)
        z1 = jnp.zeros(dim)
        z1 = z1.at[0].set(1)
        z2 = jnp.zeros(dim)
        z2 = z2.at[dim2].set(1)

        mu0 = jrandom.normal(sk0, (dim,))
        mu0 = mu0 / jnp.linalg.norm(mu0)
        mu1 = jrandom.normal(sk1, (dim,))
        mu1 = mu1 - jnp.dot(mu1, mu0) * mu0
        mu1 = mu1 / jnp.linalg.norm(mu1)
        mu2 = jrandom.normal(sk2, (dim,))
        mu2 = mu2 - jnp.dot(mu2, mu0) * mu0 - jnp.dot(mu2, mu1) * mu1
        mu2 = mu2 / jnp.linalg.norm(mu2)
        reg3 = 0.2
        params = (mu0, mu1, mu2)

        for k in range(num_iterations):
            key, subkey = jrandom.split(key)
            batch = generate_points_mixture_jax_3(L, sigma, dim, batch_size, z0, z1, z2, subkey)
            params, loss_val, grads = update(params, reg3, lr, batch,sigma)

            # Compute Distance
            par= [params[0], params[1], params[2]]
            zs = [z0, z1, z2]
            dist = min_signed_assignment_distance_3(par, zs)
                
            if k%5000==0:
                print(f'Sigma: {sigma}, Run:{run+1}/{n_runs}, Iteration:{k}/{num_iterations-1}')
                print(f'mu0:{params[0]}, mu1:{params[1]}, mu2:{params[2]}, D:{dist}')

            logs["Run"].append(run + 1)
            logs["Iterations"].append(float(k+1))
            logs["Distance to centroid (up to a sign)"].append(float(dist))
            logs["Sigma"].append(sigma)

        all_runs.append(pd.DataFrame(logs))

df_all = pd.concat(all_runs, ignore_index=True)
# Create DataFrame and plot
#df = pd.DataFrame(logs)
plt.figure(figsize=(35, 20))
sns.set_context("notebook", font_scale=8.5)
ax=sns.lineplot(
    data=df_all,
    x="Iterations",
    y="Distance to centroid (up to a sign)",
    linewidth=3.0,
    hue="Sigma",
    palette=["blue", "orange"]
    )
ax.set_ylabel("Distance to centroid \n (up to a sign)")
plt.tight_layout()
plt.subplots_adjust(left=0.3)
plt.yscale('log')

plt.grid(True)

#fname = f"experiments_mixture3/results/plot_linear_iters_{sigma}.pdf".replace('__', '_')
#plt.savefig(fname, format="pdf", bbox_inches="tight")
plt.savefig("experiments_mixture3/results/plot_linear_iters_mixture3.pdf")
plt.close()