{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66bc658e-c5a8-4d9c-b8c5-629a9742e300",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "\n",
    "# tf32 data type is faster than standard float32\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "# custom models and functions #\n",
    "import utils\n",
    "from models import Clipper, BrainNetwork, BrainDiffusionPrior, BrainDiffusionPriorOld, VersatileDiffusionPriorNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e364c51-071f-4019-8fb1-b8b6683d4e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/new_dl/subj01/test/image_ids.npy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f0150f9-bd62-41d6-aa8c-60e9ba4d879c",
   "metadata": {},
   "outputs": [],
   "source": [
    "subj1_image_ids = np.load(f'data/new_dl/subj01/test/image_ids.npy')\n",
    "#subj1_id_to_index = {img_id: idx for idx, img_id in enumerate(subj1_image_ids)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "781157b8-30cb-43ac-9328-02105a4063df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#subj1_id_to_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a0ff01-54e4-43d2-a54d-67ec450e0b86",
   "metadata": {},
   "outputs": [],
   "source": [
    "subj=1\n",
    "\n",
    "feature_file = f'data/new_dl/subj0{subj}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj0{subj}/test/images.pt'\n",
    "\n",
    "fmris = torch.load(feature_file, map_location='cpu')\n",
    "fmris.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c6333e-dda2-41f0-b844-1dc597fe4251",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc517c18-322c-46da-98a1-74b48ff43649",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_dim = 4096\n",
    "num_voxels_list = [15724]  \n",
    "\n",
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    nn.LayerNorm(out_features),\n",
    "                    nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x)#.unsqueeze(1)\n",
    "        return out\n",
    "\n",
    "pretrained_adapter = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)\n",
    "\n",
    "reference_subj=1\n",
    "\n",
    "checkpoint_name = 'subj1_nl_sclip_basictest'\n",
    "checkpoint = torch.load(f\"train_logs/{checkpoint_name}/mid_200.pth\", map_location='cpu')\n",
    "ridge_state_dict = {k.replace('ridge.', ''): v for k, v in checkpoint['model_state_dict'].items() if k.startswith('ridge.')}\n",
    "ridge_state_dict\n",
    "pretrained_adapter.load_state_dict(ridge_state_dict, strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e328112e-162c-46dd-93ca-c9f7c50f2694",
   "metadata": {},
   "outputs": [],
   "source": [
    "ridge_state_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f3e866f-78e1-4656-b623-b64a38a5e469",
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_adapter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4afcca10-70a2-4f44-a773-66f50ee61410",
   "metadata": {},
   "outputs": [],
   "source": [
    "fmris.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7c73be-d946-43d5-8d69-d86e40315dde",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "embs=[]\n",
    "pretrained_adapter.eval()\n",
    "with torch.no_grad():\n",
    "    for i, voxel in enumerate(fmris):\n",
    "        voxel = voxel.to(device)\n",
    "        voxel = torch.mean(voxel,axis=0).float()\n",
    "        #print(voxel.shape)\n",
    "        #fmri=fmri[0].to(device)\n",
    "        embs.append([pretrained_adapter(voxel,0), subj1_image_ids[i], voxel])\n",
    "\n",
    "A=pretrained_adapter.state_dict()[f'linears.{reference_subj-1}.0.weight']#.to(torch.float64)\n",
    "\n",
    "U, S, V = torch.svd(A)\n",
    "\n",
    "print(\"U shape:\", U.shape)\n",
    "print(\"S shape:\", S.shape)\n",
    "print(\"V shape:\", V.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c592d59-6f20-42b6-9054-253439547a92",
   "metadata": {},
   "outputs": [],
   "source": [
    "evs=20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c01adb7-3789-409c-832b-07fb03bdc649",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_bins = []\n",
    "bin_coverage = [set() for _ in range(evs)]\n",
    "bin_=[]\n",
    "selected_embeddings = []\n",
    "embeddings = [emb[0] for emb in embs]\n",
    "embeddings_tensor = torch.stack(embeddings)\n",
    "print(embeddings_tensor.shape)\n",
    "\n",
    "eigenvectors=U[:evs]\n",
    "eigenvalues=S[:evs]\n",
    "projections = torch.matmul(embeddings_tensor, eigenvectors.T)\n",
    "\n",
    "\n",
    "# min_proj = projections.min(dim=0).values  # Shape: [20]\n",
    "# max_proj = projections.max(dim=0).values\n",
    "low_q = 0.02   # 1% quantile\n",
    "high_q = 1 - low_q  # 99% quantile\n",
    "\n",
    "min_proj = torch.quantile(projections, low_q, dim=0)   # shape [evs]\n",
    "max_proj = torch.quantile(projections, high_q, dim=0)  # shape [evs]\n",
    "\n",
    "lambda_1 = eigenvalues[0]\n",
    "\n",
    "w=50.0\n",
    "\n",
    "bin_index_list=[]\n",
    "\n",
    "\n",
    "def calculate_bins(w, lambda_j, lambda_1):\n",
    "    return int(w * lambda_j / lambda_1)\n",
    "for j in range(evs):  # Iterate over evs eigenvectors\n",
    "    # Determine bin edges for eigenvector j\n",
    "    bins = calculate_bins(w, eigenvalues[j], lambda_1)\n",
    "    #print(bins)\n",
    "    bin_.append(bins)\n",
    "\n",
    "for j in range(evs):  # Iterate over 20 eigenvectors\n",
    "    # Determine bin edges for eigenvector j\n",
    "    bins = max(1,int(bin_[j]))\n",
    "    #bin_.append(bins)\n",
    "    #print(bins)\n",
    "    min_val = min_proj[j].item()\n",
    "    max_val = max_proj[j].item()\n",
    "\n",
    "    bin_edges = torch.linspace(min_val, max_val, bins-1).to(device)\n",
    "    # Digitize projections for eigenvector j\n",
    "    bin_index = torch.bucketize(projections[:, j], bin_edges, right=False)\n",
    "    unique_bins, counts = torch.unique(bin_index, return_counts=True)\n",
    "    print(f\"Dimension {j}: {len(unique_bins)} bins occupied out of {bin_[j]}\")\n",
    "    bin_index_list.append(bin_index)\n",
    "    \n",
    "stacked_tensor = torch.stack(bin_index_list, dim=1)\n",
    "print(stacked_tensor.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6779aa1-3b97-41ee-9cb0-445a8b880bc6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e0c20c5-6f0d-4664-ae46-2b79969fb59c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "uncovered_bins = [set(range(bin_[i])) for i in range(evs)]\n",
    "\n",
    "# Keep track of which embeddings are selected\n",
    "selected_embeddings = []\n",
    "best_idxss_=[]\n",
    "total_empty_bins_sequence = [] \n",
    "empty_bin_diff = []\n",
    "\n",
    "# Greedy set cover algorithm\n",
    "while any(len(bins) > 0 for bins in uncovered_bins):\n",
    "    best_embedding = None\n",
    "    best_coverage = 0\n",
    "    best_covered_bins = None\n",
    "    best_idxss=None\n",
    "\n",
    "    # Iterate over each embedding\n",
    "    for idx, embedding in enumerate(stacked_tensor):\n",
    "        # Find which bins this embedding covers\n",
    "        covered_bins = [set([embedding[i].item()]) for i in range(evs)]\n",
    "\n",
    "        # Calculate how many new bins it would cover\n",
    "        new_coverage = sum(len(uncovered_bins[i] & covered_bins[i]) for i in range(evs))\n",
    "        \n",
    "\n",
    "        # Select the embedding that covers the most uncovered bins\n",
    "        if new_coverage > best_coverage:\n",
    "            best_embedding = embs[idx][1]\n",
    "            best_idxss=idx\n",
    "            best_coverage = new_coverage\n",
    "            best_covered_bins = covered_bins\n",
    "            \n",
    "    #print(best_coverage)\n",
    "    # Add the best embedding to the selected set\n",
    "    if best_embedding is None:\n",
    "        print(\"No further embeddings can cover new bins.\")\n",
    "        break\n",
    "\n",
    "    # Add the best embedding to the selected set\n",
    "    \n",
    "    selected_embeddings.append(best_embedding)\n",
    "    best_idxss_.append(best_idxss)\n",
    "\n",
    "    # Update the uncovered bins\n",
    "    if best_covered_bins is not None:\n",
    "        for i in range(evs):\n",
    "            uncovered_bins[i] -= best_covered_bins[i]\n",
    "\n",
    "    total_empty_bins = sum(len(uncovered_bins[i]) for i in range(evs))\n",
    "    if len(total_empty_bins_sequence) > 0:\n",
    "        diff = total_empty_bins_sequence[-1] - total_empty_bins\n",
    "        empty_bin_diff.append(diff)\n",
    "    #     print(f\"Iteration {len(selected_embeddings)}: Total empty bins = {total_empty_bins}, Difference = {diff}\")\n",
    "    # else:\n",
    "    #     print(f\"Iteration {len(selected_embeddings)}: Total empty bins = {total_empty_bins}\")\n",
    "    \n",
    "    # Append the current total to the sequence list\n",
    "    total_empty_bins_sequence.append(total_empty_bins)\n",
    "    #total_empty_bins_sequence.append(total_empty_bins)\n",
    "    #print(f\"Iteration {len(selected_embeddings)}: Total empty bins = {total_empty_bins}\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print(\"Total empty bins at each iteration:\", total_empty_bins_sequence)\n",
    "\n",
    "\n",
    "\n",
    "# The selected_embeddings list contains the indices of the least number of embeddings\n",
    "# that cover all bins of all eigenvectors\n",
    "print(f\"Selected embeddings: {np.sort(selected_embeddings)}, {len(selected_embeddings)}\")\n",
    "print(f\"Number of selected embeddings: {len(selected_embeddings)}\")\n",
    "print(f\"Number of differences: {empty_bin_diff}!\")\n",
    "d1=selected_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0634369c-e317-4699-96df-0f1ed4551367",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_index = [subj1_id_to_index[i] for i in selected_embeddings]\n",
    "final_index = np.array(final_index)\n",
    "final_index.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f945a7-85f9-4288-a016-36361d3d6e73",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'indices/{checkpoint_name}/subj1_global_indices_{evs}d_{int(w)}w_{final_index.shape[0]}.npy',selected_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07e94ce2-1443-4d53-b68c-550d20b2720f",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'indices/{checkpoint_name}/subj1_indices_{evs}d_{int(w)}w_{final_index.shape[0]}.npy',final_index)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
