#Code that generates Figure 15a of the paper

import matplotlib
matplotlib.use('Agg')
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
from generate_jax import generate_points_mixture_jax
from matplotlib import pyplot as plt
import os
import pandas as pd
import seaborn as sns

#Define constants of the problem

dim=5

#Centroids to retrieve
z0 = jnp.zeros(dim)
z0 = z0.at[dim-1].set(1)
z1=jnp.zeros(dim)
z1 = z1.at[0].set(-1)

L=30
sigma=0.3
sigma1=1

def H_soft(mu, lambda_val, points):
    v = points @ mu
    X1= points[0,:]
    a= jnp.dot(X1,mu)
    score=lambda_val*a*v
    w=jax.nn.softmax(score)
    return w @ points

def single_sample_risk(m0, m1,b0, b1, g, lambda_val,reg3, points):
    H_m0 = H_soft(m0, lambda_val, points)
    H_m1 = H_soft(m1, lambda_val, points)
    H_sum = b0*H_m0 + b1*H_m1
    X1 = points[0,:]
    diff = X1 - H_sum 
    cent= (g/L) *jnp.sum(points, axis=0)
    return jnp.linalg.norm(diff+cent) ** 2 + reg3*(((jnp.dot(m0,X1)-1)**2 *(jnp.dot(m1,X1)-1)**2)+jnp.dot(m0,m1)**2)

def full_risk(m0, m1, b0, b1, g, lambda_val, reg3, batch):
    return jnp.mean(jax.vmap(single_sample_risk, in_axes=(None, None, None, None, None, None, None, 0))(m0, m1, b0, b1, g, lambda_val, reg3, batch))

def loss(params,reg3, batch):
    mu0, mu1 , b0 , b1, g, lambda_val  = params
    return full_risk(mu0,mu1, b0, b1, g, lambda_val, reg3, batch)

loss=jax.jit(loss)


def update(params, reg3, lr, batch):
    loss_val, grads = jax.value_and_grad(loss)(params, reg3, batch)
    mu0,mu1,b0, b1,g, lambda_val = params
    M0 = jnp.eye(dim)-jnp.outer(mu0,mu0)
    M1 = jnp.eye(dim)-jnp.outer(mu1,mu1)
    t0 = mu0 - lr * M0 @ grads[0] 
    t1 = mu1 - lr * M1 @ grads[1] 
    t0 = t0/jnp.linalg.norm(t0)
    t1 = t1/jnp.linalg.norm(t1)

    gnew = g - lr*grads[4]
    lambda_valnew = lambda_val - lr* grads[5]

    b0new = b0 
    b1new = b1 

    new_params = (t0,t1,b0new,b1new,gnew,lambda_valnew)

    return new_params, loss_val, grads

update= jax.jit(update)


logs = {
    "Run": [],
    "Regularization": [],
    "Dist sigma03":[],
    "Dist sigma1":[]
}

# Initialize parameters
n_runs=10
key = jrandom.PRNGKey(0)
run_subkeys = jrandom.split(key, n_runs)

# SGD parameters

num_iterations = 5001
batch_size=256
lr=0.01
grid_reg3v = 30
reg3v = jnp.linspace(0,3,grid_reg3v)
 

for run in range(n_runs):
    print(f'Run {run+1}/{n_runs} of initial points')
    key, subkey0, subkey1 = jrandom.split(run_subkeys[run],3)
    mu0 = jrandom.normal(subkey0, shape=(dim,))
    mu0 = mu0 / jnp.linalg.norm(mu0)
    print(f'Initial mu0: {mu0}')

    # Generate mu1 with a different subkey
    mu1 = jrandom.normal(subkey1, shape=(dim,))
    mu1 = mu1 / jnp.linalg.norm(mu1)
    mu1 = mu1 - jnp.dot(mu1, mu0) * mu0
    mu1 = mu1 / jnp.linalg.norm(mu1)
    print(f'Initial mu1: {mu1}')

    b0 = 1.0
    b1 = 1.0
    g = 2.0
    lambda_val = 3.0
    #Initial params
    params = (mu0, mu1, b0, b1, g, lambda_val)
    params1 = (mu0, mu1, b0, b1, g, lambda_val)

    for k in range(jnp.size(reg3v)):
        print(f'Run {k+1}/{grid_reg3v} of the algorithm')
        for i in range(num_iterations):
            # Split the key to get a new subkey for each iteration
            key, subkey = jrandom.split(key)
            dataset = generate_points_mixture_jax(L,sigma,dim,batch_size,subkey)
            dataset1 = generate_points_mixture_jax(L,sigma1,dim,batch_size,subkey)
                
            params, current_loss, current_grad = update(params, reg3v[k], lr, dataset)
            params1, current_loss1, current_grad1 = update(params1, reg3v[k], lr, dataset1)
            if i % 1000 == 0:
                print(f"Iteration {i}, Loss on the batch: {current_loss}, {current_loss1}")
        
        dis = jnp.sqrt(jnp.minimum(jnp.linalg.norm(params[0]-z0)**2+jnp.linalg.norm(params[1]-z1)**2,jnp.linalg.norm(params[0]-z1)**2+jnp.linalg.norm(params[1]-z0)**2))
        dissigma = jnp.sqrt(jnp.minimum(jnp.linalg.norm(params1[0]-z0)**2+jnp.linalg.norm(params1[1]-z1)**2,jnp.linalg.norm(params1[0]-z1)**2+jnp.linalg.norm(params1[1]-z0)**2))
        print(f'Distance: {dis}, {dissigma}')
        logs["Run"].append(run + 1)
        logs["Regularization"].append(float(reg3v[k]))  
        logs["Dist sigma03"].append(float(dis))
        logs["Dist sigma1"].append(float(dissigma))



df = pd.DataFrame(logs)

df_long = df.melt(id_vars="Regularization", value_vars=["Dist sigma03", "Dist sigma1"], var_name="Sigma", value_name="Distance")
sigma_map = {"Dist sigma03": r"$\sigma=0.3$", "Dist sigma1": r"$\sigma=1$"}
df_long["Sigma"] = df_long["Sigma"].map(sigma_map)

plt.figure(figsize=(35, 20))  
sns.set_context("notebook", font_scale=8.5)
ax = sns.lineplot(data=df_long, x="Regularization", y="Distance", hue="Sigma", linewidth=3.0)
ax.set_ylabel("Distance to centroid \n (up to a sign)")
plt.yscale('log')
plt.grid(True)
plt.tight_layout()

os.makedirs('softmax_experiments/results', exist_ok=True)
plt.savefig("softmax_experiments/results/plot_softmax_reg_0.3and1.pdf", format="pdf", bbox_inches="tight")
plt.close()