{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa01c19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "339efb10",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder='results_SE_ERM/act=elu_betaU=1.000_betaV=2.000_lambda=0.000_iters=5000_gamma=0.000_delta=1.000_'\n",
    "alpha_values_elu=np.load(folder+'/alpha_values_cosine.npy')\n",
    "hatm_elu=np.zeros((len(alpha_values_elu),2))\n",
    "hatq_elu=np.zeros((len(alpha_values_elu),1))\n",
    "hatV_elu=np.zeros((len(alpha_values_elu),1))\n",
    "\n",
    "m_elu=np.zeros((len(alpha_values_elu),2))\n",
    "q_elu=np.zeros((len(alpha_values_elu),1))\n",
    "V_elu=np.zeros((len(alpha_values_elu),1))\n",
    "train_loss_elu=np.zeros((len(alpha_values_elu),1))\n",
    "test_loss_elu=np.zeros((len(alpha_values_elu),1))\n",
    "for idx,alpha in enumerate(alpha_values_elu):\n",
    "    hat_list=np.load(folder+f'/hat_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    hats=(hat_list[-1])\n",
    "    hatm_elu[idx] = hats[0]\n",
    "    hatq_elu[idx] = hats[1]\n",
    "    hatV_elu[idx] = hats[2]\n",
    "    non_hat_list=np.load(folder+f'/state_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    non_hat_list=(non_hat_list[-1])\n",
    "    m_elu[idx] = non_hat_list[0]\n",
    "    q_elu[idx] = non_hat_list[1]\n",
    "    V_elu[idx] = non_hat_list[2]\n",
    "    losses=np.load(folder+f'/losses_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    train_loss_elu[idx] = losses[0]\n",
    "    test_loss_elu[idx] = losses[2]\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05cb36fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder='results_SE_ERM/act=linear_betaU=1.000_betaV=2.000_lambda=0.000_iters=5000_gamma=0.000_delta=1.000_'\n",
    "alpha_values_linear=np.load(folder+'/alpha_values_cosine.npy')\n",
    "\n",
    "hatm_linear=np.zeros((len(alpha_values_linear),2))\n",
    "hatq_linear=np.zeros((len(alpha_values_linear),1))\n",
    "hatV_linear=np.zeros((len(alpha_values_linear),1))\n",
    "\n",
    "m_linear=np.zeros((len(alpha_values_linear),2))\n",
    "q_linear=np.zeros((len(alpha_values_linear),1))\n",
    "V_linear=np.zeros((len(alpha_values_linear),1))\n",
    "train_loss_linear=np.zeros((len(alpha_values_linear),1))\n",
    "test_loss_linear=np.zeros((len(alpha_values_linear),1))\n",
    "for idx,alpha in enumerate(alpha_values_linear):\n",
    "    hat_list=np.load(folder+f'/hat_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    hats=(hat_list[-1])\n",
    "    hatm_linear[idx] = hats[0]\n",
    "    hatq_linear[idx] = hats[1]\n",
    "    hatV_linear[idx] = hats[2]\n",
    "    non_hat_list=np.load(folder+f'/state_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    non_hat_list=(non_hat_list[-1])\n",
    "    m_linear[idx] = non_hat_list[0]\n",
    "    q_linear[idx] = non_hat_list[1]\n",
    "    V_linear[idx] = non_hat_list[2]\n",
    "    losses=np.load(folder+f'/losses_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    train_loss_linear[idx] = losses[0]\n",
    "    test_loss_linear[idx] = losses[2]\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdd25fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder='results_SE_ERM/act=tanh_betaU=1.000_betaV=2.000_lambda=0.000_iters=5000_gamma=0.000_delta=1.000_'\n",
    "alpha_values_tanh=np.load(folder+'/alpha_values_cosine.npy')\n",
    "hatm_tanh=np.zeros((len(alpha_values_tanh),2))\n",
    "hatq_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "hatV_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "\n",
    "m_tanh=np.zeros((len(alpha_values_tanh),2))\n",
    "q_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "V_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "train_loss_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "test_loss_tanh=np.zeros((len(alpha_values_tanh),1))\n",
    "for idx,alpha in enumerate(alpha_values_tanh):\n",
    "    hat_list=np.load(folder+f'/hat_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    hats=(hat_list[-1])\n",
    "    hatm_tanh[idx] = hats[0]\n",
    "    hatq_tanh[idx] = hats[1]\n",
    "    hatV_tanh[idx] = hats[2]\n",
    "    non_hat_list=np.load(folder+f'/state_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    non_hat_list=(non_hat_list[-1])\n",
    "    m_tanh[idx] = non_hat_list[0]\n",
    "    q_tanh[idx] = non_hat_list[1]\n",
    "    V_tanh[idx] = non_hat_list[2]\n",
    "    losses=np.load(folder+f'/losses_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    train_loss_tanh[idx] = losses[0]\n",
    "    test_loss_tanh[idx] = losses[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1baa957",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder='results_SE_ERM/act=relu_betaU=1.000_betaV=2.000_lambda=0.000_iters=5000_gamma=0.000_delta=1.000_'\n",
    "alpha_values_relu=np.load(folder+'/alpha_values_cosine.npy')\n",
    "hatm_relu=np.zeros((len(alpha_values_relu),2))\n",
    "hatq_relu=np.zeros((len(alpha_values_relu),1))\n",
    "hatV_relu=np.zeros((len(alpha_values_relu),1))\n",
    "\n",
    "m_relu=np.zeros((len(alpha_values_relu),2))\n",
    "q_relu=np.zeros((len(alpha_values_relu),1))\n",
    "V_relu=np.zeros((len(alpha_values_relu),1))\n",
    "train_loss_relu=np.zeros((len(alpha_values_relu),1))\n",
    "test_loss_relu=np.zeros((len(alpha_values_relu),1))\n",
    "for idx,alpha in enumerate(alpha_values_relu):\n",
    "    hat_list=np.load(folder+f'/hat_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    hats=(hat_list[-1])\n",
    "    hatm_relu[idx] = hats[0]\n",
    "    hatq_relu[idx] = hats[1]\n",
    "    hatV_relu[idx] = hats[2]\n",
    "    non_hat_list=np.load(folder+f'/state_list_alpha_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    non_hat_list=(non_hat_list[-1])\n",
    "    m_relu[idx] = non_hat_list[0]\n",
    "    q_relu[idx] = non_hat_list[1]\n",
    "    V_relu[idx] = non_hat_list[2]\n",
    "    losses=np.load(folder+f'/losses_{alpha:.3f}.npy', allow_pickle=True)\n",
    "    train_loss_relu[idx] = losses[0]\n",
    "    test_loss_relu[idx] = losses[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b74afc20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import random, jit\n",
    "from jax.scipy.stats import norm\n",
    "from functools import partial \n",
    "\n",
    "@partial(jit, static_argnames=['N', 'num_experiments'])\n",
    "def aggregated_downstream_task_loss_jax(q_vars, key, beta=(1.0, 2.0), gamma=0.0, delta=1.0, \n",
    "                                        N=50, num_experiments=10_000_000):\n",
    "    \"\"\"\n",
    "    JAX-accelerated computation of the downstream loss with explicit batching simulation.\n",
    "    \"\"\"\n",
    "    # Unpack inputs\n",
    "    m, q, V = q_vars\n",
    "    # Ensure m is flattened (handling cases where it might be (1,2) or (2,))\n",
    "    m = jnp.ravel(jnp.array(m)) \n",
    "    m_u, m_v = m[0], m[1]\n",
    "    \n",
    "    # Ensure q is a scalar\n",
    "    q_val = jnp.ravel(jnp.array(q))[0]\n",
    "    \n",
    "    beta_u, beta_v = beta\n",
    "    \n",
    "    ppf_075 = norm.ppf(0.75)\n",
    "    s_v = 1.0 / jnp.sqrt(1.0 + beta_v)\n",
    "    \n",
    "    # Pool size M = 2*N per experiment\n",
    "    M = 2 * N\n",
    "    total_samples = num_experiments * M\n",
    "    \n",
    "    # Variance of Student Noise G\n",
    "    var_G = q_val - (m_v**2) * (1.0 - s_v**2)\n",
    "    var_G = jnp.maximum(var_G, 1e-12)\n",
    "    std_G = jnp.sqrt(var_G)\n",
    "    \n",
    "    # Covariance between Student G and Teacher T_noise\n",
    "    cov_GT = m_v * s_v\n",
    "    \n",
    "    # Correlation coefficient rho\n",
    "    rho = cov_GT / std_G\n",
    "    # Optional: Clip rho for numerical stability\n",
    "    # rho = jnp.clip(rho, -0.99999, 0.99999)\n",
    "    \n",
    "    # keys\n",
    "    key, k1, k2, k3 = random.split(key, 4)\n",
    "    \n",
    "    # Latents\n",
    "    lam_vals = random.normal(k1, (total_samples,))\n",
    "    nu_vals = jnp.where(jnp.abs(lam_vals) > ppf_075, 1.0, -1.0)\n",
    "    \n",
    "    # Correlated Noises\n",
    "    z1 = random.normal(k2, (total_samples,))\n",
    "    z2 = random.normal(k3, (total_samples,))\n",
    "    \n",
    "    T_noise_vals = z2\n",
    "    G_vals = std_G * (rho * z2 + jnp.sqrt(1.0 - rho**2) * z1)\n",
    "    \n",
    "    # Compute Fields\n",
    "    mu_H = (jnp.sqrt(beta_u) * m_u * lam_vals) + (jnp.sqrt(beta_v) * s_v * m_v * nu_vals)\n",
    "    H_raw = mu_H + G_vals\n",
    "    \n",
    "    # Teacher Field T\n",
    "    T_raw = jnp.sqrt(beta_v) * nu_vals + T_noise_vals\n",
    "    \n",
    "    # Batching and Aggregation\n",
    "    H_matrix = H_raw.reshape(num_experiments, M)\n",
    "    T_matrix = T_raw.reshape(num_experiments, M)\n",
    "    \n",
    "    mask_pos = (T_matrix > 0).astype(jnp.float32)\n",
    "    mask_neg = 1.0 - mask_pos\n",
    "    \n",
    "    count_pos = jnp.sum(mask_pos, axis=1)\n",
    "    count_neg = jnp.sum(mask_neg, axis=1)\n",
    "    \n",
    "    sum_H_pos = jnp.sum(H_matrix * mask_pos, axis=1)\n",
    "    sum_H_neg = jnp.sum(H_matrix * mask_neg, axis=1)\n",
    "    \n",
    "    H_a_pos = jnp.divide(sum_H_pos, jnp.maximum(count_pos, 1.0))\n",
    "    H_a_neg = jnp.divide(sum_H_neg, jnp.maximum(count_neg, 1.0))\n",
    "    \n",
    "    # Compute Loss\n",
    "    loss_pos = (jnp.sign(H_a_pos) - 1.0)**2\n",
    "    loss_neg = (jnp.sign(H_a_neg) + 1.0)**2\n",
    "    \n",
    "    # Average\n",
    "    valid_pos = (count_pos > 0).astype(jnp.float32)\n",
    "    valid_neg = (count_neg > 0).astype(jnp.float32)\n",
    "    \n",
    "    total_loss = jnp.sum(loss_pos * valid_pos) + jnp.sum(loss_neg * valid_neg)\n",
    "    total_samples_count = jnp.sum(valid_pos) + jnp.sum(valid_neg)\n",
    "    \n",
    "    return jnp.where(total_samples_count > 0, total_loss / total_samples_count, 0.0)\n",
    "\n",
    "# Wrapper\n",
    "def run_jax_downstream(q_vars, beta=(1,2), gamma=0, delta=1, N=50, num_experiments=10_000_000):\n",
    "    key = random.PRNGKey(42)\n",
    "    return float(aggregated_downstream_task_loss_jax(q_vars, key, beta, gamma, delta, N, num_experiments))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cdc0794",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_down_elu=np.zeros(len(alpha_values_elu))\n",
    "loss_down_tanh=np.zeros(len(alpha_values_tanh))\n",
    "loss_down_relu=np.zeros(len(alpha_values_relu))\n",
    "loss_down_linear=np.zeros(len(alpha_values_linear))\n",
    "\n",
    "\n",
    "print('Tanh')\n",
    "for idx, alpha in enumerate(alpha_values_tanh):\n",
    "    print(f'alpha={alpha}')\n",
    "    loss_down_tanh[idx]=run_jax_downstream(q_vars=(m_tanh[idx], q_tanh[idx], V_tanh[idx] )) \n",
    "print('Relu')\n",
    "for idx, alpha in enumerate(alpha_values_relu):\n",
    "    print(f'alpha={alpha}')\n",
    "    loss_down_relu[idx]=run_jax_downstream(q_vars=(m_relu[idx], q_relu[idx], V_relu[idx] )) \n",
    "print('Linear')\n",
    "for idx, alpha in enumerate(alpha_values_linear):\n",
    "    print(f'alpha={alpha}')\n",
    "    loss_down_linear[idx]=run_jax_downstream(q_vars=(m_linear[idx], q_linear[idx], V_linear[idx] )) \n",
    "print('Elu')\n",
    "for idx, alpha in enumerate(alpha_values_elu):\n",
    "    print(f'alpha={alpha}')\n",
    "    loss_down_elu[idx]=run_jax_downstream(q_vars=(m_elu[idx], q_elu[idx], V_elu[idx] )) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d66a86",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'downstream_theoretical/loss_down_relu.npy', loss_down_relu)\n",
    "np.save(f'downstream_theoretical/alpha_relu.npy', alpha_values_relu)\n",
    "np.save(f'downstream_theoretical/loss_down_tanh.npy', loss_down_tanh)\n",
    "np.save(f'downstream_theoretical/alpha_tanh.npy', alpha_values_tanh)\n",
    "np.save(f'downstream_theoretical/loss_down_elu.npy', loss_down_elu)\n",
    "np.save(f'downstream_theoretical/alpha_elu.npy', alpha_values_elu)\n",
    "np.save(f'downstream_theoretical/loss_down_linear.npy', loss_down_linear)\n",
    "np.save(f'downstream_theoretical/alpha_linear.npy', alpha_values_linear)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cedric_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
