{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6b476d82-6e6b-4267-b9b1-88eef5814dcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "import numpy as np\n",
    "import pickle\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from math import log, e\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import random\n",
    "import torch.autograd as autograd\n",
    "\n",
    "torch.manual_seed(11)\n",
    "np.random.seed(11)\n",
    "random.seed(11)\n",
    "\n",
    "plt.rcParams.update({'font.size': 13})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2151c8a1-f3c9-4814-b68c-d696c0a8dc9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, input_dim, out_channel, patch_num, small=True, nonlinear=True):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(1, out_channel*2, int(input_dim/patch_num), int(input_dim/patch_num))        \n",
    "        # small initialization\n",
    "        if small:\n",
    "            self.conv1.weight = torch.nn.Parameter(self.conv1.weight*0.001) \n",
    "            self.conv1.bias = torch.nn.Parameter(self.conv1.bias*0.001) \n",
    "        self.out_channel = out_channel\n",
    "        self.nonlinear = nonlinear\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        if self.nonlinear:\n",
    "            x = x**3\n",
    "        x = torch.sum(x,2)\n",
    "        output = torch.stack([torch.sum(x[:,:self.out_channel],1), torch.sum(x[:,self.out_channel:],1)]).transpose(1,0)\n",
    "        return output\n",
    "    \n",
    "\n",
    "# top 1 hard routing\n",
    "def top1(t):\n",
    "    values, index = t.topk(k=1, dim=-1)\n",
    "    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))\n",
    "    return values, index\n",
    "\n",
    "\n",
    "def cumsum_exclusive(t, dim=-1):\n",
    "    num_dims = len(t.shape)\n",
    "    num_pad_dims = - dim - 1\n",
    "    pre_padding = (0, 0) * num_pad_dims\n",
    "    pre_slice   = (slice(None),) * num_pad_dims\n",
    "    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)\n",
    "    return padded_t[(..., slice(None, -1), *pre_slice)]\n",
    "\n",
    "\n",
    "def safe_one_hot(indexes, max_length):\n",
    "    max_index = indexes.max() + 1\n",
    "    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]\n",
    "\n",
    "\n",
    "class Router(nn.Module):\n",
    "    def __init__(self, input_dim, out_dim, patch_num, noise=True):\n",
    "        super(Router, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(1, out_dim, int(input_dim/patch_num), int(input_dim/patch_num),bias=False)\n",
    "        self.out_dim = out_dim\n",
    "        self.break_tie_noise = torch.normal(0,1e-5,size=(DATA_NUM, EXPERT_NUM))\n",
    "        self.noise = noise\n",
    "        # zero initialization\n",
    "        self.reset_parameters()\n",
    "    \n",
    "    def reset_parameters(self):\n",
    "        self.conv1.weight = torch.nn.Parameter(self.conv1.weight * 0)\n",
    "            \n",
    "    def forward(self, x):      \n",
    "        x = self.conv1(x)\n",
    "        x = torch.sum(x,2)\n",
    "        if self.noise and self.training:\n",
    "            output = x + torch.rand(DATA_NUM, EXPERT_NUM) \n",
    "        elif self.training:\n",
    "            output = x + self.break_tie_noise\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe78d104-ab76-4c0a-8e3d-ae8885bedc04",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MoE(nn.Module):\n",
    "    def __init__(self, input_dim, out_channel, patch_num, expert_num, strategy='top1', nonlinear=True):\n",
    "        super(MoE, self).__init__()\n",
    "        self.router = Router(input_dim, expert_num, patch_num)\n",
    "        self.models = nn.ModuleList()\n",
    "        for i in range(expert_num):\n",
    "            self.models.append(ConvNet(input_dim, out_channel, patch_num, nonlinear=nonlinear))\n",
    "        self.strategy = strategy\n",
    "        self.expert_num = expert_num\n",
    "\n",
    "    def forward(self, x):\n",
    "        select = self.router(x)\n",
    "        # top 1 or choose 1 according to probability\n",
    "        if self.strategy == 'top1':\n",
    "            gate, index = top1(select)\n",
    "        else:\n",
    "            gate, index = choose1(select)\n",
    "        \n",
    "        mask = F.one_hot(index, self.expert_num).float()\n",
    "\n",
    "        density = mask.mean(dim=-2)\n",
    "        density_proxy = select.mean(dim=-2)\n",
    "        loss = (density_proxy * density).mean() * float(self.expert_num ** 2)\n",
    "\n",
    "        mask_count = mask.sum(dim=-2, keepdim=True)\n",
    "        mask_flat = mask.sum(dim=-1)\n",
    "\n",
    "        combine_tensor = (gate[..., None, None] * mask_flat[..., None, None]\n",
    "                          * F.one_hot(index, self.expert_num)[..., None])\n",
    "                          \n",
    "        dispatch_tensor = combine_tensor.bool().to(combine_tensor)\n",
    "        select0 = dispatch_tensor.squeeze(-1)\n",
    "        \n",
    "        expert_inputs = torch.einsum('bnd,ben->ebd', x, dispatch_tensor).unsqueeze(2)\n",
    "        \n",
    "        output = []\n",
    "        for i in range(self.expert_num):\n",
    "            output.append(self.models[i](expert_inputs[i]))\n",
    "        \n",
    "        output = torch.stack(output)\n",
    "        output = torch.einsum('ijk,jil->il', combine_tensor, output)\n",
    "        output = F.softmax(output,dim=1)\n",
    "\n",
    "        return output, select0, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "154938b3-19a0-4900-8c82-e6346eaa38f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, criterion, data, labels, verbose=True):\n",
    "    correct = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs,_,_ = model(data) # ,_\n",
    "        predicted = torch.max(outputs.data, 1).indices\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    if verbose:\n",
    "        print('Accuracy of the network on the %d test images: %.4f %%' % (data.shape[0],\n",
    "            100 * correct / data.shape[0]))\n",
    "    \n",
    "    return 100 * correct / data.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "22e71fb5-1a1e-4fbd-9638-220f6cade3e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_NUM = 16000\n",
    "CLUSTER_NUM = 4\n",
    "EXPERT_NUM = 8\n",
    "PATCH_NUM = 4\n",
    "PATCH_LEN = 50\n",
    "epsilon = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ddab3ca3-cc24-45aa-83f6-53c70e6c71ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  training_data = torch.load('synthetic_data_s1/train_data.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  training_labels = torch.load('synthetic_data_s1/train_labels.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  test_data = torch.load('synthetic_data_s1/test_data.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  test_labels = torch.load('synthetic_data_s1/test_labels.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  centers = torch.load('synthetic_data_s1/centers.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\1703678973.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  features = torch.load('synthetic_data_s1/features.pt')\n"
     ]
    }
   ],
   "source": [
    "training_data = torch.load('synthetic_data_s1/train_data.pt')\n",
    "training_labels = torch.load('synthetic_data_s1/train_labels.pt')\n",
    "\n",
    "test_data = torch.load('synthetic_data_s1/test_data.pt')\n",
    "test_labels = torch.load('synthetic_data_s1/test_labels.pt')\n",
    "\n",
    "centers = torch.load('synthetic_data_s1/centers.pt')\n",
    "features = torch.load('synthetic_data_s1/features.pt')\n",
    "\n",
    "with open(\"synthetic_data_s1/train_cluster\", \"rb\") as fp:  \n",
    "    train_cluster_idx = pickle.load(fp)\n",
    "    \n",
    "with open(\"synthetic_data_s1/test_cluster\", \"rb\") as fp:  \n",
    "    test_cluster_idx = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8d85954b-d369-4add-bd2b-d48f8e2c9dcb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\220291039.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 16000 test images: 95.4625 %\n",
      "Accuracy of the network on the 16000 test images: 94.1312 %\n",
      "Accuracy of the network on the 16000 test images: 73.9938 %\n",
      "Accuracy of the network on the 16000 test images: 73.8250 %\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\220291039.py:18: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 16000 test images: 95.5625 %\n",
      "Accuracy of the network on the 16000 test images: 94.1312 %\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0038), tensor(-0.0037))\n",
      "Bias range: (tensor(0.0008), tensor(-0.0008))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.0121), tensor(-0.0121))\n",
      "Bias range: (tensor(0.0210), tensor(-0.0210))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0025), tensor(-0.0024))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.0395), tensor(-0.0413))\n",
      "Bias range: (tensor(0.0069), tensor(-0.0255))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0065), tensor(-0.0064))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0046), tensor(-0.0047))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.0114), tensor(-0.0115))\n",
      "Bias range: (tensor(0.0002), tensor(-0.0002))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0026), tensor(-0.0026))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      " GLOBAL RESCALING\n",
      "GLOBAL RESCALING DONE\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0375), tensor(-0.0368))\n",
      "Bias range: (tensor(0.0082), tensor(-0.0081))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.1209), tensor(-0.1209))\n",
      "Bias range: (tensor(0.2096), tensor(-0.2101))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0247), tensor(-0.0244))\n",
      "Bias range: (tensor(0.0049), tensor(-0.0052))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.3950), tensor(-0.4134))\n",
      "Bias range: (tensor(0.0687), tensor(-0.2545))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.1202), tensor(-0.1201))\n",
      "Bias range: (tensor(0.0650), tensor(-0.0643))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.1195), tensor(-0.1195))\n",
      "Bias range: (tensor(0.0456), tensor(-0.0468))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.1143), tensor(-0.1149))\n",
      "Bias range: (tensor(0.0019), tensor(-0.0019))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0258), tensor(-0.0257))\n",
      "Bias range: (tensor(0.0050), tensor(-0.0053))\n",
      "Accuracy of the network on the 16000 test images: 95.5062 %\n",
      "Accuracy of the network on the 16000 test images: 94.0250 %\n",
      "Zeroing out elements smaller than epsilon\n",
      "Accuracy of the network on the 16000 test images: 95.3563 %\n",
      "Accuracy of the network on the 16000 test images: 94.0500 %\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\220291039.py:54: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 16000 test images: 95.4250 %\n",
      "Accuracy of the network on the 16000 test images: 94.1813 %\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0038), tensor(-0.0037))\n",
      "Bias range: (tensor(0.0008), tensor(-0.0008))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.0121), tensor(-0.0121))\n",
      "Bias range: (tensor(0.0210), tensor(-0.0210))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0025), tensor(-0.0024))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.0395), tensor(-0.0413))\n",
      "Bias range: (tensor(0.0069), tensor(-0.0255))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0065), tensor(-0.0064))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0046), tensor(-0.0047))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.0114), tensor(-0.0115))\n",
      "Bias range: (tensor(0.0002), tensor(-0.0002))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0026), tensor(-0.0026))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      " LOCAL RESCALING\n",
      "LOCAL RESCALING DONE\n",
      "Expert:0\n",
      "Weight range: (tensor(0.9000), tensor(-0.8814))\n",
      "Bias range: (tensor(0.9000), tensor(-0.8968))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.8994), tensor(-0.9000))\n",
      "Bias range: (tensor(0.8979), tensor(-0.9000))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.9000), tensor(-0.8891))\n",
      "Bias range: (tensor(0.8481), tensor(-0.9000))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.8598), tensor(-0.9000))\n",
      "Bias range: (tensor(0.2428), tensor(-0.9000))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.9000), tensor(-0.8993))\n",
      "Bias range: (tensor(0.9000), tensor(-0.8901))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.8999), tensor(-0.9000))\n",
      "Bias range: (tensor(0.8772), tensor(-0.9000))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.8958), tensor(-0.9000))\n",
      "Bias range: (tensor(0.8778), tensor(-0.9000))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.9000), tensor(-0.8972))\n",
      "Bias range: (tensor(0.8473), tensor(-0.9000))\n",
      "Accuracy of the network on the 16000 test images: 94.2062 %\n",
      "Accuracy of the network on the 16000 test images: 93.5500 %\n",
      "Zeroing out elements smaller than epsilon\n",
      "Accuracy of the network on the 16000 test images: 94.2938 %\n",
      "Accuracy of the network on the 16000 test images: 93.6312 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "93.63125"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "################ SCENARIO-I ZERO-OUT SMALL MAGNITUDE PRAMETERS #################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "############### SCENARIO-II ADJUST THE SCALE OF PARAMETERS GLOBAL #####################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "\n",
    "scale = 10\n",
    "print(\" GLOBAL RESCALING\")\n",
    "for i in range(8):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "print(\"GLOBAL RESCALING DONE\")\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "print(\"Zeroing out elements smaller than epsilon\")\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "############### SCENARIO-III ADJUST THE SCALE OF PARAMETERS LOCAL #####################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "\n",
    "print(\" LOCAL RESCALING\")\n",
    "for i in range(8):\n",
    "    max_wt = torch.max(torch.abs(lin_moe.models[i].conv1.weight.data))\n",
    "    scale = 0.9/max_wt\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    max_b = torch.max(torch.abs(lin_moe.models[i].conv1.bias.data))\n",
    "    scale = 0.9/max_b\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "print(\"LOCAL RESCALING DONE\")\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "print(\"Zeroing out elements smaller than epsilon\")\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "de859ecd-d3ae-4fd1-8959-b814f11da8b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([800, 8])\n",
      "Max: tensor(0.3950) Min: tensor(-0.4134)\n",
      "tensor(4163)\n",
      "torch.Size([16, 8])\n",
      "Max: tensor(0.2096) Min: tensor(-0.2545)\n",
      "tensor(64)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\2268568622.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    }
   ],
   "source": [
    "lin_moe = torch.load('linear_moe.pth')\n",
    "\n",
    "scale = 10\n",
    "for i in range(8):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "\n",
    "target_wt = []\n",
    "for i in range(8):\n",
    "    wt_flatten = lin_moe.models[i].conv1.weight.data.view(-1)\n",
    "    target_wt.append(wt_flatten)\n",
    "target_wt = torch.stack(target_wt, dim=0)\n",
    "target_wt = target_wt.T\n",
    "print(target_wt.shape)\n",
    "print(\"Max: \"+str(torch.max(target_wt))+\" Min: \"+str(torch.min(target_wt)))\n",
    "print(torch.sum(np.abs(target_wt)>=epsilon))\n",
    "\n",
    "target_b = []\n",
    "for i in range(8):\n",
    "    b_flatten = lin_moe.models[i].conv1.bias.data.view(-1)\n",
    "    target_b.append(b_flatten)\n",
    "target_b = torch.stack(target_b, dim=0)\n",
    "target_b = target_b.T\n",
    "print(target_b.shape)\n",
    "print(\"Max: \"+str(torch.max(target_b))+\" Min: \"+str(torch.min(target_b)))\n",
    "print(torch.sum(np.abs(target_b)>=epsilon))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "49d060fa-049d-4eb9-bdb2-3da7bb3b1d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "import copy\n",
    "from functools import reduce\n",
    "import numpy as np\n",
    "\n",
    "def countSetBits(n):\n",
    "    count = 0\n",
    "    while (n):\n",
    "        n &= int(n-1) \n",
    "        count+= 1\n",
    "    return count\n",
    "\n",
    "def subsets(arr,status,curr = 0):\n",
    "    global s\n",
    "    if(curr>=len(arr)):\n",
    "        s.append(np.sum(arr*status))\n",
    "        return\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 1\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 0\n",
    "\n",
    "def get_binary(num, digits=15):\n",
    "    bin = [0]*digits\n",
    "    start = digits-1\n",
    "    while(num>0):\n",
    "        bin[start] = num%2\n",
    "        num = num//2\n",
    "        start-=1\n",
    "    return bin\n",
    "\n",
    "def find_stats(super_set, best_len, best_ss_ind):\n",
    "    #if(best_len == -1):\n",
    "    #    return None, None, None\n",
    "    overlap = reduce(lambda x, y: int(x) & int(y), best_ss_ind)\n",
    "    overlap_len = countSetBits(overlap)\n",
    "    extra_len = 0\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        extra_len+=countSetBits(int(overlap) ^ int(best_ss_ind[i]))\n",
    "    final = reduce(lambda x, y: int(x) | int(y), best_ss_ind)\n",
    "    final_len = countSetBits(final)\n",
    "    ind_lens = [countSetBits(int(best_ss_ind[i])) for i in range(len(best_ss_ind))]\n",
    "    return best_len, overlap_len, extra_len, final_len, ind_lens\n",
    "\n",
    "def print_stats(super_set, best_len, best_ss_ind):\n",
    "    if(best_len==-1):\n",
    "        return\n",
    "    print(f'{\"Best overall subset is: \"+str(super_set) : <45}{str(get_binary(super_set)) : >25}')\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        print(f'{\"Subset \"+str(i+1)+\": \"+str(best_ss_ind[i])+\" Length: \"+str(countSetBits(best_ss_ind[i])) : <45}{str(get_binary(best_ss_ind[i])) : >25}')\n",
    "    overlap = reduce(lambda x, y: x & y, best_ss_ind)\n",
    "    print(f'{\"Overlap Subset: \"+str(overlap) : <45}{str(get_binary(overlap)) : >25}')\n",
    "    print(f'{\"Overall subset length: \" : <25}{str(countSetBits(super_set)) : >10}')\n",
    "    print(f'{\"Overlap length: \" : <25}{str(countSetBits(overlap)) : >10}')\n",
    "    extra_len = 0\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        extra_len+=countSetBits(overlap ^ best_ss_ind[i])\n",
    "    print(f'{\"Extra length: \" : <25}{str(extra_len) : >10}')\n",
    "    print(f'{\"BEST LENGTH IS: \" : <25}{str(best_len) : >10}')\n",
    "    print(\"-\"*100)\n",
    "    \n",
    "\n",
    "def subset_fixed_size(target, numbers, eps, subsize, errBest):\n",
    "    n = len(numbers)\n",
    "    cand = 0\n",
    "    indBest = np.array([np.NAN])\n",
    "    for ind in combinations(range(n),subsize):\n",
    "        inda = np.array(ind,dtype=\"int\")\n",
    "        napprox = np.sum(numbers[inda])\n",
    "        diff = np.abs(target-napprox)\n",
    "        if diff < errBest:\n",
    "            errBest = diff\n",
    "            cand = napprox\n",
    "            indBest = inda\n",
    "        if diff <= eps:\n",
    "            break\n",
    "    return cand, indBest, errBest\n",
    "\n",
    "def exhaustive(target, numbers, eps, nmax):\n",
    "    n = len(numbers)\n",
    "    err = np.abs(target)\n",
    "    errBest = err\n",
    "    cand = 0\n",
    "    indBest = np.array([-1])\n",
    "    nmax = min(nmax, n)\n",
    "    for k in range(nmax):\n",
    "        cank, indk, errk = subset_fixed_size(target, numbers, eps, k, errBest)\n",
    "        if errk < errBest:\n",
    "            errBest = errk\n",
    "            cand = cank\n",
    "            indBest = indk\n",
    "        if errBest <= eps:\n",
    "            break\n",
    "    return cand, indBest\n",
    "\n",
    "def find_best_subset_size(status, targets, experts, epsilon):\n",
    "    final_set = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        final_set = np.bitwise_or(final_set, status[i].reshape(1,-1))\n",
    "        #print(final_set)\n",
    "        final_set = np.unique(final_set.reshape(-1)).reshape(-1,1)\n",
    "    # print(final_set)\n",
    "    final_set = final_set.reshape(-1)\n",
    "    best = 100000\n",
    "    best_id = -1\n",
    "    for i in range(len(final_set)):\n",
    "        b = countSetBits(final_set[i])\n",
    "        if(b<best):\n",
    "            best = b\n",
    "            best_id = final_set[i]\n",
    "    #print(\"Best combo is: \"+str(best_id))\n",
    "    #print(\"Binary string: \"+str(get_binary(best_id)))\n",
    "    #print()\n",
    "    #print(\"Best is: \"+str(best))\n",
    "    if(not (best <= 100)):\n",
    "        print(\"Weird1: \"+str(best))\n",
    "        best = 0\n",
    "    # count[best]+=1\n",
    "    if(best==0):\n",
    "        return None, -1, None, None    \n",
    "    tot_len = 0\n",
    "    for i in range(experts):\n",
    "        status[i] = np.reshape(status[i], (-1))\n",
    "    candidates = []\n",
    "    for i in range(experts):\n",
    "        cand_id = np.argwhere((np.bitwise_and(status[i],best_id)-status[i])==0)\n",
    "        #print(cand_id)\n",
    "        combos = [status[i][id] for id in cand_id]\n",
    "        candidates.append(combos)\n",
    "        # for combo in combos:\n",
    "        #     # print(\"Combo:\"+str(combo[0])+\" Binary:\"+str(get_binary(combo[0])))\n",
    "        #     assert((np.sum(get_binary(combo[0])*rand_vars) - targets[i])<=epsilon)\n",
    "        cand_len = min([countSetBits(status[i][id]) for id in cand_id])\n",
    "        tot_len+=cand_len\n",
    "    candidates = [i[0][0] for i in candidates]\n",
    "    return best_id, best, candidates, tot_len\n",
    "\n",
    "def find_best_overall_size(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(ov_len+extra_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = ov_len+extra_len\n",
    "            #print()\n",
    "    if(not (best_len <= 100)):\n",
    "        print(\"Weird2: \"+str(best_len))\n",
    "        best_len = -1\n",
    "    if(best_len==-1):\n",
    "        return None, -1, None, None\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: int(x) & int(y), best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: int(x) | int(y), best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss\n",
    "\n",
    "def find_best_overlap_size(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(extra_len-ov_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = extra_len-ov_len\n",
    "            #print()\n",
    "    # if(not (best_len <= 100)):\n",
    "    #     print(\"Weird3: \"+str(best_len))\n",
    "    #     best_len = -1\n",
    "    # if(best_len==-1):\n",
    "    #     return None, -1, None, None\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: int(x) & int(y), best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: int(x) | int(y), best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss\n",
    "\n",
    "def find_best_extra_size(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(extra_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = extra_len\n",
    "            #print()\n",
    "    if(not (best_len <= 100)):\n",
    "        print(\"Weird: \"+str(best_len))\n",
    "        best_len = -1\n",
    "    if(best_len==-1):\n",
    "        return None, -1, None, None\n",
    "    ############## Verification #################\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: int(x) & int(y), best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: x | y, best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss\n",
    "\n",
    "def find_sparsest(status, targets, experts, epsilon):\n",
    "    subsets = np.zeros(experts)\n",
    "    for i in range(experts):\n",
    "        best = 100000000\n",
    "        for j in range(len(status[i])):\n",
    "            l = countSetBits(status[i][j])\n",
    "            if(l<best):\n",
    "                best = l\n",
    "                subsets[i] = status[i][j]\n",
    "    overlap = int(subsets[0])\n",
    "    final = int(subsets[0])\n",
    "    for i in range(1,experts):\n",
    "        overlap = overlap & int(subsets[i])\n",
    "        final = final | int(subsets[i])\n",
    "    extra_len = 0\n",
    "    for i in range(experts):\n",
    "        extra_len += countSetBits(overlap ^ int(subsets[i]))\n",
    "    overlap_len = countSetBits(overlap)\n",
    "    best_len = overlap_len + extra_len\n",
    "    return final, best_len, list(subsets), overlap\n",
    "\n",
    "def find_representational_bins(status, targets, experts, epsilon):\n",
    "    targets = (np.floor(targets/(2*epsilon))*2 + 1)*epsilon\n",
    "    return find_sparsest(status, targets, experts, epsilon)\n",
    "\n",
    "def verify(c, rand_vars, targets, epsilon):\n",
    "    for i in range(len(c)):\n",
    "        assert((np.sum(get_binary(c[i], len(rand_vars))*rand_vars) - targets[i])<=epsilon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f15133ed-41f7-4eb9-8684-a7b9339d105b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\2964554752.py:9: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "32\n",
      "33\n",
      "34\n",
      "35\n",
      "36\n",
      "37\n",
      "38\n",
      "39\n",
      "40\n",
      "41\n",
      "42\n",
      "43\n",
      "44\n",
      "45\n",
      "46\n",
      "47\n",
      "48\n",
      "49\n",
      "50\n",
      "51\n",
      "52\n",
      "53\n",
      "54\n",
      "55\n",
      "56\n",
      "57\n",
      "58\n",
      "59\n",
      "60\n",
      "61\n",
      "62\n",
      "63\n",
      "64\n",
      "65\n",
      "66\n",
      "67\n",
      "68\n",
      "69\n",
      "70\n",
      "71\n",
      "72\n",
      "73\n",
      "74\n",
      "75\n",
      "76\n",
      "77\n",
      "78\n",
      "79\n",
      "80\n",
      "81\n",
      "82\n",
      "83\n",
      "84\n",
      "85\n",
      "86\n",
      "87\n",
      "88\n",
      "89\n",
      "90\n",
      "91\n",
      "92\n",
      "93\n",
      "94\n",
      "95\n",
      "96\n",
      "97\n",
      "98\n",
      "99\n",
      "100\n",
      "101\n",
      "102\n",
      "103\n",
      "104\n",
      "105\n",
      "106\n",
      "107\n",
      "108\n",
      "109\n",
      "110\n",
      "111\n",
      "112\n",
      "113\n",
      "114\n",
      "115\n",
      "116\n",
      "117\n",
      "118\n",
      "119\n",
      "120\n",
      "121\n",
      "122\n",
      "123\n",
      "124\n",
      "125\n",
      "126\n",
      "127\n",
      "128\n",
      "129\n",
      "130\n",
      "131\n",
      "132\n",
      "133\n",
      "134\n",
      "135\n",
      "136\n",
      "137\n",
      "138\n",
      "139\n",
      "140\n",
      "141\n",
      "142\n",
      "143\n",
      "144\n",
      "145\n",
      "146\n",
      "147\n",
      "148\n",
      "149\n",
      "150\n",
      "151\n",
      "152\n",
      "153\n",
      "154\n",
      "155\n",
      "156\n",
      "157\n",
      "158\n",
      "159\n",
      "160\n",
      "161\n",
      "162\n",
      "163\n",
      "164\n",
      "165\n",
      "166\n",
      "167\n",
      "168\n",
      "169\n",
      "170\n",
      "171\n",
      "172\n",
      "173\n",
      "174\n",
      "175\n",
      "176\n",
      "177\n",
      "178\n",
      "179\n",
      "180\n",
      "181\n",
      "182\n",
      "183\n",
      "184\n",
      "185\n",
      "186\n",
      "187\n",
      "188\n",
      "189\n",
      "190\n",
      "191\n",
      "192\n",
      "193\n",
      "194\n",
      "195\n",
      "196\n",
      "197\n",
      "198\n",
      "199\n",
      "200\n",
      "201\n",
      "202\n",
      "203\n",
      "204\n",
      "205\n",
      "206\n",
      "207\n",
      "208\n",
      "209\n",
      "210\n",
      "211\n",
      "212\n",
      "213\n",
      "214\n",
      "215\n",
      "216\n",
      "217\n",
      "218\n",
      "219\n",
      "220\n",
      "221\n",
      "222\n",
      "223\n",
      "224\n",
      "225\n",
      "226\n",
      "227\n",
      "228\n",
      "229\n",
      "230\n",
      "231\n",
      "232\n",
      "233\n",
      "234\n",
      "235\n",
      "236\n",
      "237\n",
      "238\n",
      "239\n",
      "240\n",
      "241\n",
      "242\n",
      "243\n",
      "244\n",
      "245\n",
      "246\n",
      "247\n",
      "248\n",
      "249\n",
      "250\n",
      "251\n",
      "252\n",
      "253\n",
      "254\n",
      "255\n",
      "256\n",
      "257\n",
      "258\n",
      "259\n",
      "260\n",
      "261\n",
      "262\n",
      "263\n",
      "264\n",
      "265\n",
      "266\n",
      "267\n",
      "268\n",
      "269\n",
      "270\n",
      "271\n",
      "272\n",
      "273\n",
      "274\n",
      "275\n",
      "276\n",
      "277\n",
      "278\n",
      "279\n",
      "280\n",
      "281\n",
      "282\n",
      "283\n",
      "284\n",
      "285\n",
      "286\n",
      "287\n",
      "288\n",
      "289\n",
      "290\n",
      "291\n",
      "292\n",
      "293\n",
      "294\n",
      "295\n",
      "296\n",
      "297\n",
      "298\n",
      "299\n",
      "300\n",
      "301\n",
      "302\n",
      "303\n",
      "304\n",
      "305\n",
      "306\n",
      "307\n",
      "308\n",
      "309\n",
      "310\n",
      "311\n",
      "312\n",
      "313\n",
      "314\n",
      "315\n",
      "316\n",
      "317\n",
      "318\n",
      "319\n",
      "320\n",
      "321\n",
      "322\n",
      "323\n",
      "324\n",
      "325\n",
      "326\n",
      "327\n",
      "328\n",
      "329\n",
      "330\n",
      "331\n",
      "332\n",
      "333\n",
      "334\n",
      "335\n",
      "336\n",
      "337\n",
      "338\n",
      "339\n",
      "340\n",
      "341\n",
      "342\n",
      "343\n",
      "344\n",
      "345\n",
      "346\n",
      "347\n",
      "348\n",
      "349\n",
      "350\n",
      "351\n",
      "352\n",
      "353\n",
      "354\n",
      "355\n",
      "356\n",
      "357\n",
      "358\n",
      "359\n",
      "360\n",
      "361\n",
      "362\n",
      "363\n",
      "364\n",
      "365\n",
      "366\n",
      "367\n",
      "368\n",
      "369\n",
      "370\n",
      "371\n",
      "372\n",
      "373\n",
      "374\n",
      "375\n",
      "376\n",
      "377\n",
      "378\n",
      "379\n",
      "380\n",
      "381\n",
      "382\n",
      "383\n",
      "384\n",
      "385\n",
      "386\n",
      "387\n",
      "388\n",
      "389\n",
      "390\n",
      "391\n",
      "392\n",
      "393\n",
      "394\n",
      "395\n",
      "396\n",
      "397\n",
      "398\n",
      "399\n",
      "400\n",
      "401\n",
      "402\n",
      "403\n",
      "404\n",
      "405\n",
      "406\n",
      "407\n",
      "408\n",
      "409\n",
      "410\n",
      "411\n",
      "412\n",
      "413\n",
      "414\n",
      "415\n",
      "416\n",
      "417\n",
      "418\n",
      "419\n",
      "420\n",
      "421\n",
      "422\n",
      "423\n",
      "424\n",
      "425\n",
      "426\n",
      "427\n",
      "428\n",
      "429\n",
      "430\n",
      "431\n",
      "432\n",
      "433\n",
      "434\n",
      "435\n",
      "436\n",
      "437\n",
      "438\n",
      "439\n",
      "440\n",
      "441\n",
      "442\n",
      "443\n",
      "444\n",
      "445\n",
      "446\n",
      "447\n",
      "448\n",
      "449\n",
      "450\n",
      "451\n",
      "452\n",
      "453\n",
      "454\n",
      "455\n",
      "456\n",
      "457\n",
      "458\n",
      "459\n",
      "460\n",
      "461\n",
      "462\n",
      "463\n",
      "464\n",
      "465\n",
      "466\n",
      "467\n",
      "468\n",
      "469\n",
      "470\n",
      "471\n",
      "472\n",
      "473\n",
      "474\n",
      "475\n",
      "476\n",
      "477\n",
      "478\n",
      "479\n",
      "480\n",
      "481\n",
      "482\n",
      "483\n",
      "484\n",
      "485\n",
      "486\n",
      "487\n",
      "488\n",
      "489\n",
      "490\n",
      "491\n",
      "492\n",
      "493\n",
      "494\n",
      "495\n",
      "496\n",
      "497\n",
      "498\n",
      "499\n",
      "500\n",
      "501\n",
      "502\n",
      "503\n",
      "504\n",
      "505\n",
      "506\n",
      "507\n",
      "508\n",
      "509\n",
      "510\n",
      "511\n",
      "512\n",
      "513\n",
      "514\n",
      "515\n",
      "516\n",
      "517\n",
      "518\n",
      "519\n",
      "520\n",
      "521\n",
      "522\n",
      "523\n",
      "524\n",
      "525\n",
      "526\n",
      "527\n",
      "528\n",
      "529\n",
      "530\n",
      "531\n",
      "532\n",
      "533\n",
      "534\n",
      "535\n",
      "536\n",
      "537\n",
      "538\n",
      "539\n",
      "540\n",
      "541\n",
      "542\n",
      "543\n",
      "544\n",
      "545\n",
      "546\n",
      "547\n",
      "548\n",
      "549\n",
      "550\n",
      "551\n",
      "552\n",
      "553\n",
      "554\n",
      "555\n",
      "556\n",
      "557\n",
      "558\n",
      "559\n",
      "560\n",
      "561\n",
      "562\n",
      "563\n",
      "564\n",
      "565\n",
      "566\n",
      "567\n",
      "568\n",
      "569\n",
      "570\n",
      "571\n",
      "572\n",
      "573\n",
      "574\n",
      "575\n",
      "576\n",
      "577\n",
      "578\n",
      "579\n",
      "580\n",
      "581\n",
      "582\n",
      "583\n",
      "584\n",
      "585\n",
      "586\n",
      "587\n",
      "588\n",
      "589\n",
      "590\n",
      "591\n",
      "592\n",
      "593\n",
      "594\n",
      "595\n",
      "596\n",
      "597\n",
      "598\n",
      "599\n",
      "600\n",
      "601\n",
      "602\n",
      "603\n",
      "604\n",
      "605\n",
      "606\n",
      "607\n",
      "608\n",
      "609\n",
      "610\n",
      "611\n",
      "612\n",
      "613\n",
      "614\n",
      "615\n",
      "616\n",
      "617\n",
      "618\n",
      "619\n",
      "620\n",
      "621\n",
      "622\n",
      "623\n",
      "624\n",
      "625\n",
      "626\n",
      "627\n",
      "628\n",
      "629\n",
      "630\n",
      "631\n",
      "632\n",
      "633\n",
      "634\n",
      "635\n",
      "636\n",
      "637\n",
      "638\n",
      "639\n",
      "640\n",
      "641\n",
      "642\n",
      "643\n",
      "644\n",
      "645\n",
      "646\n",
      "647\n",
      "648\n",
      "649\n",
      "650\n",
      "651\n",
      "652\n",
      "653\n",
      "654\n",
      "655\n",
      "656\n",
      "657\n",
      "658\n",
      "659\n",
      "660\n",
      "661\n",
      "662\n",
      "663\n",
      "664\n",
      "665\n",
      "666\n",
      "667\n",
      "668\n",
      "669\n",
      "670\n",
      "671\n",
      "672\n",
      "673\n",
      "674\n",
      "675\n",
      "676\n",
      "677\n",
      "678\n",
      "679\n",
      "680\n",
      "681\n",
      "682\n",
      "683\n",
      "684\n",
      "685\n",
      "686\n",
      "687\n",
      "688\n",
      "689\n",
      "690\n",
      "691\n",
      "692\n",
      "693\n",
      "694\n",
      "695\n",
      "696\n",
      "697\n",
      "698\n",
      "699\n",
      "700\n",
      "701\n",
      "702\n",
      "703\n",
      "704\n",
      "705\n",
      "706\n",
      "707\n",
      "708\n",
      "709\n",
      "710\n",
      "711\n",
      "712\n",
      "713\n",
      "714\n",
      "715\n",
      "716\n",
      "717\n",
      "718\n",
      "719\n",
      "720\n",
      "721\n",
      "722\n",
      "723\n",
      "724\n",
      "725\n",
      "726\n",
      "727\n",
      "728\n",
      "729\n",
      "730\n",
      "731\n",
      "732\n",
      "733\n",
      "734\n",
      "735\n",
      "736\n",
      "737\n",
      "738\n",
      "739\n",
      "740\n",
      "741\n",
      "742\n",
      "743\n",
      "744\n",
      "745\n",
      "746\n",
      "747\n",
      "748\n",
      "749\n",
      "750\n",
      "751\n",
      "752\n",
      "753\n",
      "754\n",
      "755\n",
      "756\n",
      "757\n",
      "758\n",
      "759\n",
      "760\n",
      "761\n",
      "762\n",
      "763\n",
      "764\n",
      "765\n",
      "766\n",
      "767\n",
      "768\n",
      "769\n",
      "770\n",
      "771\n",
      "772\n",
      "773\n",
      "774\n",
      "775\n",
      "776\n",
      "777\n",
      "778\n",
      "779\n",
      "780\n",
      "781\n",
      "782\n",
      "783\n",
      "784\n",
      "785\n",
      "786\n",
      "787\n",
      "788\n",
      "789\n",
      "790\n",
      "791\n",
      "792\n",
      "793\n",
      "794\n",
      "795\n",
      "796\n",
      "797\n",
      "798\n",
      "799\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(3):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "\n",
    "len_stats['first_source'] = np.random.uniform(-1,1,15)\n",
    "for it in range(len(target_wt)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,15)\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    rand_vars *= len_stats['first_source']\n",
    "    targets = ((torch.abs(target_wt[it])>=epsilon)*target_wt[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    s = []\n",
    "    subsets(rand_vars, np.zeros_like(rand_vars))\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Extra optimal\")\n",
    "    a,b,c,d = find_sparsest(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx_2L.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "958fca1d-ef3f-4ca6-82a4-ddf92de9a5ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\2964554752.py:9: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(3):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "len_stats['first_source'] = np.random.uniform(-1,1,15)\n",
    "for it in range(len(target_b)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,15)\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    rand_vars *= len_stats['first_source']\n",
    "    targets = ((torch.abs(target_b[it])>=epsilon)*target_b[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    s = []\n",
    "    subsets(rand_vars, np.zeros_like(rand_vars))\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Extra optimal\")\n",
    "    a,b,c,d = find_sparsest(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx_2L_bias.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5f62e4fc-5ddf-4d31-9dc3-6480518f60bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\2964554752.py:9: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "32\n",
      "33\n",
      "34\n",
      "35\n",
      "36\n",
      "37\n",
      "38\n",
      "39\n",
      "40\n",
      "41\n",
      "42\n",
      "43\n",
      "44\n",
      "45\n",
      "46\n",
      "47\n",
      "48\n",
      "49\n",
      "50\n",
      "51\n",
      "52\n",
      "53\n",
      "54\n",
      "55\n",
      "56\n",
      "57\n",
      "58\n",
      "59\n",
      "60\n",
      "61\n",
      "62\n",
      "63\n",
      "64\n",
      "65\n",
      "66\n",
      "67\n",
      "68\n",
      "69\n",
      "70\n",
      "71\n",
      "72\n",
      "73\n",
      "74\n",
      "75\n",
      "76\n",
      "77\n",
      "78\n",
      "79\n",
      "80\n",
      "81\n",
      "82\n",
      "83\n",
      "84\n",
      "85\n",
      "86\n",
      "87\n",
      "88\n",
      "89\n",
      "90\n",
      "91\n",
      "92\n",
      "93\n",
      "94\n",
      "95\n",
      "96\n",
      "97\n",
      "98\n",
      "99\n",
      "100\n",
      "101\n",
      "102\n",
      "103\n",
      "104\n",
      "105\n",
      "106\n",
      "107\n",
      "108\n",
      "109\n",
      "110\n",
      "111\n",
      "112\n",
      "113\n",
      "114\n",
      "115\n",
      "116\n",
      "117\n",
      "118\n",
      "119\n",
      "120\n",
      "121\n",
      "122\n",
      "123\n",
      "124\n",
      "125\n",
      "126\n",
      "127\n",
      "128\n",
      "129\n",
      "130\n",
      "131\n",
      "132\n",
      "133\n",
      "134\n",
      "135\n",
      "136\n",
      "137\n",
      "138\n",
      "139\n",
      "140\n",
      "141\n",
      "142\n",
      "143\n",
      "144\n",
      "145\n",
      "146\n",
      "147\n",
      "148\n",
      "149\n",
      "150\n",
      "151\n",
      "152\n",
      "153\n",
      "154\n",
      "155\n",
      "156\n",
      "157\n",
      "158\n",
      "159\n",
      "160\n",
      "161\n",
      "162\n",
      "163\n",
      "164\n",
      "165\n",
      "166\n",
      "167\n",
      "168\n",
      "169\n",
      "170\n",
      "171\n",
      "172\n",
      "173\n",
      "174\n",
      "175\n",
      "176\n",
      "177\n",
      "178\n",
      "179\n",
      "180\n",
      "181\n",
      "182\n",
      "183\n",
      "184\n",
      "185\n",
      "186\n",
      "187\n",
      "188\n",
      "189\n",
      "190\n",
      "191\n",
      "192\n",
      "193\n",
      "194\n",
      "195\n",
      "196\n",
      "197\n",
      "198\n",
      "199\n",
      "200\n",
      "201\n",
      "202\n",
      "203\n",
      "204\n",
      "205\n",
      "206\n",
      "207\n",
      "208\n",
      "209\n",
      "210\n",
      "211\n",
      "212\n",
      "213\n",
      "214\n",
      "215\n",
      "216\n",
      "217\n",
      "218\n",
      "219\n",
      "220\n",
      "221\n",
      "222\n",
      "223\n",
      "224\n",
      "225\n",
      "226\n",
      "227\n",
      "228\n",
      "229\n",
      "230\n",
      "231\n",
      "232\n",
      "233\n",
      "234\n",
      "235\n",
      "236\n",
      "237\n",
      "238\n",
      "239\n",
      "240\n",
      "241\n",
      "242\n",
      "243\n",
      "244\n",
      "245\n",
      "246\n",
      "247\n",
      "248\n",
      "249\n",
      "250\n",
      "251\n",
      "252\n",
      "253\n",
      "254\n",
      "255\n",
      "256\n",
      "257\n",
      "258\n",
      "259\n",
      "260\n",
      "261\n",
      "262\n",
      "263\n",
      "264\n",
      "265\n",
      "266\n",
      "267\n",
      "268\n",
      "269\n",
      "270\n",
      "271\n",
      "272\n",
      "273\n",
      "274\n",
      "275\n",
      "276\n",
      "277\n",
      "278\n",
      "279\n",
      "280\n",
      "281\n",
      "282\n",
      "283\n",
      "284\n",
      "285\n",
      "286\n",
      "287\n",
      "288\n",
      "289\n",
      "290\n",
      "291\n",
      "292\n",
      "293\n",
      "294\n",
      "295\n",
      "296\n",
      "297\n",
      "298\n",
      "299\n",
      "300\n",
      "301\n",
      "302\n",
      "303\n",
      "304\n",
      "305\n",
      "306\n",
      "307\n",
      "308\n",
      "309\n",
      "310\n",
      "311\n",
      "312\n",
      "313\n",
      "314\n",
      "315\n",
      "316\n",
      "317\n",
      "318\n",
      "319\n",
      "320\n",
      "321\n",
      "322\n",
      "323\n",
      "324\n",
      "325\n",
      "326\n",
      "327\n",
      "328\n",
      "329\n",
      "330\n",
      "331\n",
      "332\n",
      "333\n",
      "334\n",
      "335\n",
      "336\n",
      "337\n",
      "338\n",
      "339\n",
      "340\n",
      "341\n",
      "342\n",
      "343\n",
      "344\n",
      "345\n",
      "346\n",
      "347\n",
      "348\n",
      "349\n",
      "350\n",
      "351\n",
      "352\n",
      "353\n",
      "354\n",
      "355\n",
      "356\n",
      "357\n",
      "358\n",
      "359\n",
      "360\n",
      "361\n",
      "362\n",
      "363\n",
      "364\n",
      "365\n",
      "366\n",
      "367\n",
      "368\n",
      "369\n",
      "370\n",
      "371\n",
      "372\n",
      "373\n",
      "374\n",
      "375\n",
      "376\n",
      "377\n",
      "378\n",
      "379\n",
      "380\n",
      "381\n",
      "382\n",
      "383\n",
      "384\n",
      "385\n",
      "386\n",
      "387\n",
      "388\n",
      "389\n",
      "390\n",
      "391\n",
      "392\n",
      "393\n",
      "394\n",
      "395\n",
      "396\n",
      "397\n",
      "398\n",
      "399\n",
      "400\n",
      "401\n",
      "402\n",
      "403\n",
      "404\n",
      "405\n",
      "406\n",
      "407\n",
      "408\n",
      "409\n",
      "410\n",
      "411\n",
      "412\n",
      "413\n",
      "414\n",
      "415\n",
      "416\n",
      "417\n",
      "418\n",
      "419\n",
      "420\n",
      "421\n",
      "422\n",
      "423\n",
      "424\n",
      "425\n",
      "426\n",
      "427\n",
      "428\n",
      "429\n",
      "430\n",
      "431\n",
      "432\n",
      "433\n",
      "434\n",
      "435\n",
      "436\n",
      "437\n",
      "438\n",
      "439\n",
      "440\n",
      "441\n",
      "442\n",
      "443\n",
      "444\n",
      "445\n",
      "446\n",
      "447\n",
      "448\n",
      "449\n",
      "450\n",
      "451\n",
      "452\n",
      "453\n",
      "454\n",
      "455\n",
      "456\n",
      "457\n",
      "458\n",
      "459\n",
      "460\n",
      "461\n",
      "462\n",
      "463\n",
      "464\n",
      "465\n",
      "466\n",
      "467\n",
      "468\n",
      "469\n",
      "470\n",
      "471\n",
      "472\n",
      "473\n",
      "474\n",
      "475\n",
      "476\n",
      "477\n",
      "478\n",
      "479\n",
      "480\n",
      "481\n",
      "482\n",
      "483\n",
      "484\n",
      "485\n",
      "486\n",
      "487\n",
      "488\n",
      "489\n",
      "490\n",
      "491\n",
      "492\n",
      "493\n",
      "494\n",
      "495\n",
      "496\n",
      "497\n",
      "498\n",
      "499\n",
      "500\n",
      "501\n",
      "502\n",
      "503\n",
      "504\n",
      "505\n",
      "506\n",
      "507\n",
      "508\n",
      "509\n",
      "510\n",
      "511\n",
      "512\n",
      "513\n",
      "514\n",
      "515\n",
      "516\n",
      "517\n",
      "518\n",
      "519\n",
      "520\n",
      "521\n",
      "522\n",
      "523\n",
      "524\n",
      "525\n",
      "526\n",
      "527\n",
      "528\n",
      "529\n",
      "530\n",
      "531\n",
      "532\n",
      "533\n",
      "534\n",
      "535\n",
      "536\n",
      "537\n",
      "538\n",
      "539\n",
      "540\n",
      "541\n",
      "542\n",
      "543\n",
      "544\n",
      "545\n",
      "546\n",
      "547\n",
      "548\n",
      "549\n",
      "550\n",
      "551\n",
      "552\n",
      "553\n",
      "554\n",
      "555\n",
      "556\n",
      "557\n",
      "558\n",
      "559\n",
      "560\n",
      "561\n",
      "562\n",
      "563\n",
      "564\n",
      "565\n",
      "566\n",
      "567\n",
      "568\n",
      "569\n",
      "570\n",
      "571\n",
      "572\n",
      "573\n",
      "574\n",
      "575\n",
      "576\n",
      "577\n",
      "578\n",
      "579\n",
      "580\n",
      "581\n",
      "582\n",
      "583\n",
      "584\n",
      "585\n",
      "586\n",
      "587\n",
      "588\n",
      "589\n",
      "590\n",
      "591\n",
      "592\n",
      "593\n",
      "594\n",
      "595\n",
      "596\n",
      "597\n",
      "598\n",
      "599\n",
      "600\n",
      "601\n",
      "602\n",
      "603\n",
      "604\n",
      "605\n",
      "606\n",
      "607\n",
      "608\n",
      "609\n",
      "610\n",
      "611\n",
      "612\n",
      "613\n",
      "614\n",
      "615\n",
      "616\n",
      "617\n",
      "618\n",
      "619\n",
      "620\n",
      "621\n",
      "622\n",
      "623\n",
      "624\n",
      "625\n",
      "626\n",
      "627\n",
      "628\n",
      "629\n",
      "630\n",
      "631\n",
      "632\n",
      "633\n",
      "634\n",
      "635\n",
      "636\n",
      "637\n",
      "638\n",
      "639\n",
      "640\n",
      "641\n",
      "642\n",
      "643\n",
      "644\n",
      "645\n",
      "646\n",
      "647\n",
      "648\n",
      "649\n",
      "650\n",
      "651\n",
      "652\n",
      "653\n",
      "654\n",
      "655\n",
      "656\n",
      "657\n",
      "658\n",
      "659\n",
      "660\n",
      "661\n",
      "662\n",
      "663\n",
      "664\n",
      "665\n",
      "666\n",
      "667\n",
      "668\n",
      "669\n",
      "670\n",
      "671\n",
      "672\n",
      "673\n",
      "674\n",
      "675\n",
      "676\n",
      "677\n",
      "678\n",
      "679\n",
      "680\n",
      "681\n",
      "682\n",
      "683\n",
      "684\n",
      "685\n",
      "686\n",
      "687\n",
      "688\n",
      "689\n",
      "690\n",
      "691\n",
      "692\n",
      "693\n",
      "694\n",
      "695\n",
      "696\n",
      "697\n",
      "698\n",
      "699\n",
      "700\n",
      "701\n",
      "702\n",
      "703\n",
      "704\n",
      "705\n",
      "706\n",
      "707\n",
      "708\n",
      "709\n",
      "710\n",
      "711\n",
      "712\n",
      "713\n",
      "714\n",
      "715\n",
      "716\n",
      "717\n",
      "718\n",
      "719\n",
      "720\n",
      "721\n",
      "722\n",
      "723\n",
      "724\n",
      "725\n",
      "726\n",
      "727\n",
      "728\n",
      "729\n",
      "730\n",
      "731\n",
      "732\n",
      "733\n",
      "734\n",
      "735\n",
      "736\n",
      "737\n",
      "738\n",
      "739\n",
      "740\n",
      "741\n",
      "742\n",
      "743\n",
      "744\n",
      "745\n",
      "746\n",
      "747\n",
      "748\n",
      "749\n",
      "750\n",
      "751\n",
      "752\n",
      "753\n",
      "754\n",
      "755\n",
      "756\n",
      "757\n",
      "758\n",
      "759\n",
      "760\n",
      "761\n",
      "762\n",
      "763\n",
      "764\n",
      "765\n",
      "766\n",
      "767\n",
      "768\n",
      "769\n",
      "770\n",
      "771\n",
      "772\n",
      "773\n",
      "774\n",
      "775\n",
      "776\n",
      "777\n",
      "778\n",
      "779\n",
      "780\n",
      "781\n",
      "782\n",
      "783\n",
      "784\n",
      "785\n",
      "786\n",
      "787\n",
      "788\n",
      "789\n",
      "790\n",
      "791\n",
      "792\n",
      "793\n",
      "794\n",
      "795\n",
      "796\n",
      "797\n",
      "798\n",
      "799\n"
     ]
    }
   ],
   "source": [
    "from itertools import combinations\n",
    "import copy\n",
    "from functools import reduce\n",
    "\n",
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(6):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "\n",
    "len_stats['first_source'] = np.random.uniform(-1,1,15)\n",
    "for it in range(len(target_wt)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,(8,15))\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    rand_vars *= len_stats['first_source']\n",
    "    targets = ((torch.abs(target_wt[it])>=epsilon)*target_wt[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(6):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        s = []\n",
    "        subsets(rand_vars[i], np.zeros_like(rand_vars[i]))\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(6):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    verify(c, rand_vars, targets, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # # print(\"Extra optimal\")\n",
    "    # a,b,c,d = find_best_overlap_size(status, targets, experts, epsilon)\n",
    "    # a1,b1,c1 = find_stats(a,b,c)\n",
    "    # i = 2\n",
    "    # if(a is not None):\n",
    "    #     len_stats['stats'][str(i)+'best'].append(a1)\n",
    "    #     len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "    #     len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "    #     len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "    #     len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "    #     err = []\n",
    "    #     for j in range(experts):\n",
    "    #         err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "    #     len_stats['stats'][str(i)+'errors'].append(err)\n",
    "\n",
    "    # a,b,c,d = find_best_extra_size(status, targets, experts, epsilon)\n",
    "    # a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    # i = 3\n",
    "    # if(a is not None):\n",
    "    #     len_stats['stats'][str(i)+'best'].append(a1)\n",
    "    #     len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "    #     len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "    #     len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "    #     len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "    #     err = []\n",
    "    #     for j in range(experts):\n",
    "    #         err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "    #     len_stats['stats'][str(i)+'errors'].append(err)\n",
    "\n",
    "    a,b,c,d = find_sparsest(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "\n",
    "    # a,b,c,d = find_sparsest(status, targets, experts, epsilon)\n",
    "    # a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    # i = 5\n",
    "    # if(a is not None):\n",
    "    #     len_stats['stats'][str(i)+'best'].append(a1)\n",
    "    #     len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "    #     len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "    #     len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "    #     len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "    #     err = []\n",
    "    #     for j in range(experts):\n",
    "    #         err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "    #     len_stats['stats'][str(i)+'errors'].append(err)\n",
    "\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx_2L_diff.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f4ebfd6e-6265-4fa3-b59b-51ba7d0f791c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_23592\\2964554752.py:9: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(3):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "len_stats['first_source'] = np.random.uniform(-1,1,15)\n",
    "for it in range(len(target_b)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,(8,15))\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    rand_vars *= len_stats['first_source']\n",
    "    targets = ((torch.abs(target_b[it])>=epsilon)*target_b[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        s = []\n",
    "        subsets(rand_vars[i], np.zeros_like(rand_vars[i]))\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Extra optimal\")\n",
    "    a,b,c,d = find_sparsest(status, targets, experts, epsilon)\n",
    "    a1,b1,c1,d1,e1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx_2L_diff_bias.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "08e728d2-77b3-4c2b-9e7b-a4c21910e088",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dummy(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(extra_len-ov_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = extra_len-ov_len\n",
    "            #print()\n",
    "    # if(not (best_len <= 100)):\n",
    "    #     print(\"Weird3: \"+str(best_len))\n",
    "    #     best_len = -1\n",
    "    # if(best_len==-1):\n",
    "    #     return None, -1, None, None\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: int(x) & int(y), best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: int(x) | int(y), best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "9e8b8c5f-387a-49f4-b9e9-1856fe03d31e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(31677, -1, [31396, 15012, 31420, 15269], 15012)"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dummy(status, targets, experts, epsilon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a37067ed-ad03-43f6-a943-fcb06e0ac8d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_6428\\951309503.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  st = torch.load(\"final_stats_ssa_4_exp_diff0.001.pth\")\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_6428\\951309503.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  st_s = torch.load(\"final_stats_ssa_4_exp0.001.pth\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "st = torch.load(\"final_stats_ssa_4_exp_diff0.001.pth\")\n",
    "st_s = torch.load(\"final_stats_ssa_4_exp0.001.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e09ca25c-8f26-470d-931c-563703025646",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overlap\n",
      "0.5427104722792607\n",
      "0.6797066014669927\n",
      "0.7144182771285037\n",
      "0.7835798018902356\n",
      "Total\n",
      "16.40400410677618\n",
      "16.01650366748166\n",
      "2.7858406256944384\n",
      "2.6958593322556914\n",
      "\n",
      "Overlap\n",
      "1.693634496919918\n",
      "1.9546658516707416\n",
      "1.1942626938022929\n",
      "1.2346377234103108\n",
      "Total\n",
      "12.487371663244353\n",
      "12.053789731051344\n",
      "1.4635701726830386\n",
      "1.3522596856134776\n",
      "\n",
      "Overlap\n",
      "5.130082135523614\n",
      "6.229421352893236\n",
      "1.609270688638568\n",
      "1.5513271887598756\n",
      "Total\n",
      "15.325154004106777\n",
      "15.689079054604727\n",
      "2.2351779019778553\n",
      "2.1608379361378325\n",
      "\n",
      "Overlap\n",
      "3.525770020533881\n",
      "4.1312143439282805\n",
      "1.6446831680570404\n",
      "1.6571753765701163\n",
      "Total\n",
      "13.211704312114989\n",
      "12.987061939690301\n",
      "1.8156140517323431\n",
      "1.7743347315091813\n",
      "\n",
      "Overlap\n",
      "0.04496919917864477\n",
      "0.08425020374898126\n",
      "0.2111626879355238\n",
      "0.28678534036560127\n",
      "Total\n",
      "14.5717659137577\n",
      "14.421352893235534\n",
      "2.1002306421276122\n",
      "2.080424646757769\n",
      "\n",
      "Overlap\n",
      "0.04496919917864477\n",
      "0.08445669934640523\n",
      "0.2111626879355238\n",
      "0.2878167319046429\n",
      "Total\n",
      "14.5717659137577\n",
      "14.416053921568627\n",
      "2.1002306421276122\n",
      "2.0688959632848594\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "for i in range(6):\n",
    "    print(\"Overlap\")\n",
    "    print(np.mean(st[16][i]['overlap']))\n",
    "    print(np.mean(st_s[16][i]['overlap']))\n",
    "    print(np.std(st[16][i]['overlap']))\n",
    "    print(np.std(st_s[16][i]['overlap']))\n",
    "    print(\"Total\")\n",
    "    print(np.mean(st[16][i]['total']))\n",
    "    print(np.mean(st_s[16][i]['total']))\n",
    "    print(np.std(st[16][i]['total']))\n",
    "    print(np.std(st_s[16][i]['total']))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "07cfb470-6f4d-45a9-902f-72b52a0f7816",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7124965116121322"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(st[17][0]['overlap'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b8e8b0-5625-4fb8-a64f-d1373a101873",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
