{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "629cadd1-a606-43a2-8b2f-212c5d584445",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtaoprajjwal\u001b[0m (\u001b[33mpj-runs\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import wandb\n",
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "622f88ee-e4a0-4632-8687-d999093c43a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1554677c8c10>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(69)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5bec217c-b0c0-4b13-bdbb-41da31e2b93e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import math\n",
    "device = torch.device(\"cuda\")\n",
    "\n",
    "def generate_low_inner_product_vectors(n, epsilon, device='cpu', seed=None):\n",
    "    if seed is not None:\n",
    "        torch.manual_seed(seed)\n",
    "        \n",
    "    k = int(math.exp((epsilon**2 * n) / 4))\n",
    "    \n",
    "    signs = torch.randint(0, 2, (k, n), device=device, dtype=torch.float32) * 2 - 1  # ±1\n",
    "    vectors = signs / math.sqrt(n)  # Normalize each to have norm 1\n",
    "\n",
    "    ip_matrix = torch.matmul(vectors, vectors.T)\n",
    "    mask = torch.eye(k, device=device).bool()\n",
    "    ip_matrix[mask] = 0  # zero out diagonal\n",
    "\n",
    "    max_ip = ip_matrix.abs().max().item()\n",
    "    success = (ip_matrix.abs() < epsilon).all().item()\n",
    "\n",
    "    print(f\"Generated {k} vectors in {n} dimensions.\")\n",
    "    print(f\"Maximum off-diagonal inner product: {max_ip:.4f}\")\n",
    "    print(\"All inner products < ε:\", success)\n",
    "\n",
    "    return vectors\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6601bab3-b428-47ac-8b57-0528b50b1b5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated 22026 vectors in 1000 dimensions.\n",
      "Maximum off-diagonal inner product: 0.1840\n",
      "All inner products < ε: True\n"
     ]
    }
   ],
   "source": [
    "vecs = generate_low_inner_product_vectors(1000, 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "97e3b912-e3ea-47c9-b790-9b72bc7affac",
   "metadata": {},
   "outputs": [],
   "source": [
    "proj = torch.randn(1000, 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ed75a32a-db40-49ef-a352-3d5b9edcb8cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2 = torch.randn(22026, 500)\n",
    "stud_vec_2 = stud_vec_2 / stud_vec_2.norm(dim=1, keepdim=True)\n",
    "stud_vec_2_init = stud_vec_2.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "d1b00409-0bc8-47e8-a16c-78db56d3add6",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2 = stud_vec_2.requires_grad_(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b8e86380-736b-4549-9673-68d06f17162c",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vecs = vecs @ proj\n",
    "stud_vecs = stud_vecs / stud_vecs.norm(dim=1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "300734db-2eae-4b7c-99d1-ad39c0f861d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "student_vecs = stud_vecs.requires_grad_(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "95a178dc-5d8a-4cb4-80b7-2a3f0042b9f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def luby_mis_from_dense(adj_dense, max_iters=1000, device=device):\n",
    "    \"\"\"\n",
    "    Compute an approximate maximal independent set using Luby's algorithm.\n",
    "    \n",
    "    Args:\n",
    "        adj_dense (torch.Tensor): Dense n x n adjacency matrix with 1s as edges, 0s elsewhere.\n",
    "        max_iters (int): Maximum iterations to try.\n",
    "        device (str): 'cpu' or 'cuda'.\n",
    "    \n",
    "    Returns:\n",
    "        torch.Tensor: Indices of nodes in the MIS.\n",
    "    \"\"\"\n",
    "    adj_dense = adj_dense.to(device)\n",
    "    n = adj_dense.size(0)\n",
    "    adj_sparse = adj_dense.to_sparse().coalesce()\n",
    "\n",
    "    in_set = torch.zeros(n, dtype=torch.bool, device=device)\n",
    "    remaining = torch.ones(n, dtype=torch.bool, device=device)\n",
    "\n",
    "    for _ in range(max_iters):\n",
    "        if not remaining.any():\n",
    "            break\n",
    "\n",
    "        priorities = torch.rand(n, device=device)\n",
    "        priorities[~remaining] = -1e9  # Effectively ignore these\n",
    "\n",
    "        neighbor_max = torch.sparse.mm(adj_sparse, priorities.unsqueeze(1)).squeeze(1)\n",
    "\n",
    "        selected = (priorities > neighbor_max) & remaining\n",
    "        in_set[selected] = True\n",
    "\n",
    "        selected_mask = selected.to(adj_dense.dtype).unsqueeze(0)  # (1 x n)\n",
    "        neighbors_of_selected = (selected_mask @ adj_dense).squeeze(0).bool()\n",
    "\n",
    "        to_remove = selected | neighbors_of_selected\n",
    "        remaining[to_remove] = False\n",
    "\n",
    "    return torch.nonzero(in_set, as_tuple=False).squeeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "adff029d-2978-41ed-b95e-66f47dc0dcf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_adjc(vec, eps):\n",
    "    norm_vec = vec / vec.norm(dim=1, keepdim=True)\n",
    "    ip_matrix = vec@vec.T\n",
    "    mask = torch.eye(len(vec)).bool().to(device)\n",
    "    ip_matrix[mask] = 0\n",
    "\n",
    "    return (ip_matrix.abs() < eps).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d7d7671b-3aa3-42cf-b9be-d6ad81574452",
   "metadata": {},
   "outputs": [],
   "source": [
    "import similarity_measures as sim\n",
    "import similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "0a21e3eb-5dd8-4e76-893c-b6ba30f993fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "lin = sim.LinearMeasure(approx=True)\n",
    "cka = sim.CKA(biased=False)\n",
    "measure = similarity.make(\"measure/netrep/procrustes-distance=euclidean\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "52bd1459-33b4-47dd-95eb-e6c3df76993f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "def copy_and_clone(stud_vec):\n",
    "    stud_vec_2 = stud_vec.clone().detach()\n",
    "    stud_vec_2.requires_grad_(True)\n",
    "    dataset = TensorDataset(vecs, stud_vec_2)\n",
    "    return stud_vec_2, dataset\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9da7a467-ee6f-4b8a-845b-1335cc13c603",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kernel_frobenius(x_vec, y_vec):\n",
    "    k_x = x_vec@ x_vec.T\n",
    "    k_y = y_vec@ y_vec.T\n",
    "\n",
    "    return torch.norm(k_x -k_y, p=\"fro\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "649512a4-a91d-40a7-91f1-1120c25473fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "stud_vec_2_cka, dataset_cka = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_cka, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_cka])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"cka_minim_projected_stud\")\n",
    "for epoch in range(10):\n",
    "    for batch in tqdm(dl, desc=f\"Epoch {epoch+1}\"):\n",
    "        a,b = batch\n",
    "        loss = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "        shape = lin(b.unsqueeze(1), a.unsqueeze(1))\n",
    "        k_frob = get_kernel_frobenius(a,b)\n",
    "        adj = get_adjc(stud_vec_2_cka, 0.2)\n",
    "        s = luby_mis_from_dense(1-adj)\n",
    "        wandb.log( {\"shape\": shape.item(), \"cka\":loss.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\":k_frob})\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f64cbf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "stud_vec_2_shape, dataset_shape = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_shape, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_shape])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"shape_minim_projected_stud\")\n",
    "for epoch in range(10):\n",
    "    for batch in tqdm(dl, desc=f\"Epoch {epoch+1}\"):\n",
    "        a,b = batch\n",
    "        ck = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "        loss = lin(b.unsqueeze(1),  a.unsqueeze(1))\n",
    "        adj = get_adjc(stud_vec_2_shape, 0.2)\n",
    "        k_frob = get_kernel_frobenius(a,b)\n",
    "        s = luby_mis_from_dense(1-adj)\n",
    "        wandb.log( {\"shape\": loss.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s) ,\"kernel_frob\":k_frob})\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52bfabe5-e59b-4f38-8f4a-a0b8ff5d02e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2_kfrob, dataset_kfrob = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_kfrob, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_kfrob])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"kernel_frobenius_diff_projected_stud\")\n",
    "for epoch in range(10):\n",
    "    for batch in tqdm(dl, desc=f\"Epoch {epoch+1}\"):\n",
    "        a,b = batch\n",
    "        ck = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "        shape = lin(b.unsqueeze(1),  a.unsqueeze(1))\n",
    "        adj = get_adjc(stud_vec_2_kfrob, 0.2)\n",
    "        loss = get_kernel_frobenius(a,b)\n",
    "        s = luby_mis_from_dense(1-adj)\n",
    "        print({\"shape\": shape.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\": loss.item()})\n",
    "        wandb.log( {\"shape\": shape.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\": loss.item()})\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f9f6912-3fe9-4b4e-827c-e8f4ff5e5dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2_linear, dataset_linear = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_linear, batch_size=256, shuffle=False)\n",
    "linear_transform = torch.rand(1000,500).requires_grad_(True)\n",
    "optim = torch.optim.AdamW([stud_vec_2_linear,linear_transform])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"linear_projection_diff_projected_stud\")\n",
    "for epoch in range(10):\n",
    "    for batch in tqdm(dl, desc=f\"Epoch {epoch+1}\"):\n",
    "        a,b = batch\n",
    "        ck = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "        shape = lin(b.unsqueeze(1),  a.unsqueeze(1))\n",
    "        adj = get_adjc(stud_vec_2_linear, 0.2)\n",
    "        kernel_frob = get_kernel_frobenius(a,b) \n",
    "        loss = torch.norm(a@linear_transform - b)\n",
    "        s = luby_mis_from_dense(1-adj)\n",
    "        print({\"shape\": shape.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\": kernel_frob, \"loss\": loss})\n",
    "        wandb.log( {\"shape\": shape.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\": kernel_frob, \"loss\": loss})\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d74e13-1a83-4d62-9127-f8b427a3dce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2_cka, dataset_cka = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_cka, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_cka])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"cka_minim_projected_stud\")\n",
    "for epoch in range(100):\n",
    "    a,b = dataset_cka[:]\n",
    "    loss = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "    shape = lin(b.unsqueeze(1), a.unsqueeze(1))\n",
    "    k_frob = get_kernel_frobenius(a,b)\n",
    "    adj = get_adjc(stud_vec_2_cka, 0.2)\n",
    "    s = luby_mis_from_dense(1-adj)\n",
    "    wandb.log( {\"shape\": shape.item(), \"cka\":loss.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\":k_frob})\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd53cd2-6fce-430e-ab4d-bcceccf35ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2_shap, dataset_shap = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_shap, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_shap])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"shape_minim_projected_stud\")\n",
    "for epoch in range(100):\n",
    "    a,b = dataset_shap[:]\n",
    "    ck = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "    loss = lin(b.unsqueeze(1), a.unsqueeze(1))\n",
    "    k_frob = get_kernel_frobenius(a,b)\n",
    "    adj = get_adjc(stud_vec_2_shap, 0.2)\n",
    "    s = luby_mis_from_dense(1-adj)\n",
    "    wandb.log( {\"shape\": loss.item(), \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\":k_frob})\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c148a7-5eb9-4593-a294-784374f947e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "stud_vec_2_kfrob, dataset_kfrob = copy_and_clone(stud_vecs)\n",
    "dl = DataLoader(dataset_kfrob, batch_size=256, shuffle=False)\n",
    "optim = torch.optim.AdamW([stud_vec_2_kfrob])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"kfrob_minim_projected_stud\")\n",
    "for epoch in range(100):\n",
    "    a,b = dataset_kfrob[:]\n",
    "    ck = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "    shape = lin(b.unsqueeze(1), a.unsqueeze(1))\n",
    "    loss = get_kernel_frobenius(a,b)\n",
    "    adj = get_adjc(stud_vec_2_kfrob, 0.2)\n",
    "    s = luby_mis_from_dense(1-adj)\n",
    "    wandb.log( {\"shape\": shape, \"cka\":ck.item(), \"eps-orth-vectors\": len(s), \"kernel_frob\":loss.item()})\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    optim.zero_grad()\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d07df1b-d649-450c-be2c-d706d6bc1597",
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.AdamW([stud_vec_2])\n",
    "wandb.init(project=\"\", entity=\"\", name=\"cka_minim\")\n",
    "for epoch in range(20):\n",
    "    for batch in dl:\n",
    "        a,b = batch\n",
    "        loss = 1- (cka(b.unsqueeze(1), a.unsqueeze(1)))\n",
    "        shape = measure(b.detach().numpy(), a.detach().numpy())\n",
    "        wandb.log( {\"shape\": shape, \"cka\":loss.item()})\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        optim.zero_grad()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "GPU",
   "language": "python",
   "name": "gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
