#Code that generates Figure 9 and 10 of the paper
#Warning, may take more than 12 hours to run.

import os
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax import lax
from generate_jax import generate_points_mixture_jax
from generate_centroid_manif import generate_centroid_manifold
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

# Linear attention head function

def H_lin(mu, lambda_val, L, points):
    X1 = points[0, :]
    a = jnp.dot(mu, X1)
    inner = points @ mu
    b = points.T @ inner
    return (2 / L) * lambda_val * a * b


def single_sample_risk(m0, m1, lambda_val, L, reg3, points):
    H_m0 = H_lin(m0, lambda_val, L, points)
    H_m1 = H_lin(m1, lambda_val, L, points)
    H_sum = H_m0 + H_m1

    X1 = points[0, :]
    diff = X1 - H_sum
    return jnp.linalg.norm(diff) ** 2 + reg3 * ((jnp.dot(m0, X1))**2 * (jnp.dot(m1, X1))**2)


def full_risk(m0, m1, lambda_val, L, reg3, batch):
    return jnp.mean(jax.vmap(single_sample_risk, in_axes=(None, None, None, None, None, 0))(m0, m1, lambda_val, L, reg3, batch))


def loss(params, lambda_val, L, reg3, batch):
    mu0, mu1 = params
    return full_risk(mu0, mu1, lambda_val, L, reg3, batch)

# JIT-compiled loss
loss = jax.jit(loss)


def update(params, L, reg3, lr, batch, sigma):
    lambda_val = lax.select(sigma > 0.5, 0.2, 0.6)
    loss_val, grads = jax.value_and_grad(loss)(params, lambda_val, L, reg3, batch)
    mu0, mu1 = params
    dim = mu0.shape[0]
    M0 = jnp.eye(dim) - jnp.outer(mu0, mu0)
    M1 = jnp.eye(dim) - jnp.outer(mu1, mu1)
    t0 = mu0 - lr * M0 @ grads[0]
    t1 = mu1 - lr * M1 @ grads[1]
    t0 = t0 / jnp.linalg.norm(t0)
    t1 = t1 / jnp.linalg.norm(t1)
    new_params = (t0, t1)
    return new_params, loss_val, grads

# JIT-compiled update
update = jax.jit(update)


def solve(mu0, mu1, L, sigma, dim, num_iterations, batch_size, key, on_manifold_flag):
    # Initialize parameters
    params = [jnp.zeros(dim) for _ in range(2)]
    params[0] = mu0
    params[1] = mu1
    # Reference centroids
    z0 = jnp.zeros(dim).at[dim - 1].set(1)
    z1 = jnp.zeros(dim).at[0].set(-1)

    for i in range(num_iterations):
        key, subkey = jrandom.split(key)
        dataset = generate_points_mixture_jax(L, sigma, dim, batch_size, subkey)
        if on_manifold_flag:
            params, current_loss, current_grad = update(params, L, 0.0, lr, dataset, sigma)
        else:
            params, current_loss, current_grad = update(params, L, 0.2, lr, dataset, sigma)

        # compute RMSE distance
        d1 = (jnp.minimum(
            jnp.linalg.norm(params[0] - z0)**2,jnp.linalg.norm(params[0] + z0)**2) + jnp.minimum(jnp.linalg.norm(params[1] - z1)**2,jnp.linalg.norm(params[1] + z1)**2))
        d2 = (jnp.minimum(jnp.linalg.norm(params[0] - z1)**2,jnp.linalg.norm(params[0] + z1)**2) + jnp.minimum(jnp.linalg.norm(params[1] - z0)**2,jnp.linalg.norm(params[1] + z0)**2))
        dis = jnp.sqrt(jnp.minimum(d1, d2)) / jnp.sqrt(dim)

        if i % 1000 == 0:
            print(f"Iteration {i}/{num_iterations-1}, RMSE: {dis}")

    return params[0], params[1], dis

# Experiment configuration
noise_levels = [0.0, 0.3, 1.0]
manifold_flags = [0, 1]

n_runs = 10
num_iterations = 5001
batch_size = 256
L = 30

low_dim = 4
big_dim = 200
num_dim = 10
lr = 0.01

dimensions = jnp.linspace(low_dim, big_dim, num_dim, dtype=int)

os.makedirs('experiments_varying_dimension/results', exist_ok=True)

# Outer loops over sigma and on_manifold
for sigma in noise_levels:
    for on_manifold in manifold_flags:
        logs = {"Run": [], "Dimension": [], "RMSE": []}
        key = jrandom.PRNGKey(0)
        run_subkeys = jrandom.split(key, n_runs)

        for dim in dimensions:
            print(f'Sigma:{sigma}, manifold={on_manifold}, dim={dim}')
            z0 = jnp.zeros(dim).at[dim-1].set(1)
            z1 = jnp.zeros(dim).at[0].set(-1)

            for run in range(n_runs):
                sub0, sub1 = jrandom.split(run_subkeys[run], 2)
                if on_manifold:
                    mu0, mu1 = generate_centroid_manifold(z0, z1, sub0, sub1, dim)
                else:
                    mu0 = jrandom.normal(sub0, (dim,))
                    mu0 /= jnp.linalg.norm(mu0)
                    mu1 = jrandom.normal(sub1, (dim,))
                    mu1 -= jnp.dot(mu1, mu0) * mu0
                    mu1 /= jnp.linalg.norm(mu1)

                mu0_opt, mu1_opt, dis_opt = solve(mu0, mu1, L, sigma, dim, num_iterations, batch_size, key, on_manifold)

                logs["Run"].append(run+1)
                logs["Dimension"].append(int(dim))
                logs["RMSE"].append(float(dis_opt))

        df = pd.DataFrame(logs)
        plt.figure(figsize=(35, 20))
        sns.set_context("notebook", font_scale=8.5)
        sns.lineplot(data=df, x="Dimension", y="RMSE", linewidth=3.0)
        plt.yscale('log')
        plt.grid(True)
        plt.tight_layout()

        suffix = "manifold" if on_manifold else ""
        fname = f"experiments_varying_dimension/results/plot_linear_{suffix}_iters_{sigma}.pdf".replace('__', '_')
        plt.savefig(fname, format="pdf", bbox_inches="tight")
        plt.close()
