{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c3952e0-be05-4008-979d-ffa996ea6529",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Ellipse\n",
    "from matplotlib import transforms\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "import torch.optim as optim\n",
    "from tqdm.auto import tqdm\n",
    "import copy\n",
    "import scipy.linalg\n",
    "\n",
    "# ---------- Params ----------\n",
    "device_str = \"cuda:0\"\n",
    "diffusion_steps = 1000\n",
    "FID_SAMPLES = 1000000\n",
    "N_base = 10**3\n",
    "EPOCHS_BASE = 10000\n",
    "LR_BASE = 1e-4\n",
    "\n",
    "# ---------- Repro / Device ----------\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "device = torch.device(device_str if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# ---------- Target Gaussian ----------\n",
    "D = 2\n",
    "sigma_ref = np.array([[2.0, 1.0], [1.0, 2.0]], dtype=np.float64)\n",
    "mu_ref = np.zeros((D,), dtype=np.float64)\n",
    "\n",
    "# ---------- Schedules ----------\n",
    "def make_cosine_schedule(n_steps: int, s: float = 0.008, device=None):\n",
    "    device = device or torch.device(\"cpu\")\n",
    "    t = torch.arange(0, n_steps, dtype=torch.float32, device=device)\n",
    "    schedule = torch.cos((t / n_steps + s) / (1 + s) * torch.pi / 2) ** 2\n",
    "    baralphas = schedule / schedule[0]\n",
    "    betas = 1.0 - baralphas / torch.cat([baralphas[0:1], baralphas[:-1]])\n",
    "    alphas = 1.0 - betas\n",
    "    return {\"alphas\": alphas, \"betas\": betas, \"baralphas\": baralphas}\n",
    "\n",
    "train_sched = make_cosine_schedule(diffusion_steps, device=device)\n",
    "\n",
    "# ---------- Model & Utilities ----------\n",
    "def noise(Xbatch: torch.Tensor, t_idx: torch.Tensor, baralphas: torch.Tensor):\n",
    "    baralpha_t = baralphas[t_idx.squeeze(-1)].unsqueeze(-1)\n",
    "    eps = torch.randn_like(Xbatch)\n",
    "    noised = baralpha_t.sqrt() * Xbatch + (1.0 - baralpha_t).sqrt() * eps\n",
    "    return noised, eps\n",
    "\n",
    "class DiffusionBlock(nn.Module):\n",
    "    def __init__(self, n: int):\n",
    "        super().__init__()\n",
    "        self.l = nn.Linear(n, n)\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return F.relu(self.l(x))\n",
    "\n",
    "class DiffusionModel(nn.Module):\n",
    "    def __init__(self, nfeatures: int, nblocks: int = 4, nunits: int = 32):\n",
    "        super().__init__()\n",
    "        self.inb = nn.Linear(nfeatures + 1, nunits)\n",
    "        self.mbs = nn.ModuleList([DiffusionBlock(nunits) for _ in range(nblocks)])\n",
    "        self.out = nn.Linear(nunits, nfeatures)\n",
    "    def forward(self, x: torch.Tensor, t: torch.Tensor, n_steps: int) -> torch.Tensor:\n",
    "        if t.dtype not in (torch.float32, torch.float64):\n",
    "            t = t.float()\n",
    "        t_scaled = t / max(1, n_steps - 1)\n",
    "        h = torch.hstack([x, t_scaled])\n",
    "        h = self.inb(h)\n",
    "        for block in self.mbs:\n",
    "            h = block(h)\n",
    "        return self.out(h)\n",
    "\n",
    "@torch.no_grad()\n",
    "def sample_diffusion(model: nn.Module, nsamples: int, nfeatures: int, n_steps: int, eta: float = 1.0, score_scale: float = 1.0) -> torch.Tensor:\n",
    "    sched = make_cosine_schedule(n_steps, device=device)\n",
    "    alphas, baralphas = sched[\"alphas\"], sched[\"baralphas\"]\n",
    "    x = torch.randn((nsamples, nfeatures), device=device)\n",
    "\n",
    "    for t in range(n_steps - 1, -1, -1):\n",
    "        t_batch = torch.full((nsamples, 1), t, device=device, dtype=torch.long)\n",
    "        eps_pred = model(x, t_batch, n_steps)\n",
    "        \n",
    "        bar_t = baralphas[t]\n",
    "        bar_tm1 = baralphas[t-1] if t > 0 else torch.tensor(1.0, device=device)\n",
    "        \n",
    "        sqrt_bar_t = torch.sqrt(bar_t)\n",
    "        sqrt_1m_bar_t = torch.sqrt(1.0 - bar_t)\n",
    "\n",
    "        score = -eps_pred / (sqrt_1m_bar_t + 1e-9)\n",
    "        score_scaled = score_scale * score\n",
    "        \n",
    "        x0_pred = (x + (1.0 - bar_t) * score_scaled) / (sqrt_bar_t + 1e-9)\n",
    "        # x0_pred.clamp_(-3., 3.) # Clamping removed for more natural plots\n",
    "        \n",
    "        if t == 0:\n",
    "            x = x0_pred\n",
    "            break\n",
    "\n",
    "        alpha_t = alphas[t]\n",
    "        term1 = (1.0 - bar_tm1) / (1.0 - bar_t + 1e-9)\n",
    "        term2 = 1.0 - alpha_t / (bar_t + 1e-9)\n",
    "        sigma_t = eta * torch.sqrt(torch.clamp(term1 * term2, min=0.0))\n",
    "        \n",
    "        dir_coeff = torch.sqrt(torch.clamp(1.0 - bar_tm1 - sigma_t**2, min=0.0))\n",
    "        \n",
    "        mean = torch.sqrt(bar_tm1) * x0_pred + dir_coeff * eps_pred\n",
    "        x = mean + (sigma_t * torch.randn_like(x) if (t > 0 and eta > 0.0) else 0.0)\n",
    "    return x\n",
    "\n",
    "def merge_models(base: nn.Module, other: nn.Module, w: float):\n",
    "    m = copy.deepcopy(base)\n",
    "    with torch.no_grad():\n",
    "        for pm, pb, po in zip(m.parameters(), base.parameters(), other.parameters()):\n",
    "            pm.copy_((1.0 + w) * pb - w * po)\n",
    "    return m\n",
    "\n",
    "# ---------- Metrics ----------\n",
    "def calculate_fid(mu, sigma, mu_r, sigma_r) -> float:\n",
    "    m = np.square(mu - mu_r).sum()\n",
    "    s, _ = scipy.linalg.sqrtm(sigma @ sigma_r, disp=False)\n",
    "    if not np.isfinite(s).all():\n",
    "        return float('inf')\n",
    "    return float(np.real(m + np.trace(sigma + sigma_r - 2.0 * s)))\n",
    "\n",
    "def fid_of_model(model: nn.Module, nsamples: int, n_steps_infer: int) -> float:\n",
    "    xs = sample_diffusion(model, nsamples, D, n_steps=n_steps_infer).cpu().numpy()\n",
    "    if np.isnan(xs).any() or np.isinf(xs).any():\n",
    "        return float('inf')\n",
    "    return calculate_fid(xs.mean(0), np.cov(xs.T), mu_ref, sigma_ref)\n",
    "\n",
    "# ===============================================\n",
    "# ================== TRAIN BASE =================\n",
    "# ===============================================\n",
    "print(\"--- Training Base Model ---\")\n",
    "Xb = torch.from_numpy(np.random.multivariate_normal(mu_ref, sigma_ref, N_base).astype(np.float32))\n",
    "base_model = DiffusionModel(D, nblocks=4).to(device)\n",
    "opt = optim.Adam(base_model.parameters(), lr=LR_BASE)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "for epoch in tqdm(range(EPOCHS_BASE), desc=\"Training Base\"):\n",
    "    for i in range(0, len(Xb), 2048):\n",
    "        xb = Xb[i : i + 2048].to(device)\n",
    "        t_idx = torch.randint(0, diffusion_steps, (len(xb), 1), device=device)\n",
    "        noised, eps = noise(xb, t_idx, train_sched[\"baralphas\"])\n",
    "        pred = base_model(noised, t_idx, diffusion_steps)\n",
    "        loss = loss_fn(pred, eps)\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "base_model.eval()\n",
    "print(\"Base model training complete.\")\n",
    "\n",
    "# ===============================================\n",
    "# === EXPERIMENT 1: FID vs. w ===\n",
    "# ===============================================\n",
    "print(\"\\n--- Running Experiment 1: FID vs. w ---\")\n",
    "EPOCHS_LIST = [50, 250]\n",
    "score_scales_fid = [0.9, 1.1]\n",
    "LR_FT = 1e-5\n",
    "N_aux = 10**3\n",
    "ws_sweep = np.linspace(-2.0, 2.0, 101)\n",
    "n_steps_eval = 20\n",
    "results_fid = []\n",
    "\n",
    "for s_scale in tqdm(score_scales_fid, desc=\"Sweeping s_scale for FID\"):\n",
    "    Xs = sample_diffusion(base_model, nsamples=N_aux, nfeatures=D, n_steps=n_steps_eval, score_scale=s_scale).cpu()\n",
    "    aux_model = copy.deepcopy(base_model).train()\n",
    "    opt_ft = optim.Adam(aux_model.parameters(), lr=LR_FT)\n",
    "    max_epochs = max(EPOCHS_LIST)\n",
    "    for epoch in range(1, max_epochs + 1):\n",
    "        for i in range(0, len(Xs), 2048):\n",
    "            xb = Xs[i : i + 2048].to(device)\n",
    "            t_idx = torch.randint(0, diffusion_steps, (len(xb), 1), device=device)\n",
    "            noised, eps = noise(xb, t_idx, train_sched[\"baralphas\"])\n",
    "            pred = aux_model(noised, t_idx, diffusion_steps)\n",
    "            loss = loss_fn(pred, eps)\n",
    "            opt_ft.zero_grad(); loss.backward(); opt_ft.step()\n",
    "        if epoch in EPOCHS_LIST:\n",
    "            aux_model.eval()\n",
    "            for w in ws_sweep:\n",
    "                merged_model = merge_models(base_model, aux_model, w)\n",
    "                fid = fid_of_model(merged_model, nsamples=FID_SAMPLES, n_steps_infer=n_steps_eval)\n",
    "                results_fid.append({'epoch': epoch, 's_scale': s_scale, 'w': w, 'fid': fid})\n",
    "            aux_model.train()\n",
    "\n",
    "df_fid = pd.DataFrame(results_fid)\n",
    "df_fid.to_csv('fid_vs_w_results.csv', index=False)\n",
    "print(\"Experiment 1 finished. Results saved to 'fid_vs_w_results.csv'\")\n",
    "\n",
    "# ===============================================\n",
    "# === EXPERIMENT 2: Cosine Similarity vs. s_scale ===\n",
    "# ===============================================\n",
    "print(\"\\n--- Running Experiment 2: Cosine Similarity vs. s_scale ---\")\n",
    "\n",
    "def _avg_grad_over_dataset(model, X, batch_size=2048, n_t_samples=16):\n",
    "    model.train()\n",
    "    loss_fn = nn.MSELoss()\n",
    "    per_param_grads = [torch.zeros_like(p, device=device) for p in model.parameters()]\n",
    "    n_batches = (len(X) + batch_size - 1) // batch_size\n",
    "    for _ in range(n_t_samples):\n",
    "        for i in range(0, len(X), batch_size):\n",
    "            xb = X[i:i+batch_size].to(device)\n",
    "            t_idx = torch.randint(0, diffusion_steps, (len(xb), 1), device=device)\n",
    "            x_noised, eps = noise(xb, t_idx, train_sched[\"baralphas\"])\n",
    "            model.zero_grad()\n",
    "            pred = model(x_noised, t_idx, diffusion_steps)\n",
    "            loss = loss_fn(pred, eps)\n",
    "            loss.backward()\n",
    "            for acc, p in zip(per_param_grads, model.parameters()):\n",
    "                if p.grad is not None: acc.add_(p.grad)\n",
    "    for acc in per_param_grads: acc.div_(n_batches * n_t_samples)\n",
    "    return [g.detach() for g in per_param_grads]\n",
    "\n",
    "def _estimate_adam_preconditioner_diagonal(model, X, batch_size=2048, n_t_samples=16, betas=(0.9, 0.999), eps=1e-8):\n",
    "    beta2 = betas[1]\n",
    "    model.train()\n",
    "    loss_fn = nn.MSELoss()\n",
    "    v = [torch.zeros_like(p, device=device) for p in model.parameters()]\n",
    "    steps = 0\n",
    "    for _ in range(n_t_samples):\n",
    "        for i in range(0, len(X), batch_size):\n",
    "            xb = X[i:i+batch_size].to(device)\n",
    "            t_idx = torch.randint(0, diffusion_steps, (len(xb), 1), device=device)\n",
    "            x_noised, eps_t = noise(xb, t_idx, train_sched[\"baralphas\"])\n",
    "            model.zero_grad()\n",
    "            pred = model(x_noised, t_idx, diffusion_steps)\n",
    "            loss = loss_fn(pred, eps_t)\n",
    "            loss.backward()\n",
    "            for vv, p in zip(v, model.parameters()):\n",
    "                if p.grad is not None: vv.mul_(beta2).addcmul_(p.grad, p.grad, value=(1.0 - beta2))\n",
    "            steps += 1\n",
    "    bc = 1.0 - (beta2 ** steps)\n",
    "    for i in range(len(v)): v[i] = v[i] / bc\n",
    "    return [1.0 / (torch.sqrt(vv) + eps) for vv in v]\n",
    "\n",
    "# --- NEW HELPER FUNCTIONS FOR COSINE SIMILARITY ---\n",
    "def _dot_with_P(g1_parts, g2_parts, P_diag_parts):\n",
    "    s = 0.0\n",
    "    for g1, g2, pd in zip(g1_parts, g2_parts, P_diag_parts):\n",
    "        s += torch.sum(g1 * (pd * g2))\n",
    "    return s.item()\n",
    "\n",
    "def _norm_with_P(g_parts, P_diag_parts):\n",
    "    norm_sq = 0.0\n",
    "    for g, pd in zip(g_parts, P_diag_parts):\n",
    "        norm_sq += torch.sum(g * (pd * g)) # This is <g, P g>\n",
    "    return torch.sqrt(norm_sq).item()\n",
    "\n",
    "S_SCALES_align = np.linspace(0.8, 1.2, 46)\n",
    "N_pop = 100_000\n",
    "N_syn = 100_000\n",
    "\n",
    "X_pop = torch.from_numpy(np.random.multivariate_normal(mu_ref, sigma_ref, size=N_pop).astype(np.float32))\n",
    "g_d_pop_parts = _avg_grad_over_dataset(base_model, X_pop)\n",
    "P_adam_diag_parts = _estimate_adam_preconditioner_diagonal(base_model, X_pop)\n",
    "norm_d = _norm_with_P(g_d_pop_parts, P_adam_diag_parts)\n",
    "\n",
    "cosine_similarities = []\n",
    "for s_scale in tqdm(S_SCALES_align, desc=\"Sweeping s_scale for Cosine Similarity\"):\n",
    "    X_syn = sample_diffusion(base_model, nsamples=N_syn, nfeatures=D, n_steps=n_steps_eval, score_scale=s_scale)\n",
    "    g_s_parts = _avg_grad_over_dataset(base_model, X_syn)\n",
    "    \n",
    "    # Calculate cosine similarity\n",
    "    inner_product = _dot_with_P(g_d_pop_parts, g_s_parts, P_adam_diag_parts)\n",
    "    norm_s = _norm_with_P(g_s_parts, P_adam_diag_parts)\n",
    "    cos_sim = inner_product / (norm_d * norm_s + 1e-9)\n",
    "    cosine_similarities.append(cos_sim)\n",
    "\n",
    "df_align = pd.DataFrame({'s_scale': S_SCALES_align, 'cosine_similarity': cosine_similarities})\n",
    "df_align.to_csv('cosine_similarity_results.csv', index=False)\n",
    "print(\"Experiment 2 finished. Results saved to 'cosine_similarity_results.csv'\")\n",
    "\n",
    "# ===============================================\n",
    "# ================== PLOTTING ===================\n",
    "# ===============================================\n",
    "print(\"\\n--- Plotting Results ---\")\n",
    "for epoch in sorted(df_fid['epoch'].unique()):\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    epoch_df = df_fid[df_fid['epoch'] == epoch]\n",
    "    for s_scale in sorted(epoch_df['s_scale'].unique()):\n",
    "        subset = epoch_df[epoch_df['s_scale'] == s_scale]\n",
    "        label = f\"$\\\\zeta={s_scale:.2f}$ (mode-seeking)\" if s_scale > 1.0 else f\"$\\\\zeta={s_scale:.2f}$ (diversity-seeking)\"\n",
    "        plt.plot(subset['w'], np.log10(subset['fid']), label=label)\n",
    "    plt.grid(True, linestyle='--', alpha=0.6)\n",
    "    plt.title(f\"FID vs. Merge Weight ($w$) after {epoch} FT Epochs\")\n",
    "    plt.xlabel(\"Merge Weight ($w$)\")\n",
    "    plt.ylabel(\"$\\\\log_{10}(\\\\mathrm{FID})$\")\n",
    "    plt.axvline(0, color='k', linestyle=':', linewidth=1, label='Base Model ($w=0$)')\n",
    "    plt.legend(title=\"Source Sampler\")\n",
    "    plt.show()\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(df_align['s_scale'], df_align['cosine_similarity'], marker='o', linestyle='-', markersize=4)\n",
    "plt.axhline(0.0, color='k', linestyle=\"--\", linewidth=1)\n",
    "plt.xlabel(\"Sampler Score Scale ($\\\\zeta$)\")\n",
    "plt.ylabel(\"Cosine Similarity $\\\\cos(\\\\theta)$\")\n",
    "plt.title(\"Gradient Alignment vs. Sampler Type\")\n",
    "plt.grid(True, linestyle='--', alpha=0.6)\n",
    "plt.ylim(-1.1, 1.1)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
