{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aeef2de5-3ea1-4cb3-9bd2-faf27ae93d15",
   "metadata": {},
   "source": [
    "# Synthetic Data Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "844d8800-42e5-4c1a-a222-bd413994335b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "import numpy as np\n",
    "import pickle\n",
    "import random\n",
    "\n",
    "torch.manual_seed(12)\n",
    "np.random.seed(12)\n",
    "random.seed(12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "21fe9641-4022-4efa-bb17-6d4ee4c55379",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_NUM = 16000\n",
    "CLUSTER_NUM = 4\n",
    "EXPERT_NUM = 8\n",
    "PATCH_NUM = 4\n",
    "PATCH_LEN = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9e24a1e2-9f5c-4116-ba50-9c9b1d799218",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 200])\n",
      "torch.Size([32000])\n",
      "torch.Size([16000, 1, 200]) torch.Size([16000, 1, 200])\n",
      "torch.Size([16000]) torch.Size([16000])\n"
     ]
    }
   ],
   "source": [
    "features = torch.zeros(CLUSTER_NUM, PATCH_LEN)\n",
    "x = np.random.randn(PATCH_LEN)\n",
    "x /= np.linalg.norm(x)\n",
    "current_x = []\n",
    "for i in range(CLUSTER_NUM):\n",
    "    features[i] = torch.tensor(x)\n",
    "    current_x.append(x)\n",
    "    x = np.random.randn(PATCH_LEN)\n",
    "    x /= np.linalg.norm(x)\n",
    "    for x_prev in current_x:\n",
    "        x -= x.dot(x_prev) * x_prev\n",
    "    x /= np.linalg.norm(x)\n",
    "\n",
    "centers = torch.zeros(CLUSTER_NUM, PATCH_LEN)\n",
    "for i in range(CLUSTER_NUM):\n",
    "    centers[i] = torch.tensor(x)\n",
    "    if i!=3:\n",
    "        current_x.append(x)\n",
    "        x = np.random.randn(PATCH_LEN)\n",
    "        x /= np.linalg.norm(x)\n",
    "        for x_prev in current_x:\n",
    "            x -= x.dot(x_prev) * x_prev\n",
    "        x /= np.linalg.norm(x)\n",
    "\n",
    "\n",
    "data = []\n",
    "labels = []\n",
    "train_cluster_idx = [[] for x in range(CLUSTER_NUM)]\n",
    "test_cluster_idx = [[] for x in range(CLUSTER_NUM)]\n",
    "\n",
    "for i in range(DATA_NUM*2):\n",
    "    y = np.random.choice([-1,1], 1)[0]\n",
    "    k = np.random.choice(list(range(0,CLUSTER_NUM)))\n",
    "\n",
    "    if i < DATA_NUM:\n",
    "        train_cluster_idx[k].append(i)\n",
    "    else:\n",
    "        test_cluster_idx[k].append(i-DATA_NUM)\n",
    "\n",
    "    # Noise patch\n",
    "    xi = torch.tensor(np.random.normal(0, 1/math.sqrt(PATCH_LEN), size=(PATCH_LEN)))\n",
    "\n",
    "    # Feature noise patch\n",
    "    pos_or_neg = np.random.choice([-1,1], 1)[0]\n",
    "    k_noise = np.random.choice(list(set(range(0,CLUSTER_NUM))-set([int(k)])))\n",
    "    alpha, beta, gamma = np.random.uniform(0.5,2), np.random.uniform(1,2), np.random.uniform(0.5,3)\n",
    "\n",
    "    x = torch.stack([features[k]*y*alpha, centers[k]*beta, xi,\n",
    "                     pos_or_neg*features[k_noise]*gamma])\n",
    "\n",
    "    # random permutation\n",
    "    idx = torch.randperm(len(x))\n",
    "    x = x[idx].flatten()\n",
    "\n",
    "    data.append(x)\n",
    "    labels.append(y)\n",
    "\n",
    "data = torch.stack(data)\n",
    "print(data.shape)\n",
    "\n",
    "labels = torch.tensor(labels)\n",
    "labels[labels==-1] = 0\n",
    "print(labels.shape)\n",
    "\n",
    "training_data = data[:DATA_NUM,:].unsqueeze(1).float()\n",
    "test_data = data[DATA_NUM::].unsqueeze(1).float()\n",
    "print(training_data.shape, test_data.shape)\n",
    "\n",
    "training_labels = labels[:DATA_NUM]\n",
    "test_labels = labels[DATA_NUM:]\n",
    "print(training_labels.shape, test_labels.shape)\n",
    "\n",
    "training_data, test_data = training_data, test_data\n",
    "training_labels, test_labels = training_labels, test_labels\n",
    "training_labels = training_labels.long()\n",
    "test_labels = test_labels.long()\n",
    "centers, features = centers, features\n",
    "training_data *= 10\n",
    "test_data *= 10\n",
    "# Below does not affect training, just for plotting\n",
    "centers *= 10\n",
    "features *= 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "85455a9a-fc5e-4695-8166-75c49aa79b13",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "if(not os.path.exists(\"./synthetic_data_s1\")):\n",
    "  os.mkdir(\"./synthetic_data_s1\")\n",
    "\n",
    "torch.save(training_data, './synthetic_data_s1/train_data.pt')\n",
    "torch.save(training_labels, './synthetic_data_s1/train_labels.pt')\n",
    "\n",
    "torch.save(test_data, './synthetic_data_s1/test_data.pt')\n",
    "torch.save(test_labels, './synthetic_data_s1/test_labels.pt')\n",
    "\n",
    "torch.save(centers, './synthetic_data_s1/centers.pt')\n",
    "torch.save(features, './synthetic_data_s1/features.pt')\n",
    "\n",
    "with open(\"synthetic_data_s1/train_cluster\", \"wb\") as fp:\n",
    "    pickle.dump(train_cluster_idx,fp)\n",
    "\n",
    "with open(\"synthetic_data_s1/test_cluster\", \"wb\") as fp:\n",
    "    pickle.dump(test_cluster_idx,fp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e598ebfb-8343-48d7-b414-1b405284f800",
   "metadata": {},
   "source": [
    "# Verification of experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5adadd35-1ce2-4515-a19c-47f74d8c4d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from math import log, e\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import random\n",
    "import torch.autograd as autograd\n",
    "\n",
    "torch.manual_seed(11)\n",
    "np.random.seed(11)\n",
    "random.seed(11)\n",
    "\n",
    "plt.rcParams.update({'font.size': 13})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11cc0f84-f4fe-4e4c-91f4-3cf384b5a774",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#utils\n",
    "\n",
    "def entropy(dispatch, base=None):\n",
    "    \"\"\" Computes entropy of label distribution. \"\"\"\n",
    "    n_expert = torch.sum(dispatch, axis=0)\n",
    "    n_total = torch.sum(dispatch)\n",
    "\n",
    "    prob = dispatch / n_expert\n",
    "    ent = - torch.nansum(prob*torch.log(prob), axis=0)\n",
    "    ent = torch.sum((n_expert / n_total) * ent)\n",
    "\n",
    "    return ent\n",
    "\n",
    "def plot_expert_acc(expert_acc,uplim=0.5):\n",
    "    fig, ax = plt.subplots(nrows=4, ncols=2,figsize=(15,15))\n",
    "    i = 0\n",
    "    colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3']\n",
    "\n",
    "    for row in ax:\n",
    "        for col in row:\n",
    "            col.set_ylim([0,torch.max(expert_acc)+uplim])\n",
    "            if i == EXPERT_NUM:\n",
    "                break\n",
    "            for j in range(CLUSTER_NUM):\n",
    "                col.plot(expert_acc[i][j], label='cluser '+str(j+1), color=colors[j])\n",
    "            if i == 0:\n",
    "                col.legend(loc='upper left')\n",
    "            col.title.set_text('Expert %d'%(i+1))\n",
    "            i+=1\n",
    "\n",
    "    plt.plot()\n",
    "\n",
    "\n",
    "def plot_router_acc(expert_acc,uplim=0.5):\n",
    "    fig, ax = plt.subplots(nrows=4, ncols=2,figsize=(15,15))\n",
    "    i = 0\n",
    "    expert_acc = torch.tensor(expert_acc).squeeze(-1)\n",
    "    colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3']\n",
    "\n",
    "    for row in ax:\n",
    "        for col in row:\n",
    "            col.set_ylim([0,torch.max(expert_acc)+uplim])\n",
    "            if i == EXPERT_NUM:\n",
    "                break\n",
    "            for j in range(CLUSTER_NUM):\n",
    "                col.plot(expert_acc[j,:,i], label='cluser '+str(j+1), color=colors[j])\n",
    "            if i == 0:\n",
    "                col.legend(loc='upper left')\n",
    "            col.title.set_text('Theta %d'%(i+1))\n",
    "            i+=1\n",
    "\n",
    "    plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "db8bb51b-bfac-42bf-bea4-0efbafc16b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, input_dim, out_channel, patch_num, small=True, nonlinear=True):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(1, out_channel*2, int(input_dim/patch_num), int(input_dim/patch_num))        \n",
    "        # small initialization\n",
    "        if small:\n",
    "            self.conv1.weight = torch.nn.Parameter(self.conv1.weight*0.001) \n",
    "            self.conv1.bias = torch.nn.Parameter(self.conv1.bias*0.001) \n",
    "        self.out_channel = out_channel\n",
    "        self.nonlinear = nonlinear\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        if self.nonlinear:\n",
    "            x = x**3\n",
    "        x = torch.sum(x,2)\n",
    "        output = torch.stack([torch.sum(x[:,:self.out_channel],1), torch.sum(x[:,self.out_channel:],1)]).transpose(1,0)\n",
    "        return output\n",
    "    \n",
    "\n",
    "# top 1 hard routing\n",
    "def top1(t):\n",
    "    values, index = t.topk(k=1, dim=-1)\n",
    "    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))\n",
    "    return values, index\n",
    "\n",
    "\n",
    "def cumsum_exclusive(t, dim=-1):\n",
    "    num_dims = len(t.shape)\n",
    "    num_pad_dims = - dim - 1\n",
    "    pre_padding = (0, 0) * num_pad_dims\n",
    "    pre_slice   = (slice(None),) * num_pad_dims\n",
    "    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)\n",
    "    return padded_t[(..., slice(None, -1), *pre_slice)]\n",
    "\n",
    "\n",
    "def safe_one_hot(indexes, max_length):\n",
    "    max_index = indexes.max() + 1\n",
    "    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]\n",
    "\n",
    "\n",
    "class Router(nn.Module):\n",
    "    def __init__(self, input_dim, out_dim, patch_num, noise=True):\n",
    "        super(Router, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(1, out_dim, int(input_dim/patch_num), int(input_dim/patch_num),bias=False)\n",
    "        self.out_dim = out_dim\n",
    "        self.break_tie_noise = torch.normal(0,1e-5,size=(DATA_NUM, EXPERT_NUM))\n",
    "        self.noise = noise\n",
    "        # zero initialization\n",
    "        self.reset_parameters()\n",
    "    \n",
    "    def reset_parameters(self):\n",
    "        self.conv1.weight = torch.nn.Parameter(self.conv1.weight * 0)\n",
    "            \n",
    "    def forward(self, x):      \n",
    "        x = self.conv1(x)\n",
    "        x = torch.sum(x,2)\n",
    "        if self.noise and self.training:\n",
    "            output = x + torch.rand(DATA_NUM, EXPERT_NUM) \n",
    "        elif self.training:\n",
    "            output = x + self.break_tie_noise\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6dc774a0-8566-43cc-89c4-5ac4c5318f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MoE(nn.Module):\n",
    "    def __init__(self, input_dim, out_channel, patch_num, expert_num, strategy='top1', nonlinear=True):\n",
    "        super(MoE, self).__init__()\n",
    "        self.router = Router(input_dim, expert_num, patch_num)\n",
    "        self.models = nn.ModuleList()\n",
    "        for i in range(expert_num):\n",
    "            self.models.append(ConvNet(input_dim, out_channel, patch_num, nonlinear=nonlinear))\n",
    "        self.strategy = strategy\n",
    "        self.expert_num = expert_num\n",
    "\n",
    "    def forward(self, x):\n",
    "        select = self.router(x)\n",
    "        # top 1 or choose 1 according to probability\n",
    "        if self.strategy == 'top1':\n",
    "            gate, index = top1(select)\n",
    "        else:\n",
    "            gate, index = choose1(select)\n",
    "        \n",
    "        mask = F.one_hot(index, self.expert_num).float()\n",
    "\n",
    "        density = mask.mean(dim=-2)\n",
    "        density_proxy = select.mean(dim=-2)\n",
    "        loss = (density_proxy * density).mean() * float(self.expert_num ** 2)\n",
    "\n",
    "        mask_count = mask.sum(dim=-2, keepdim=True)\n",
    "        mask_flat = mask.sum(dim=-1)\n",
    "\n",
    "        combine_tensor = (gate[..., None, None] * mask_flat[..., None, None]\n",
    "                          * F.one_hot(index, self.expert_num)[..., None])\n",
    "                          \n",
    "        dispatch_tensor = combine_tensor.bool().to(combine_tensor)\n",
    "        select0 = dispatch_tensor.squeeze(-1)\n",
    "        \n",
    "        expert_inputs = torch.einsum('bnd,ben->ebd', x, dispatch_tensor).unsqueeze(2)\n",
    "        \n",
    "        output = []\n",
    "        for i in range(self.expert_num):\n",
    "            output.append(self.models[i](expert_inputs[i]))\n",
    "        \n",
    "        output = torch.stack(output)\n",
    "        output = torch.einsum('ijk,jil->il', combine_tensor, output)\n",
    "        output = F.softmax(output,dim=1)\n",
    "\n",
    "        return output, select0, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "990847ff-cf01-4237-8f65-c20294fa3fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_single(model, criterion, data, labels, optimizers, epochs):\n",
    "    \n",
    "    min_loss = float('inf')\n",
    "    \n",
    "    for epoch in range(epochs):  \n",
    "        for optimizer in optimizers:\n",
    "            optimizer.zero_grad()\n",
    "        outputs = model(data) \n",
    "        loss = criterion(outputs, labels) \n",
    "        \n",
    "        if loss.item() <= min_loss:\n",
    "            min_loss = loss.item()\n",
    "        elif epoch > 500 and loss > min_loss+0.02:\n",
    "            break\n",
    "        \n",
    "        loss.backward() \n",
    "                \n",
    "        for optimizer in optimizers:\n",
    "            optimizer.step()\n",
    "        \n",
    "        if epoch%100 == 0:   \n",
    "            print('Epoch %d --- loss: %.3f' %\n",
    "                    (epoch + 1, loss.item()))\n",
    "    print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2f1a91cd-cf9f-4404-bb7b-bba22d0ad3d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, criterion, data, labels, optimizers, epochs, \n",
    "          plot=False, load_balancing=False, verbose=True, early_stopping=True):\n",
    "    \n",
    "    expert_acc_train = [[[] for x in range(CLUSTER_NUM)] for x in range(EXPERT_NUM)]\n",
    "    expert_inner_train = [[[] for x in range(CLUSTER_NUM)] for x in range(EXPERT_NUM)]\n",
    "    \n",
    "    router_acc_train = [[] for x in range(CLUSTER_NUM)]\n",
    "    router_inner_train = [[] for x in range(CLUSTER_NUM)]\n",
    "\n",
    "    entropy_record = []\n",
    "    min_loss = float('inf')\n",
    "\n",
    "    for epoch in range(epochs):  \n",
    "        \n",
    "        for optimizer in optimizers:\n",
    "            optimizer.zero_grad()\n",
    "        outputs, select0, load_balancing_loss = model(data)\n",
    "        \n",
    "        e = entropy(torch.stack([select0[train_cluster_idx[0]].squeeze(-1).sum(dim=0),\n",
    "                                 select0[train_cluster_idx[1]].squeeze(-1).sum(dim=0),\n",
    "                                 select0[train_cluster_idx[2]].squeeze(-1).sum(dim=0),\n",
    "                                 select0[train_cluster_idx[3]].squeeze(-1).sum(dim=0)]))\n",
    "        entropy_record.append(e)\n",
    "        \n",
    "        if load_balancing:\n",
    "            loss = criterion(outputs, labels) + 0.0001 * load_balancing_loss\n",
    "        else:\n",
    "            loss = criterion(outputs, labels)\n",
    "            \n",
    "        if early_stopping:\n",
    "            if loss.item() <= min_loss:\n",
    "                min_loss = loss.item()\n",
    "            elif loss > min_loss+0.02 or loss <= 0.314:\n",
    "                break\n",
    "        loss.backward() \n",
    "                \n",
    "        for optimizer in optimizers:\n",
    "            optimizer.step()\n",
    "        \n",
    "        if epoch%100 == 0:\n",
    "            if verbose:\n",
    "                print('Epoch %d --- loss: %.3f' % (epoch + 1, loss.item()))\n",
    "            if plot:\n",
    "                acc_list,inner_list,router_flist,router_clist = test_each_expert(model, criterion, training_data, training_labels, datatype='training')\n",
    "                for each_cluster in range(CLUSTER_NUM):\n",
    "\n",
    "                    router_acc_train[each_cluster].append(router_flist[each_cluster])\n",
    "                    router_inner_train[each_cluster].append(router_clist[each_cluster])\n",
    "\n",
    "                    for each_expert in range(EXPERT_NUM):\n",
    "                        expert_acc_train[each_expert][each_cluster].append(acc_list[each_expert][each_cluster])\n",
    "                        expert_inner_train[each_expert][each_cluster].append(inner_list[each_expert][each_cluster])\n",
    "            \n",
    "    print('Finished Training')\n",
    "    return expert_acc_train,expert_inner_train,router_acc_train,router_inner_train, select0, entropy_record\n",
    "\n",
    "\n",
    "def test_single(model, criterion, data, labels):\n",
    "    correct = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model(data) # ,_\n",
    "        predicted = torch.max(outputs.data, 1).indices\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Accuracy of the network on the %d test images: %.4f %%' % (data.shape[0],\n",
    "        100 * correct / data.shape[0]))\n",
    "    \n",
    "    return 100 * correct / data.shape[0]\n",
    "    \n",
    "\n",
    "def test(model, criterion, data, labels, verbose=True):\n",
    "    correct = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs,_,_ = model(data) # ,_\n",
    "        predicted = torch.max(outputs.data, 1).indices\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    if verbose:\n",
    "        print('Accuracy of the network on the %d test images: %.4f %%' % (data.shape[0],\n",
    "            100 * correct / data.shape[0]))\n",
    "    \n",
    "    return 100 * correct / data.shape[0]\n",
    "    \n",
    "    \n",
    "def test_expert(model, criterion, data, labels, cluster=0, datatype='training', verbose=False):\n",
    "    if datatype=='training':\n",
    "        cluster_idx = train_cluster_idx\n",
    "    else:\n",
    "        cluster_idx = test_cluster_idx\n",
    "        \n",
    "    data = data[cluster_idx[cluster],:,:]\n",
    "    labels = labels[cluster_idx[cluster]]\n",
    "    expert_acc = []\n",
    "    with torch.no_grad():\n",
    "        for i in range(model.expert_num):\n",
    "            correct = 0\n",
    "            outputs = F.softmax(model.models[i](data),dim=1)\n",
    "            predicted = torch.max(outputs.data, 1).indices\n",
    "            correct += (predicted == labels).sum().item()\n",
    "            acc = 100 * correct / data.shape[0]\n",
    "            expert_acc.append(acc)\n",
    "            if verbose:\n",
    "                print('Accuracy of the %d expert on cluster %d with %d examples: %.4f %%' % (i, cluster, data.shape[0],\n",
    "                                                                                             acc))\n",
    "    return expert_acc\n",
    "        \n",
    "    \n",
    "def test_each_expert(model, criterion, data, labels, datatype='training', verbose=False):\n",
    "    expert_feature = [[] for x in range(EXPERT_NUM)]\n",
    "    expert_center = [[] for x in range(EXPERT_NUM)]\n",
    "    router_feature, router_center = test_router_inner(model)\n",
    "\n",
    "    for i in range(CLUSTER_NUM):\n",
    "        feat, cent = test_expert_inner(model, cluster=i)\n",
    "        \n",
    "        for each in range(EXPERT_NUM):\n",
    "            expert_feature[each].append(feat[each].cpu())\n",
    "            expert_center[each].append(cent[each].cpu())\n",
    "            \n",
    "    return expert_feature, expert_center, router_feature, router_center\n",
    "\n",
    "\n",
    "def test_expert_inner(model, cluster=0):\n",
    "    expert_fea = []\n",
    "    expert_cen = []\n",
    "    with torch.no_grad():\n",
    "        for i in range(model.expert_num):\n",
    "            feature_inner = torch.max(torch.abs(torch.matmul(model.models[i].conv1.weight.squeeze(1), \n",
    "                                         features[[cluster]].float().transpose(1,0))))\n",
    "            center_inner = torch.max(torch.abs(torch.matmul(model.models[i].conv1.weight.squeeze(1), \n",
    "                                         centers[[cluster]].float().transpose(1,0))))\n",
    "            expert_fea.append(feature_inner)\n",
    "            expert_cen.append(center_inner)\n",
    "    return expert_fea, expert_cen\n",
    "\n",
    "\n",
    "def test_router_inner(model):\n",
    "    router_feature = []\n",
    "    router_center = []\n",
    "    with torch.no_grad():\n",
    "        for cluster in range(CLUSTER_NUM):\n",
    "            ### torch.max\n",
    "            feature_inner = torch.abs(torch.matmul(model.router.conv1.weight.squeeze(1), \n",
    "                                         features[[cluster]].float().transpose(1,0)))\n",
    "            center_inner = torch.abs(torch.matmul(model.router.conv1.weight.squeeze(1), \n",
    "                                         centers[[cluster]].float().transpose(1,0)))\n",
    "            router_feature.append(feature_inner.cpu().tolist())\n",
    "            router_center.append(center_inner.cpu().tolist())\n",
    "    return router_feature, router_center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7155e82e-dc74-4efa-80dc-076bac408db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_NUM = 16000\n",
    "CLUSTER_NUM = 4\n",
    "EXPERT_NUM = 8\n",
    "PATCH_NUM = 4\n",
    "PATCH_LEN = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ff8e7293-0f4e-4cf0-91b0-980f264c3ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  training_data = torch.load('synthetic_data_s1/train_data.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  training_labels = torch.load('synthetic_data_s1/train_labels.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  test_data = torch.load('synthetic_data_s1/test_data.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  test_labels = torch.load('synthetic_data_s1/test_labels.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  centers = torch.load('synthetic_data_s1/centers.pt')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1703678973.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  features = torch.load('synthetic_data_s1/features.pt')\n"
     ]
    }
   ],
   "source": [
    "training_data = torch.load('synthetic_data_s1/train_data.pt')\n",
    "training_labels = torch.load('synthetic_data_s1/train_labels.pt')\n",
    "\n",
    "test_data = torch.load('synthetic_data_s1/test_data.pt')\n",
    "test_labels = torch.load('synthetic_data_s1/test_labels.pt')\n",
    "\n",
    "centers = torch.load('synthetic_data_s1/centers.pt')\n",
    "features = torch.load('synthetic_data_s1/features.pt')\n",
    "\n",
    "with open(\"synthetic_data_s1/train_cluster\", \"rb\") as fp:  \n",
    "    train_cluster_idx = pickle.load(fp)\n",
    "    \n",
    "with open(\"synthetic_data_s1/test_cluster\", \"rb\") as fp:  \n",
    "    test_cluster_idx = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f90441d4-eda9-45db-800f-79b79f885081",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.optimizer import Optimizer, required\n",
    "from torch.optim import _functional\n",
    "\n",
    "class NormalizedGD(Optimizer):\n",
    "    def __init__(self, params, lr=required, momentum=0, dampening=0,\n",
    "                 weight_decay=0, nesterov=False):\n",
    "        if lr is not required and lr < 0.0:\n",
    "            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n",
    "        if momentum < 0.0:\n",
    "            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n",
    "        if weight_decay < 0.0:\n",
    "            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n",
    "\n",
    "        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,\n",
    "                        weight_decay=weight_decay, nesterov=nesterov)\n",
    "        if nesterov and (momentum <= 0 or dampening != 0):\n",
    "            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n",
    "        super(NormalizedGD, self).__init__(params, defaults)\n",
    "\n",
    "    def __setstate__(self, state):\n",
    "        super(NormalizedGD, self).__setstate__(state)\n",
    "        for group in self.param_groups:\n",
    "            group.setdefault('nesterov', False)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def step(self, closure=None):\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            with torch.enable_grad():\n",
    "                loss = closure()\n",
    "\n",
    "        for group in self.param_groups:\n",
    "            params_with_grad = []\n",
    "            d_p_list = []\n",
    "            momentum_buffer_list = []\n",
    "            weight_decay = group['weight_decay']\n",
    "            momentum = group['momentum']\n",
    "            dampening = group['dampening']\n",
    "            nesterov = group['nesterov']\n",
    "            lr = group['lr']\n",
    "            \n",
    "            per_expert_num = int(len(group['params'])/EXPERT_NUM)\n",
    "            per_expert_norm = [0 for i in range(EXPERT_NUM)]\n",
    "            for i in range(EXPERT_NUM):\n",
    "                for j in range(i*per_expert_num,(i+1)*per_expert_num):\n",
    "                    p = group['params'][j]\n",
    "                    if p.grad is not None:\n",
    "                        per_expert_norm[i] += p.grad.norm()\n",
    "\n",
    "            for idx,p in enumerate(group['params']):\n",
    "                if p.grad is not None:\n",
    "                    # Normalizing \n",
    "                    if per_expert_norm[idx // per_expert_num] != 0:\n",
    "                        p.grad /= per_expert_norm[idx // per_expert_num]\n",
    "                        \n",
    "                    params_with_grad.append(p)\n",
    "                    d_p_list.append(p.grad)\n",
    "\n",
    "                    state = self.state[p]\n",
    "                    if 'momentum_buffer' not in state:\n",
    "                        momentum_buffer_list.append(None)\n",
    "                    else:\n",
    "                        momentum_buffer_list.append(state['momentum_buffer'])\n",
    "\n",
    "            _functional.sgd(params_with_grad,\n",
    "                  d_p_list,\n",
    "                  momentum_buffer_list,\n",
    "                  weight_decay=weight_decay,\n",
    "                  momentum=momentum,\n",
    "                  lr=lr,\n",
    "                  dampening=dampening,\n",
    "                  nesterov=nesterov,\n",
    "                  maximize=False)\n",
    "\n",
    "            # update momentum_buffers in state\n",
    "            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):\n",
    "                state = self.state[p]\n",
    "                state['momentum_buffer'] = momentum_buffer\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "29d7c42b-a230-4834-8add-4f1520ae9812",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 --- loss: 8.773\n",
      "Epoch 101 --- loss: 0.528\n",
      "Epoch 201 --- loss: 0.555\n",
      "Epoch 301 --- loss: 0.550\n",
      "Epoch 401 --- loss: 0.545\n",
      "Epoch 501 --- loss: 0.556\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 69.6875 %\n",
      "Accuracy of the network on the 16000 test images: 69.3688 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "69.36875"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_epochs = 801\n",
    "\n",
    "linear_single = ConvNet(200, 40, PATCH_NUM, small=False, nonlinear=False)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer =  torch.optim.Adam(linear_single.parameters(), lr=0.003, weight_decay=5e-4) \n",
    "train_single(linear_single, criterion, training_data, training_labels, [optimizer], num_epochs)\n",
    "\n",
    "test_single(linear_single, criterion, training_data, training_labels)\n",
    "test_single(linear_single, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c3b7e5de-fc5f-456f-84a7-c6e343b0025c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 --- loss: 45.245\n",
      "Epoch 101 --- loss: 1.825\n",
      "Epoch 201 --- loss: 0.739\n",
      "Epoch 301 --- loss: 0.383\n",
      "Epoch 401 --- loss: 0.360\n",
      "Epoch 501 --- loss: 0.245\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 88.5062 %\n",
      "Accuracy of the network on the 16000 test images: 79.7250 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "79.725"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_epochs = 801\n",
    "\n",
    "nonlinear_single = ConvNet(200, 40, PATCH_NUM, small=False, nonlinear=True)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer =  torch.optim.Adam(nonlinear_single.parameters(), lr=0.01, weight_decay=5e-4) \n",
    "train_single(nonlinear_single, criterion, training_data, training_labels, \n",
    "                                                           [optimizer], num_epochs)\n",
    "\n",
    "test_single(nonlinear_single, criterion, training_data, training_labels)\n",
    "test_single(nonlinear_single, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "3a4700c6-a470-4361-a19e-8c64da7b6263",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2015., 1975., 2004., 2034., 2046., 1973., 1934., 2019.])\n",
      "Epoch 1 --- loss: 0.693\n",
      "Epoch 101 --- loss: 0.337\n",
      "Epoch 201 --- loss: 0.313\n",
      "Finished Training\n",
      "tensor([   0.,    0., 4035.,    0.,    0., 3960., 3999., 4006.])\n",
      "Accuracy of the network on the 16000 test images: 100.0000 %\n",
      "Accuracy of the network on the 16000 test images: 100.0000 %\n",
      "tensor([   0,    0, 4035,    0,    0,    0,    0,    0])\n",
      "tensor([   0,    0,    0,    0,    0,    0, 3999,    0])\n",
      "tensor([   0,    0,    0,    0,    0,    0,    0, 4006])\n",
      "tensor([   0,    0,    0,    0,    0, 3960,    0,    0])\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 601\n",
    "\n",
    "nonlinear_mixture = MoE(200, 8, PATCH_NUM, EXPERT_NUM, strategy='top1', nonlinear=True) #input_dim, out_channel (m), cluter_num, patch_num\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "_, s, _ = nonlinear_mixture(training_data)\n",
    "print(s.squeeze(-1).sum(dim=0))\n",
    "\n",
    "optimizer = NormalizedGD(nonlinear_mixture.models.parameters(), lr=0.001)\n",
    "optimizer2 = torch.optim.SGD(nonlinear_mixture.router.parameters(), lr=0.1)\n",
    "\n",
    "expert_feat, expert_cent, router_feat, router_cent, select, _ = train(nonlinear_mixture, criterion, training_data, training_labels, \n",
    "                                                           [optimizer,optimizer2], num_epochs, plot=True)\n",
    "print(select.squeeze(-1).sum(dim=0))\n",
    "\n",
    "test(nonlinear_mixture, criterion, training_data, training_labels)\n",
    "test(nonlinear_mixture, criterion, test_data, test_labels)\n",
    "\n",
    "print(select[train_cluster_idx[0]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[1]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[2]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[3]].squeeze(-1).sum(dim=0).to(torch.long))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "3c1838df-4317-487e-bb35-67b490149e64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 100.0000 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.6750 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.2625 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 100.0000 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.5563 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.6000 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 98.9375 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.4125 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.1063 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 99.9813 %\n",
      "\n",
      "Average test accuracy: 99.55.\n",
      "Standard deviation: 0.36\n",
      "\n",
      "Average dispatch entropy 0.06199999898672104.\n",
      "Standard deviation: 0.050999999046325684\n"
     ]
    }
   ],
   "source": [
    "acc_list, ent_list = [], []\n",
    "entropy_records = []\n",
    "\n",
    "for i in range(10):    \n",
    "    num_epochs = 501\n",
    "\n",
    "    nonlinear_mixture = MoE(200, 8, PATCH_NUM, EXPERT_NUM, strategy='top1', nonlinear=True) #input_dim, out_channel (m), cluter_num, patch_num\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    \n",
    "    optimizer = NormalizedGD(nonlinear_mixture.models.parameters(), lr=0.001)\n",
    "    optimizer2 = torch.optim.SGD(nonlinear_mixture.router.parameters(), lr=0.1)\n",
    "\n",
    "    _, _, _, _, _, entropy_record = train(nonlinear_mixture, criterion, training_data, training_labels, \n",
    "                                          [optimizer,optimizer2], num_epochs, plot=False, verbose=False)\n",
    "    ent_list.append(entropy_record[-1])\n",
    "\n",
    "    acc = test(nonlinear_mixture, criterion, test_data, test_labels)\n",
    "    acc_list.append(acc)\n",
    "    \n",
    "    entropy_records.append(torch.stack(entropy_record))\n",
    "    \n",
    "print()\n",
    "print(f\"Average test accuracy: {round(np.mean(acc_list),2)}.\")\n",
    "print(f\"Standard deviation: {round(np.std(acc_list),2)}\")\n",
    "print()\n",
    "print(f\"Average dispatch entropy {round(np.mean(torch.stack(ent_list).cpu().numpy()),3)}.\")\n",
    "print(f\"Standard deviation: {round(np.std(torch.stack(ent_list).cpu().numpy()),3)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "552147a6-46ed-47ad-b368-ea978aed554e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2041., 1973., 1984., 2020., 2005., 2059., 1960., 1958.])\n",
      "Epoch 1 --- loss: 0.693\n",
      "Epoch 101 --- loss: 0.375\n",
      "Epoch 201 --- loss: 0.363\n",
      "Epoch 301 --- loss: 0.361\n",
      "Epoch 401 --- loss: 0.360\n",
      "Epoch 501 --- loss: 0.360\n",
      "Epoch 601 --- loss: 0.360\n",
      "Finished Training\n",
      "tensor([0.0000e+00, 4.0280e+03, 0.0000e+00, 3.0000e+00, 4.8140e+03, 5.0740e+03,\n",
      "        2.0810e+03, 0.0000e+00])\n",
      "Accuracy of the network on the 16000 test images: 95.4750 %\n",
      "Accuracy of the network on the 16000 test images: 94.0438 %\n",
      "tensor([   0, 1198,    0,    3, 1202, 1176,  456,    0])\n",
      "tensor([   0, 1385,    0,    0, 1142, 1224,  248,    0])\n",
      "tensor([   0,  773,    0,    0, 1591, 1208,  434,    0])\n",
      "tensor([   0,  672,    0,    0,  879, 1466,  943,    0])\n"
     ]
    }
   ],
   "source": [
    "linear_mixture = MoE(200, 8, PATCH_NUM, EXPERT_NUM, strategy='top1', nonlinear=False)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "_, s, _ = linear_mixture(training_data)\n",
    "print(s.squeeze(-1).sum(dim=0))\n",
    "\n",
    "optimizer = NormalizedGD(linear_mixture.models.parameters(), lr=0.001)\n",
    "optimizer2 = torch.optim.SGD(linear_mixture.router.parameters(), lr=0.1) \n",
    "\n",
    "expert_feat, expert_cent, router_feat, router_cent, select, _ = train(linear_mixture, criterion, training_data, training_labels, \n",
    "                                                           [optimizer,optimizer2], num_epochs, plot=True)\n",
    "print(select.squeeze(-1).sum(dim=0))\n",
    "\n",
    "test(linear_mixture, criterion, training_data, training_labels)\n",
    "test(linear_mixture, criterion, test_data, test_labels)\n",
    "\n",
    "print(select[train_cluster_idx[0]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[1]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[2]].squeeze(-1).sum(dim=0).to(torch.long))\n",
    "print(select[train_cluster_idx[3]].squeeze(-1).sum(dim=0).to(torch.long))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "18e32a8d-6c42-48b0-8004-ee775a9f2b47",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 95.0375 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 88.8500 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 93.5500 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 95.1125 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 93.7750 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 93.8937 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 94.7375 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 94.5563 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 94.9688 %\n",
      "Finished Training\n",
      "Accuracy of the network on the 16000 test images: 93.7500 %\n",
      "\n",
      "Average test accuracy: 93.82.\n",
      "Standard deviation: 1.75\n",
      "\n",
      "Average dispatch entropy 1.2760000228881836.\n",
      "Standard deviation: 0.061000000685453415\n"
     ]
    }
   ],
   "source": [
    "acc_list, ent_list = [], []\n",
    "entropy_records = []\n",
    "\n",
    "for i in range(10):    \n",
    "    num_epochs = 501\n",
    "\n",
    "    linear_mixture = MoE(200, 8, PATCH_NUM, EXPERT_NUM, strategy='top1', nonlinear=False)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    optimizer = NormalizedGD(linear_mixture.models.parameters(), lr=0.001)\n",
    "    optimizer2 = torch.optim.SGD(linear_mixture.router.parameters(), lr=0.1) \n",
    "\n",
    "    _, _, _, _, _, entropy_record = train(linear_mixture, criterion, training_data, training_labels, \n",
    "                                          [optimizer,optimizer2], num_epochs, plot=False, verbose=False)\n",
    "    ent_list.append(entropy_record[-1])\n",
    "\n",
    "    acc = test(linear_mixture, criterion, test_data, test_labels)\n",
    "    acc_list.append(acc)\n",
    "    \n",
    "    entropy_records.append(torch.stack(entropy_record))\n",
    "    \n",
    "print()\n",
    "print(f\"Average test accuracy: {round(np.mean(acc_list),2)}.\")\n",
    "print(f\"Standard deviation: {round(np.std(acc_list),2)}\")\n",
    "print()\n",
    "print(f\"Average dispatch entropy {round(np.mean(torch.stack(ent_list).cpu().numpy()),3)}.\")\n",
    "print(f\"Standard deviation: {round(np.std(torch.stack(ent_list).cpu().numpy()),3)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "0eca999d-79da-45d0-b438-4842a646543a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(nonlinear_mixture, 'non_linear_moe.pth')\n",
    "torch.save(linear_mixture, 'linear_moe.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "0c1b4d8f-3b3e-46c2-9aac-3bd5f541cf37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([800, 8])\n",
      "torch.Size([16, 8])\n",
      "tensor(0.0395)\n",
      "tensor(-0.0413)\n",
      "tensor(0.0210)\n",
      "tensor(-0.0255)\n"
     ]
    }
   ],
   "source": [
    "target_wt = []\n",
    "for i in range(8):\n",
    "    wt_flatten = linear_mixture.models[i].conv1.weight.data.view(-1)\n",
    "    target_wt.append(wt_flatten)\n",
    "target_wt = torch.stack(target_wt, dim=0)\n",
    "target_wt = target_wt.T\n",
    "print(target_wt.shape)\n",
    "\n",
    "target_b = []\n",
    "for i in range(8):\n",
    "    b_flatten = linear_mixture.models[i].conv1.bias.data.view(-1)\n",
    "    target_b.append(b_flatten)\n",
    "target_b = torch.stack(target_b, dim=0)\n",
    "target_b = target_b.T\n",
    "print(target_b.shape)\n",
    "\n",
    "print(torch.max(target_wt))\n",
    "print(torch.min(target_wt))\n",
    "print(torch.max(target_b))\n",
    "print(torch.min(target_b))"
   ]
  },
  {
   "cell_type": "raw",
   "id": "44fc5ec9-3fde-44aa-a39b-c451e4114b24",
   "metadata": {},
   "source": [
    "The range of values of the weights and biases of the Mixture of experts seem to be very small and closer to epsilon=0.01. It could either be the case that:\n",
    "(i). They aren't relevant for the final accuracy and can be zeroed out.\n",
    "(ii). The scale of parameters must be adjusted to fit the (-1,1) range that we assume. Scaling is possible if no non-linearity activations exist(which is indeed the case).\n",
    "\n",
    "First verify which of these alternatives preserves the accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 212,
   "id": "1bbe796f-355f-4c04-b061-e72afbe5bc72",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\220291039.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 16000 test images: 95.4562 %\n",
      "Accuracy of the network on the 16000 test images: 94.0938 %\n",
      "Accuracy of the network on the 16000 test images: 74.0438 %\n",
      "Accuracy of the network on the 16000 test images: 73.8187 %\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\220291039.py:18: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 16000 test images: 95.5187 %\n",
      "Accuracy of the network on the 16000 test images: 94.1188 %\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0038), tensor(-0.0037))\n",
      "Bias range: (tensor(0.0008), tensor(-0.0008))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.0121), tensor(-0.0121))\n",
      "Bias range: (tensor(0.0210), tensor(-0.0210))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0025), tensor(-0.0024))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.0395), tensor(-0.0413))\n",
      "Bias range: (tensor(0.0069), tensor(-0.0255))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0065), tensor(-0.0064))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0046), tensor(-0.0047))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.0114), tensor(-0.0115))\n",
      "Bias range: (tensor(0.0002), tensor(-0.0002))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0026), tensor(-0.0026))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      " GLOBAL RESCALING\n",
      "GLOBAL RESCALING DONE\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0375), tensor(-0.0368))\n",
      "Bias range: (tensor(0.0082), tensor(-0.0081))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.1209), tensor(-0.1209))\n",
      "Bias range: (tensor(0.2096), tensor(-0.2101))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0247), tensor(-0.0244))\n",
      "Bias range: (tensor(0.0049), tensor(-0.0052))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.3950), tensor(-0.4134))\n",
      "Bias range: (tensor(0.0687), tensor(-0.2545))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.1202), tensor(-0.1201))\n",
      "Bias range: (tensor(0.0650), tensor(-0.0643))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.1195), tensor(-0.1195))\n",
      "Bias range: (tensor(0.0456), tensor(-0.0468))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.1143), tensor(-0.1149))\n",
      "Bias range: (tensor(0.0019), tensor(-0.0019))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0258), tensor(-0.0257))\n",
      "Bias range: (tensor(0.0050), tensor(-0.0053))\n",
      "Accuracy of the network on the 16000 test images: 95.4250 %\n",
      "Accuracy of the network on the 16000 test images: 94.1813 %\n",
      "Zeroing out elements smaller than epsilon\n",
      "Accuracy of the network on the 16000 test images: 95.4750 %\n",
      "Accuracy of the network on the 16000 test images: 94.0250 %\n",
      "Accuracy of the network on the 16000 test images: 95.4938 %\n",
      "Accuracy of the network on the 16000 test images: 94.1750 %\n",
      "Expert:0\n",
      "Weight range: (tensor(0.0038), tensor(-0.0037))\n",
      "Bias range: (tensor(0.0008), tensor(-0.0008))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.0121), tensor(-0.0121))\n",
      "Bias range: (tensor(0.0210), tensor(-0.0210))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.0025), tensor(-0.0024))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.0395), tensor(-0.0413))\n",
      "Bias range: (tensor(0.0069), tensor(-0.0255))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0065), tensor(-0.0064))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.0120), tensor(-0.0120))\n",
      "Bias range: (tensor(0.0046), tensor(-0.0047))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.0114), tensor(-0.0115))\n",
      "Bias range: (tensor(0.0002), tensor(-0.0002))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.0026), tensor(-0.0026))\n",
      "Bias range: (tensor(0.0005), tensor(-0.0005))\n",
      " LOCAL RESCALING\n",
      "LOCAL RESCALING DONE\n",
      "Expert:0\n",
      "Weight range: (tensor(0.9000), tensor(-0.8814))\n",
      "Bias range: (tensor(0.9000), tensor(-0.8968))\n",
      "Expert:1\n",
      "Weight range: (tensor(0.8994), tensor(-0.9000))\n",
      "Bias range: (tensor(0.8979), tensor(-0.9000))\n",
      "Expert:2\n",
      "Weight range: (tensor(0.9000), tensor(-0.8891))\n",
      "Bias range: (tensor(0.8481), tensor(-0.9000))\n",
      "Expert:3\n",
      "Weight range: (tensor(0.8598), tensor(-0.9000))\n",
      "Bias range: (tensor(0.2428), tensor(-0.9000))\n",
      "Expert:4\n",
      "Weight range: (tensor(0.9000), tensor(-0.8993))\n",
      "Bias range: (tensor(0.9000), tensor(-0.8901))\n",
      "Expert:5\n",
      "Weight range: (tensor(0.8999), tensor(-0.9000))\n",
      "Bias range: (tensor(0.8772), tensor(-0.9000))\n",
      "Expert:6\n",
      "Weight range: (tensor(0.8958), tensor(-0.9000))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\220291039.py:54: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bias range: (tensor(0.8778), tensor(-0.9000))\n",
      "Expert:7\n",
      "Weight range: (tensor(0.9000), tensor(-0.8972))\n",
      "Bias range: (tensor(0.8473), tensor(-0.9000))\n",
      "Accuracy of the network on the 16000 test images: 94.1875 %\n",
      "Accuracy of the network on the 16000 test images: 93.6625 %\n",
      "Zeroing out elements smaller than epsilon\n",
      "Accuracy of the network on the 16000 test images: 94.2562 %\n",
      "Accuracy of the network on the 16000 test images: 93.6312 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "93.63125"
      ]
     },
     "execution_count": 212,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "################ SCENARIO-I ZERO-OUT SMALL MAGNITUDE PRAMETERS #################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "############### SCENARIO-II ADJUST THE SCALE OF PARAMETERS GLOBAL #####################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "\n",
    "scale = 10\n",
    "print(\" GLOBAL RESCALING\")\n",
    "for i in range(8):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "print(\"GLOBAL RESCALING DONE\")\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "print(\"Zeroing out elements smaller than epsilon\")\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "############### SCENARIO-III ADJUST THE SCALE OF PARAMETERS LOCAL #####################\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "\n",
    "print(\" LOCAL RESCALING\")\n",
    "for i in range(8):\n",
    "    max_wt = torch.max(torch.abs(lin_moe.models[i].conv1.weight.data))\n",
    "    scale = 0.9/max_wt\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    max_b = torch.max(torch.abs(lin_moe.models[i].conv1.bias.data))\n",
    "    scale = 0.9/max_b\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "print(\"LOCAL RESCALING DONE\")\n",
    "\n",
    "for i in range(8):\n",
    "    print(\"Expert:\"+str(i))\n",
    "    print(\"Weight range: (\"+ str(torch.max(lin_moe.models[i].conv1.weight.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.weight.data)) +\")\")\n",
    "    print(\"Bias range: (\"+ str(torch.max(lin_moe.models[i].conv1.bias.data)) + \", \"+ str(torch.min(lin_moe.models[i].conv1.bias.data)) +\")\")\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)\n",
    "print(\"Zeroing out elements smaller than epsilon\")\n",
    "for i in range(8):\n",
    "    wt_mask = (torch.abs(lin_moe.models[i].conv1.weight.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.weight.data *= wt_mask\n",
    "    b_mask = (torch.abs(lin_moe.models[i].conv1.bias.data.clone()) >=epsilon)*1\n",
    "    lin_moe.models[i].conv1.bias.data *= b_mask\n",
    "    \n",
    "test(lin_moe, criterion, training_data, training_labels)\n",
    "test(lin_moe, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "af063976-fb60-4d9f-9146-5baa73e80cd0",
   "metadata": {},
   "source": [
    "1. Zeroing out parameters smaller than epsilon=0.01 in the MoE model leads to a significant drop in accuracy. So it is not a viable alternative.\n",
    "2. Global rescaling of parameters with a scaling factor=10 is a viable alternative as it preserves the accuracy both after rescaling the parameters as well as then zeroing out the smaller magnitude parameters smaller than epsilon.\n",
    "\n",
    "Therefore, collect the new rescaled targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 402,
   "id": "a1bd9b19-1dcf-4719-8f1c-8de99c425075",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([800, 8])\n",
      "Max: tensor(0.3950) Min: tensor(-0.4134)\n",
      "tensor(4163)\n",
      "torch.Size([16, 8])\n",
      "Max: tensor(0.2096) Min: tensor(-0.2545)\n",
      "tensor(64)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\2268568622.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n"
     ]
    }
   ],
   "source": [
    "lin_moe = torch.load('linear_moe.pth')\n",
    "\n",
    "scale = 10\n",
    "for i in range(8):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "\n",
    "target_wt = []\n",
    "for i in range(8):\n",
    "    wt_flatten = lin_moe.models[i].conv1.weight.data.view(-1)\n",
    "    target_wt.append(wt_flatten)\n",
    "target_wt = torch.stack(target_wt, dim=0)\n",
    "target_wt = target_wt.T\n",
    "print(target_wt.shape)\n",
    "print(\"Max: \"+str(torch.max(target_wt))+\" Min: \"+str(torch.min(target_wt)))\n",
    "print(torch.sum(np.abs(target_wt)>=epsilon))\n",
    "\n",
    "target_b = []\n",
    "for i in range(8):\n",
    "    b_flatten = lin_moe.models[i].conv1.bias.data.view(-1)\n",
    "    target_b.append(b_flatten)\n",
    "target_b = torch.stack(target_b, dim=0)\n",
    "target_b = target_b.T\n",
    "print(target_b.shape)\n",
    "print(\"Max: \"+str(torch.max(target_b))+\" Min: \"+str(torch.min(target_b)))\n",
    "print(torch.sum(np.abs(target_b)>=epsilon))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "id": "5a2c7b8c-1600-4064-b9cf-37ccaedd62d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0113,  0.0286,  0.0000,  ...,  0.0203, -0.0106,  0.0101],\n",
       "        [-0.0000, -0.0000,  0.0000,  ...,  0.0631,  0.0174, -0.0000],\n",
       "        [ 0.0144,  0.0331,  0.0000,  ...,  0.0000,  0.0375,  0.0000],\n",
       "        ...,\n",
       "        [-0.0209,  0.0712, -0.0000,  ..., -0.0588,  0.0713, -0.0110],\n",
       "        [-0.0000, -0.1209,  0.0000,  ...,  0.0000, -0.0321,  0.0000],\n",
       "        [-0.0154,  0.0458, -0.0225,  ..., -0.0804, -0.1114, -0.0000]])"
      ]
     },
     "execution_count": 300,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.abs(target_wt)>=epsilon)*target_wt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "a705f76b-b569-41b1-b469-dfbce9be40cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\3790740740.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "32\n",
      "33\n",
      "34\n",
      "35\n",
      "36\n",
      "37\n",
      "38\n",
      "39\n",
      "40\n",
      "41\n",
      "42\n",
      "43\n",
      "44\n",
      "45\n",
      "46\n",
      "47\n",
      "48\n",
      "49\n",
      "50\n",
      "51\n",
      "52\n",
      "53\n",
      "54\n",
      "55\n",
      "56\n",
      "57\n",
      "58\n",
      "59\n",
      "60\n",
      "61\n",
      "62\n",
      "63\n",
      "64\n",
      "65\n",
      "66\n",
      "67\n",
      "68\n",
      "69\n",
      "70\n",
      "71\n",
      "72\n",
      "73\n",
      "74\n",
      "75\n",
      "76\n",
      "77\n",
      "78\n",
      "79\n",
      "80\n",
      "81\n",
      "82\n",
      "83\n",
      "84\n",
      "85\n",
      "86\n",
      "87\n",
      "88\n",
      "89\n",
      "90\n",
      "91\n",
      "92\n",
      "93\n",
      "94\n",
      "95\n",
      "96\n",
      "97\n",
      "98\n",
      "99\n",
      "100\n",
      "101\n",
      "102\n",
      "103\n",
      "104\n",
      "105\n",
      "106\n",
      "107\n",
      "108\n",
      "109\n",
      "110\n",
      "111\n",
      "112\n",
      "113\n",
      "114\n",
      "115\n",
      "116\n",
      "117\n",
      "118\n",
      "119\n",
      "120\n",
      "121\n",
      "122\n",
      "123\n",
      "124\n",
      "125\n",
      "126\n",
      "127\n",
      "128\n",
      "129\n",
      "130\n",
      "131\n",
      "132\n",
      "133\n",
      "134\n",
      "135\n",
      "136\n",
      "137\n",
      "138\n",
      "139\n",
      "140\n",
      "141\n",
      "142\n",
      "143\n",
      "144\n",
      "145\n",
      "146\n",
      "147\n",
      "148\n",
      "149\n",
      "150\n",
      "151\n",
      "152\n",
      "153\n",
      "154\n",
      "155\n",
      "156\n",
      "157\n",
      "158\n",
      "159\n",
      "160\n",
      "161\n",
      "162\n",
      "163\n",
      "164\n",
      "165\n",
      "166\n",
      "167\n",
      "168\n",
      "169\n",
      "170\n",
      "171\n",
      "172\n",
      "173\n",
      "174\n",
      "175\n",
      "176\n",
      "177\n",
      "178\n",
      "179\n",
      "180\n",
      "181\n",
      "182\n",
      "183\n",
      "184\n",
      "185\n",
      "186\n",
      "187\n",
      "188\n",
      "189\n",
      "190\n",
      "191\n",
      "192\n",
      "193\n",
      "194\n",
      "195\n",
      "196\n",
      "197\n",
      "198\n",
      "199\n",
      "200\n",
      "201\n",
      "202\n",
      "203\n",
      "204\n",
      "205\n",
      "206\n",
      "207\n",
      "208\n",
      "209\n",
      "210\n",
      "211\n",
      "212\n",
      "213\n",
      "214\n",
      "215\n",
      "216\n",
      "217\n",
      "218\n",
      "219\n",
      "220\n",
      "221\n",
      "222\n",
      "223\n",
      "224\n",
      "225\n",
      "226\n",
      "227\n",
      "228\n",
      "229\n",
      "230\n",
      "231\n",
      "232\n",
      "233\n",
      "234\n",
      "235\n",
      "236\n",
      "237\n",
      "238\n",
      "239\n",
      "240\n",
      "241\n",
      "242\n",
      "243\n",
      "244\n",
      "245\n",
      "246\n",
      "247\n",
      "248\n",
      "249\n",
      "250\n",
      "251\n",
      "252\n",
      "253\n",
      "254\n",
      "255\n",
      "256\n",
      "257\n",
      "258\n",
      "259\n",
      "260\n",
      "261\n",
      "262\n",
      "263\n",
      "264\n",
      "265\n",
      "266\n",
      "267\n",
      "268\n",
      "269\n",
      "270\n",
      "271\n",
      "272\n",
      "273\n",
      "274\n",
      "275\n",
      "276\n",
      "277\n",
      "278\n",
      "279\n",
      "280\n",
      "281\n",
      "282\n",
      "283\n",
      "284\n",
      "285\n",
      "286\n",
      "287\n",
      "288\n",
      "289\n",
      "290\n",
      "291\n",
      "292\n",
      "293\n",
      "294\n",
      "295\n",
      "296\n",
      "297\n",
      "298\n",
      "299\n",
      "300\n",
      "301\n",
      "302\n",
      "303\n",
      "304\n",
      "305\n",
      "306\n",
      "307\n",
      "308\n",
      "309\n",
      "310\n",
      "311\n",
      "312\n",
      "313\n",
      "314\n",
      "315\n",
      "316\n",
      "317\n",
      "318\n",
      "319\n",
      "320\n",
      "321\n",
      "322\n",
      "323\n",
      "324\n",
      "325\n",
      "326\n",
      "327\n",
      "328\n",
      "329\n",
      "330\n",
      "331\n",
      "332\n",
      "333\n",
      "334\n",
      "335\n",
      "336\n",
      "337\n",
      "338\n",
      "339\n",
      "340\n",
      "341\n",
      "342\n",
      "343\n",
      "344\n",
      "345\n",
      "346\n",
      "347\n",
      "348\n",
      "349\n",
      "350\n",
      "351\n",
      "352\n",
      "353\n",
      "354\n",
      "355\n",
      "356\n",
      "357\n",
      "358\n",
      "359\n",
      "360\n",
      "361\n",
      "362\n",
      "363\n",
      "364\n",
      "365\n",
      "366\n",
      "367\n",
      "368\n",
      "369\n",
      "370\n",
      "371\n",
      "372\n",
      "373\n",
      "374\n",
      "375\n",
      "376\n",
      "377\n",
      "378\n",
      "379\n",
      "380\n",
      "381\n",
      "382\n",
      "383\n",
      "384\n",
      "385\n",
      "386\n",
      "387\n",
      "388\n",
      "389\n",
      "390\n",
      "391\n",
      "392\n",
      "393\n",
      "394\n",
      "395\n",
      "396\n",
      "397\n",
      "398\n",
      "399\n",
      "400\n",
      "401\n",
      "402\n",
      "403\n",
      "404\n",
      "405\n",
      "406\n",
      "407\n",
      "408\n",
      "409\n",
      "410\n",
      "411\n",
      "412\n",
      "413\n",
      "414\n",
      "415\n",
      "416\n",
      "417\n",
      "418\n",
      "419\n",
      "420\n",
      "421\n",
      "422\n",
      "423\n",
      "424\n",
      "425\n",
      "426\n",
      "427\n",
      "428\n",
      "429\n",
      "430\n",
      "431\n",
      "432\n",
      "433\n",
      "434\n",
      "435\n",
      "436\n",
      "437\n",
      "438\n",
      "439\n",
      "440\n",
      "441\n",
      "442\n",
      "443\n",
      "444\n",
      "445\n",
      "446\n",
      "447\n",
      "448\n",
      "449\n",
      "450\n",
      "451\n",
      "452\n",
      "453\n",
      "454\n",
      "455\n",
      "456\n",
      "457\n",
      "458\n",
      "459\n",
      "460\n",
      "461\n",
      "462\n",
      "463\n",
      "464\n",
      "465\n",
      "466\n",
      "467\n",
      "468\n",
      "469\n",
      "470\n",
      "471\n",
      "472\n",
      "473\n",
      "474\n",
      "475\n",
      "476\n",
      "477\n",
      "478\n",
      "479\n",
      "480\n",
      "481\n",
      "482\n",
      "483\n",
      "484\n",
      "485\n",
      "486\n",
      "487\n",
      "488\n",
      "489\n",
      "490\n",
      "491\n",
      "492\n",
      "493\n",
      "494\n",
      "495\n",
      "496\n",
      "497\n",
      "498\n",
      "499\n",
      "500\n",
      "501\n",
      "502\n",
      "503\n",
      "504\n",
      "505\n",
      "506\n",
      "507\n",
      "508\n",
      "509\n",
      "510\n",
      "511\n",
      "512\n",
      "513\n",
      "514\n",
      "515\n",
      "516\n",
      "517\n",
      "518\n",
      "519\n",
      "520\n",
      "521\n",
      "522\n",
      "523\n",
      "524\n",
      "525\n",
      "526\n",
      "527\n",
      "528\n",
      "529\n",
      "530\n",
      "531\n",
      "532\n",
      "533\n",
      "534\n",
      "535\n",
      "536\n",
      "537\n",
      "538\n",
      "539\n",
      "540\n",
      "541\n",
      "542\n",
      "543\n",
      "544\n",
      "545\n",
      "546\n",
      "547\n",
      "548\n",
      "549\n",
      "550\n",
      "551\n",
      "552\n",
      "553\n",
      "554\n",
      "555\n",
      "556\n",
      "557\n",
      "558\n",
      "559\n",
      "560\n",
      "561\n",
      "562\n",
      "563\n",
      "564\n",
      "565\n",
      "566\n",
      "567\n",
      "568\n",
      "569\n",
      "570\n",
      "571\n",
      "572\n",
      "573\n",
      "574\n",
      "575\n",
      "576\n",
      "577\n",
      "578\n",
      "579\n",
      "580\n",
      "581\n",
      "582\n",
      "583\n",
      "584\n",
      "585\n",
      "586\n",
      "587\n",
      "588\n",
      "589\n",
      "590\n",
      "591\n",
      "592\n",
      "593\n",
      "594\n",
      "595\n",
      "596\n",
      "597\n",
      "598\n",
      "599\n",
      "600\n",
      "601\n",
      "602\n",
      "603\n",
      "604\n",
      "605\n",
      "606\n",
      "607\n",
      "608\n",
      "609\n",
      "610\n",
      "611\n",
      "612\n",
      "613\n",
      "614\n",
      "615\n",
      "616\n",
      "617\n",
      "618\n",
      "619\n",
      "620\n",
      "621\n",
      "622\n",
      "623\n",
      "624\n",
      "625\n",
      "626\n",
      "627\n",
      "628\n",
      "629\n",
      "630\n",
      "631\n",
      "632\n",
      "633\n",
      "634\n",
      "635\n",
      "636\n",
      "637\n",
      "638\n",
      "639\n",
      "640\n",
      "641\n",
      "642\n",
      "643\n",
      "644\n",
      "645\n",
      "646\n",
      "647\n",
      "648\n",
      "649\n",
      "650\n",
      "651\n",
      "652\n",
      "653\n",
      "654\n",
      "655\n",
      "656\n",
      "657\n",
      "658\n",
      "659\n",
      "660\n",
      "661\n",
      "662\n",
      "663\n",
      "664\n",
      "665\n",
      "666\n",
      "667\n",
      "668\n",
      "669\n",
      "670\n",
      "671\n",
      "672\n",
      "673\n",
      "674\n",
      "675\n",
      "676\n",
      "677\n",
      "678\n",
      "679\n",
      "680\n",
      "681\n",
      "682\n",
      "683\n",
      "684\n",
      "685\n",
      "686\n",
      "687\n",
      "688\n",
      "689\n",
      "690\n",
      "691\n",
      "692\n",
      "693\n",
      "694\n",
      "695\n",
      "696\n",
      "697\n",
      "698\n",
      "699\n",
      "700\n",
      "701\n",
      "702\n",
      "703\n",
      "704\n",
      "705\n",
      "706\n",
      "707\n",
      "708\n",
      "709\n",
      "710\n",
      "711\n",
      "712\n",
      "713\n",
      "714\n",
      "715\n",
      "716\n",
      "717\n",
      "718\n",
      "719\n",
      "720\n",
      "721\n",
      "722\n",
      "723\n",
      "724\n",
      "725\n",
      "726\n",
      "727\n",
      "728\n",
      "729\n",
      "730\n",
      "731\n",
      "732\n",
      "733\n",
      "734\n",
      "735\n",
      "736\n",
      "737\n",
      "738\n",
      "739\n",
      "740\n",
      "741\n",
      "742\n",
      "743\n",
      "744\n",
      "745\n",
      "746\n",
      "747\n",
      "748\n",
      "749\n",
      "750\n",
      "751\n",
      "752\n",
      "753\n",
      "754\n",
      "755\n",
      "756\n",
      "757\n",
      "758\n",
      "759\n",
      "760\n",
      "761\n",
      "762\n",
      "763\n",
      "764\n",
      "765\n",
      "766\n",
      "767\n",
      "768\n",
      "769\n",
      "770\n",
      "771\n",
      "772\n",
      "773\n",
      "774\n",
      "775\n",
      "776\n",
      "777\n",
      "778\n",
      "779\n",
      "780\n",
      "781\n",
      "782\n",
      "783\n",
      "784\n",
      "785\n",
      "786\n",
      "787\n",
      "788\n",
      "789\n",
      "790\n",
      "791\n",
      "792\n",
      "793\n",
      "794\n",
      "795\n",
      "796\n",
      "797\n",
      "798\n",
      "799\n"
     ]
    }
   ],
   "source": [
    "from itertools import combinations\n",
    "import copy\n",
    "from functools import reduce\n",
    "def countSetBits(n):\n",
    "    count = 0\n",
    "    while (n):\n",
    "        n &= int(n-1) \n",
    "        count+= 1\n",
    "    return count\n",
    "\n",
    "def subsets(arr,status,curr = 0):\n",
    "    global s\n",
    "    if(curr>=len(arr)):\n",
    "        s.append(np.sum(arr*status))\n",
    "        return\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 1\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 0\n",
    "\n",
    "def get_binary(num, digits=15):\n",
    "    bin = [0]*digits\n",
    "    start = digits-1\n",
    "    while(num>0):\n",
    "        bin[start] = num%2\n",
    "        num = num//2\n",
    "        start-=1\n",
    "    return bin\n",
    "\n",
    "def find_stats(super_set, best_len, best_ss_ind):\n",
    "    if(best_len == -1):\n",
    "        return None, None, None\n",
    "    overlap = reduce(lambda x, y: x & y, best_ss_ind)\n",
    "    overlap_len = countSetBits(overlap)\n",
    "    extra_len = 0\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        extra_len+=countSetBits(overlap ^ best_ss_ind[i])\n",
    "    return best_len, overlap_len, extra_len\n",
    "\n",
    "def print_stats(super_set, best_len, best_ss_ind):\n",
    "    if(best_len==-1):\n",
    "        print(\"Problem\")\n",
    "        return\n",
    "    print(f\"{\"Best overall subset is: \"+str(super_set) : <45}{str(get_binary(super_set)) : >25}\")\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        print(f\"{\"Subset \"+str(i+1)+\": \"+str(best_ss_ind[i])+\" Length: \"+str(countSetBits(best_ss_ind[i])) : <45}{str(get_binary(best_ss_ind[i])) : >25}\")\n",
    "    overlap = reduce(lambda x, y: x & y, best_ss_ind)\n",
    "    print(f\"{\"Overlap Subset: \"+str(overlap) : <45}{str(get_binary(overlap)) : >25}\")\n",
    "    print(f\"{\"Overall subset length: \" : <25}{str(countSetBits(super_set)) : >10}\")\n",
    "    print(f\"{\"Overlap length: \" : <25}{str(countSetBits(overlap)) : >10}\")\n",
    "    extra_len = 0\n",
    "    for i in range(len(best_ss_ind)):\n",
    "        extra_len+=countSetBits(overlap ^ best_ss_ind[i])\n",
    "    print(f\"{\"Extra length: \" : <25}{str(extra_len) : >10}\")\n",
    "    print(f\"{\"BEST LENGTH IS: \" : <25}{str(best_len) : >10}\")\n",
    "    print(\"-\"*100)\n",
    "    \n",
    "\n",
    "def subset_fixed_size(target, numbers, eps, subsize, errBest):\n",
    "    n = len(numbers)\n",
    "    cand = 0\n",
    "    indBest = np.array([np.NAN])\n",
    "    for ind in combinations(range(n),subsize):\n",
    "        inda = np.array(ind,dtype=\"int\")\n",
    "        napprox = np.sum(numbers[inda])\n",
    "        diff = np.abs(target-napprox)\n",
    "        if diff < errBest:\n",
    "            errBest = diff\n",
    "            cand = napprox\n",
    "            indBest = inda\n",
    "        if diff <= eps:\n",
    "            break\n",
    "    return cand, indBest, errBest\n",
    "\n",
    "def exhaustive(target, numbers, eps, nmax):\n",
    "    n = len(numbers)\n",
    "    err = np.abs(target)\n",
    "    errBest = err\n",
    "    cand = 0\n",
    "    indBest = np.array([-1])\n",
    "    nmax = min(nmax, n)\n",
    "    for k in range(nmax):\n",
    "        cank, indk, errk = subset_fixed_size(target, numbers, eps, k, errBest)\n",
    "        if errk < errBest:\n",
    "            errBest = errk\n",
    "            cand = cank\n",
    "            indBest = indk\n",
    "        if errBest <= eps:\n",
    "            break\n",
    "    return cand, indBest\n",
    "\n",
    "def find_best_subset_size(status, targets, experts, epsilon):\n",
    "    final_set = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        final_set = np.bitwise_or(final_set, status[i].reshape(1,-1))\n",
    "        #print(final_set)\n",
    "        final_set = np.unique(final_set.reshape(-1)).reshape(-1,1)\n",
    "    # print(final_set)\n",
    "    final_set = final_set.reshape(-1)\n",
    "    best = 100000\n",
    "    best_id = -1\n",
    "    for i in range(len(final_set)):\n",
    "        b = countSetBits(final_set[i])\n",
    "        if(b<best):\n",
    "            best = b\n",
    "            best_id = final_set[i]\n",
    "    #print(\"Best combo is: \"+str(best_id))\n",
    "    #print(\"Binary string: \"+str(get_binary(best_id)))\n",
    "    #print()\n",
    "    #print(\"Best is: \"+str(best))\n",
    "    if(not (best <= 15)):\n",
    "        print(\"Weird: \"+str(best))\n",
    "        best = 0\n",
    "    # count[best]+=1\n",
    "    if(best==0):\n",
    "        return None, -1, None, None    \n",
    "    tot_len = 0\n",
    "    for i in range(experts):\n",
    "        status[i] = np.reshape(status[i], (-1))\n",
    "    candidates = []\n",
    "    for i in range(experts):\n",
    "        cand_id = np.argwhere((np.bitwise_and(status[i],best_id)-status[i])==0)\n",
    "        #print(cand_id)\n",
    "        combos = [status[i][id] for id in cand_id]\n",
    "        candidates.append(combos)\n",
    "        for combo in combos:\n",
    "            # print(\"Combo:\"+str(combo[0])+\" Binary:\"+str(get_binary(combo[0])))\n",
    "            assert((np.sum(get_binary(combo[0])*rand_vars) - targets[i])<=epsilon)\n",
    "        cand_len = min([countSetBits(status[i][id]) for id in cand_id])\n",
    "        tot_len+=cand_len\n",
    "    candidates = [i[0][0] for i in candidates]\n",
    "    return best_id, best, candidates, tot_len\n",
    "\n",
    "def find_best_overall_size(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(ov_len+extra_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = ov_len+extra_len\n",
    "            #print()\n",
    "    if(not (best_len <= 1000)):\n",
    "        print(\"Weird: \"+str(best))\n",
    "        best_len = 0\n",
    "    if(best_len==0):\n",
    "        return None, -1, None, None\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: x & y, best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: x | y, best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss\n",
    "\n",
    "def find_best_extra_size(status, targets, experts, epsilon):\n",
    "    #final_set = status[0].reshape(-1,1)\n",
    "    overlaps = status[0].reshape(-1,1)\n",
    "    for i in range(1,experts):\n",
    "        #final_set = np.bitwise_or(final_set, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        overlaps = np.bitwise_and(overlaps, status[i].reshape(1,-1)).reshape(-1,1)\n",
    "        #combined = np.unique(np.concatenate((final_set,overlaps), axis=1), axis=0)\n",
    "        #print(combined.shape)\n",
    "        #final_set = combined[:,0].reshape(-1,1)\n",
    "        #overlaps = combined[:,1].reshape(-1,1)\n",
    "        #print(i)\n",
    "        overlaps = np.unique(overlaps.reshape(-1)).reshape(-1,1)\n",
    "    overlaps = overlaps.reshape(-1,1)\n",
    "    # print(overlaps.shape)\n",
    "    cand = []\n",
    "    for i in range(experts):\n",
    "        o_s_map = (np.bitwise_and(overlaps, status[i].reshape(1,-1))==overlaps)\n",
    "        # print(o_s_map.sum(axis=1))\n",
    "        cand.append([status[i][np.where(o_s_map[j]==True)] for j in range(len(overlaps))])\n",
    "    best_len = 1000000\n",
    "    best_ss = -1\n",
    "    best_ss_ind = None\n",
    "    best_ss_id = -1\n",
    "    for ov_id in range(len(overlaps)):\n",
    "        ov_len = countSetBits(overlaps[ov_id][0])\n",
    "        extra_len = 0\n",
    "        min_vals = []\n",
    "        for i in range(experts):\n",
    "            l = []\n",
    "            for j in range(len(cand[i][ov_id])):\n",
    "                l.append(countSetBits(cand[i][ov_id][j] ^ overlaps[ov_id][0]))\n",
    "                # print(\"Candidate: \"+str(cand[i][ov_id][j])+\" Binary: \"+str(get_binary(cand[i][ov_id][j])))\n",
    "            cand_min_id = np.argmin(l)\n",
    "            #print(l)\n",
    "            extra_len+=l[cand_min_id]\n",
    "            min_vals.append(cand[i][ov_id][cand_min_id])\n",
    "        if(extra_len<best_len):\n",
    "            best_ss = overlaps[ov_id][0]\n",
    "            best_ss_id = ov_id\n",
    "            best_ss_ind = copy.deepcopy(min_vals)\n",
    "            # print(\"Overlap: \"+str(overlaps[ov_id][0])+\" Binary: \"+str(get_binary(overlaps[ov_id][0])))\n",
    "            # print(\"Overlap length: \"+str(ov_len))\n",
    "            # print(\"Extra length: \"+str(extra_len))\n",
    "            # for i in range(experts):\n",
    "            #     print(\"Candidate: \"+str(min_vals[i])+\" Binary: \"+str(get_binary(min_vals[i])))\n",
    "            best_len = extra_len\n",
    "            #print()\n",
    "    ############## Verification #################\n",
    "    for i in range(experts):\n",
    "        assert(best_ss_ind[i] in status[i])\n",
    "    assert(overlaps[best_ss_id][0] == reduce(lambda x, y: x & y, best_ss_ind))\n",
    "    super_set = reduce(lambda x, y: x | y, best_ss_ind)\n",
    "    return super_set, best_len, best_ss_ind, best_ss\n",
    "\n",
    "\n",
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(3):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "for it in range(len(target_wt)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,15)\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    targets = ((torch.abs(target_wt[it])>=epsilon)*target_wt[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    s = []\n",
    "    subsets(rand_vars, np.zeros_like(rand_vars))\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Extra optimal\")\n",
    "    a,b,c,d = find_best_extra_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "id": "0d86c9e4-249e-47fe-ba89-921a67e77e7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_30836\\3790740740.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  n &= int(n-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "count = np.zeros(16)\n",
    "\n",
    "len_stats = {}\n",
    "len_stats['source'] = []\n",
    "len_stats['tgts_nz'] = []\n",
    "len_stats['status'] = []\n",
    "len_stats['stats'] = {}\n",
    "for i in range(3):\n",
    "    len_stats['stats'][str(i)+'best'] = []\n",
    "    len_stats['stats'][str(i)+'overlap'] = []\n",
    "    len_stats['stats'][str(i)+'extra'] = []\n",
    "    len_stats['stats'][str(i)+'total'] = []\n",
    "    len_stats['stats'][str(i)+'subsets'] = []\n",
    "    len_stats['stats'][str(i)+'errors'] = []\n",
    "for it in range(len(target_b)):\n",
    "    if(it%1==0):\n",
    "        print(it)\n",
    "    rand_vars = np.random.uniform(-1,1,15)\n",
    "    len_stats['source'].append(rand_vars)\n",
    "    targets = ((torch.abs(target_b[it])>=epsilon)*target_b[it])\n",
    "    ind = torch.argwhere(targets!=0.0)\n",
    "    experts = len(ind)\n",
    "    targets = targets[ind].view(-1).numpy()\n",
    "    len_stats['tgts_nz'].append(targets)\n",
    "    if(len(targets)==0):\n",
    "        print(\"Zero target\")\n",
    "        print()\n",
    "        print(\"#\"*100)\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    #########\n",
    "    #print(rand_vars)\n",
    "    #########\n",
    "    s = []\n",
    "    subsets(rand_vars, np.zeros_like(rand_vars))\n",
    "    status = []\n",
    "    for i in range(experts):\n",
    "        status.append(np.argwhere(np.abs(s-targets[i])<epsilon))\n",
    "    # for i in range(experts):\n",
    "    #     print(len(status[i]))\n",
    "\n",
    "    flag = False\n",
    "    for i in range(experts):\n",
    "        if(len(status[i])==0):\n",
    "            flag = True\n",
    "    if flag:\n",
    "        len_stats['status'].append(False)\n",
    "        for i in range(3):\n",
    "            len_stats['stats'][str(i)+'best'].append(None)\n",
    "            len_stats['stats'][str(i)+'overlap'].append(None)\n",
    "            len_stats['stats'][str(i)+'extra'].append(None)\n",
    "            len_stats['stats'][str(i)+'total'].append(None)\n",
    "            len_stats['stats'][str(i)+'subsets'].append(None)\n",
    "            len_stats['stats'][str(i)+'errors'].append(None)\n",
    "        continue\n",
    "    len_stats['status'].append(True)\n",
    "    #print([status[i].shape for i in range(experts)])\n",
    "    #########\n",
    "    #for i in range(experts):\n",
    "    #    print(np.sum(status[i]))\n",
    "    #########\n",
    "    # print(\"Smallest superset\")\n",
    "    a,b,c,d = find_best_subset_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 0\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print(c)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Overall optimal\")\n",
    "    a,b,c,d = find_best_overall_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 1\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    \n",
    "    # print(\"Extra optimal\")\n",
    "    a,b,c,d = find_best_extra_size(status, targets, experts, epsilon)\n",
    "    a1,b1,c1 = find_stats(a,b,c)\n",
    "    i = 2\n",
    "    if(a is not None):\n",
    "        len_stats['stats'][str(i)+'best'].append(a1)\n",
    "        len_stats['stats'][str(i)+'overlap'].append(b1)\n",
    "        len_stats['stats'][str(i)+'extra'].append(c1)\n",
    "        len_stats['stats'][str(i)+'total'].append(b1+c1)\n",
    "        len_stats['stats'][str(i)+'subsets'].append(c)\n",
    "        err = []\n",
    "        for j in range(experts):\n",
    "            err.append(targets[j]-np.sum(get_binary(c[j])*rand_vars))\n",
    "        len_stats['stats'][str(i)+'errors'].append(err)\n",
    "    # print_stats(a,b,c)\n",
    "    # print(err)\n",
    "    # print()\n",
    "    # print(\"#\"*100)\n",
    "torch.save(len_stats, \"linear_moe_ssa_approx_bias.pth\")    \n",
    "# print(len_stats)\n",
    "#print(\"Frequency of best subset sizes: \"+str(count))\n",
    "#print(\"Average fraction of subset size compared to individual sizes: \"+str(np.mean(np.array(best_len)/np.array(total_len))))"
   ]
  },
  {
   "cell_type": "raw",
   "id": "9ec5ec9f-c913-40e9-9c9f-343f8b22e9c4",
   "metadata": {},
   "source": [
    "Approximation of the linear MoE model by SSA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "112091e6-f32b-49b4-a728-5b60a78a712c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1203990259.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1203990259.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  len_stats_w = torch.load(\"linear_moe_ssa_approx_2L.pth\")\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\1203990259.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  len_stats_b = torch.load(\"linear_moe_ssa_approx_2L_bias.pth\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "len_stats_w = torch.load(\"linear_moe_ssa_approx_2L.pth\")\n",
    "len_stats_b = torch.load(\"linear_moe_ssa_approx_2L_bias.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "id": "454508d6-e492-4d7f-9767-e8f4390bcd9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([800, 8])\n",
      "Max: tensor(0.3950) Min: tensor(-0.4134)\n",
      "tensor(4163)\n",
      "torch.Size([16, 8])\n",
      "Max: tensor(0.2096) Min: tensor(-0.2545)\n",
      "tensor(64)\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "scale = 10\n",
    "for i in range(experts):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "\n",
    "target_wt = []\n",
    "for i in range(experts):\n",
    "    wt_flatten = lin_moe.models[i].conv1.weight.data.view(-1)\n",
    "    target_wt.append(wt_flatten)\n",
    "target_wt = torch.stack(target_wt, dim=0)\n",
    "target_wt = target_wt.T\n",
    "print(target_wt.shape)\n",
    "print(\"Max: \"+str(torch.max(target_wt))+\" Min: \"+str(torch.min(target_wt)))\n",
    "print(torch.sum(np.abs(target_wt)>=epsilon))\n",
    "\n",
    "target_b = []\n",
    "for i in range(experts):\n",
    "    b_flatten = lin_moe.models[i].conv1.bias.data.view(-1)\n",
    "    target_b.append(b_flatten)\n",
    "target_b = torch.stack(target_b, dim=0)\n",
    "target_b = target_b.T\n",
    "print(target_b.shape)\n",
    "print(\"Max: \"+str(torch.max(target_b))+\" Min: \"+str(torch.min(target_b)))\n",
    "print(torch.sum(np.abs(target_b)>=epsilon))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "b8ecb979-63fc-4969-8aea-262bfb1e3ff9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Verifying Weights...\n",
      "####################################################################################################\n",
      "No solution for index: 258\n",
      "####################################################################################################\n",
      "####################################################################################################\n",
      "No solution for index: 361\n",
      "####################################################################################################\n",
      "####################################################################################################\n",
      "No solution for index: 775\n",
      "####################################################################################################\n",
      "Verifying biases\n",
      "813\n"
     ]
    }
   ],
   "source": [
    "approach = 2\n",
    "def get_binary(num, digits=15):\n",
    "    bin = [0]*digits\n",
    "    start = digits-1\n",
    "    while(num>0):\n",
    "        bin[start] = num%2\n",
    "        num = num//2\n",
    "        start-=1\n",
    "    return bin\n",
    "############### VERIFICATION OF WEIGHTS #########################################\n",
    "print(\"Verifying Weights...\")\n",
    "count = 0\n",
    "for i in range(len(target_wt)):\n",
    "    if(not len_stats_w['status'][i]):\n",
    "        print(\"#\"*100)\n",
    "        print(\"No solution for index: \"+str(i))\n",
    "        print(\"#\"*100)\n",
    "        continue\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_w['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_w['stats'][str(approach)+'subsets'][i][j])*len_stats_w['source'][i]*len_stats_w['first_source']))\n",
    "    if(not np.all(np.abs(np.array(len_stats_w['tgts_nz'][i])-np.array(recon_tgt))<=epsilon)):\n",
    "        count+=1\n",
    "    #assert(np.all(np.abs(np.array(len_stats_w['tgts_nz'][i])-np.array(recon_tgt))<=epsilon))\n",
    "############## VERIFICATION OF BIASES ############################################\n",
    "print(\"Verifying biases\")\n",
    "for i in range(len(target_b)):\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_b['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_b['stats'][str(approach)+'subsets'][i][j])*len_stats_b['source'][i]*len_stats_b['first_source']))\n",
    "    if(not np.all(np.abs(np.array(len_stats_b['tgts_nz'][i])-np.array(recon_tgt))<=epsilon)):\n",
    "        count+=1\n",
    "    # assert(np.all(np.abs(np.array(len_stats_b['tgts_nz'][i])-np.array(recon_tgt))<=epsilon))\n",
    "print(count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "id": "a93d8842-e91c-4e51-a152-c04f8d5ee52d",
   "metadata": {},
   "outputs": [],
   "source": [
    "wt_recon = np.zeros_like(target_wt)\n",
    "experts = 8\n",
    "for i in range(len(target_wt)):\n",
    "    if(not len_stats_w['status'][i]):\n",
    "        continue\n",
    "    ind = np.argwhere(np.abs(target_wt[i])>=epsilon)[0]\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_w['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_w['stats'][str(approach)+'subsets'][i][j])*len_stats_w['source'][i]))\n",
    "    wt_recon[i][ind] = np.array(recon_tgt)\n",
    "wt_recon = torch.Tensor(wt_recon)\n",
    "    \n",
    "b_recon = np.zeros_like(target_b)\n",
    "experts = 8\n",
    "for i in range(len(target_b)):\n",
    "    if(not len_stats_b['status'][i]):\n",
    "        continue\n",
    "    ind = np.argwhere(np.abs(target_b[i])>=epsilon)[0]\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_b['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_b['stats'][str(approach)+'subsets'][i][j])*len_stats_b['source'][i]))\n",
    "    b_recon[i][ind] = np.array(recon_tgt)\n",
    "b_recon = torch.Tensor(b_recon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "id": "78d621d1-fbb0-49dc-962a-b9ae42d5a0b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3734345243.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe_normal = torch.load('linear_moe.pth')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3734345243.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe_recon = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performance of Normal networks\n",
      "Accuracy of the network on the 16000 test images: 95.4813 %\n",
      "Accuracy of the network on the 16000 test images: 94.2625 %\n",
      "Performance of Reconstructed networks\n",
      "Accuracy of the network on the 16000 test images: 95.2438 %\n",
      "Accuracy of the network on the 16000 test images: 94.0875 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "94.0875"
      ]
     },
     "execution_count": 163,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin_moe_normal = torch.load('linear_moe.pth')\n",
    "print(\"Performance of Normal networks\")\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "test(lin_moe_normal, criterion, training_data, training_labels)\n",
    "test(lin_moe_normal, criterion, test_data, test_labels)\n",
    "\n",
    "lin_moe_recon = torch.load('linear_moe.pth')\n",
    "\n",
    "for i in range(experts):\n",
    "    lin_moe_recon.models[i].conv1.weight.data = wt_recon[:,i].view(lin_moe_normal.models[i].conv1.weight.shape)\n",
    "    lin_moe_recon.models[i].conv1.bias.data = b_recon[:,i].view(lin_moe_normal.models[i].conv1.bias.shape)\n",
    "print(\"Performance of Reconstructed networks\")\n",
    "test(lin_moe_recon, criterion, training_data, training_labels)\n",
    "test(lin_moe_recon, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "4bf88567-4c29-4233-81a0-195ee9793079",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3903283536.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe = torch.load('linear_moe.pth')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3903283536.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  len_stats_w = torch.load(\"linear_moe_ssa_approx_2L_diff.pth\")\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3903283536.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  len_stats_b = torch.load(\"linear_moe_ssa_approx_2L_diff_bias.pth\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "lin_moe = torch.load('linear_moe.pth')\n",
    "len_stats_w = torch.load(\"linear_moe_ssa_approx_2L_diff.pth\")\n",
    "len_stats_b = torch.load(\"linear_moe_ssa_approx_2L_diff_bias.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "55300f2a-fe3a-4ba6-903f-b2505d5786e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([800, 8])\n",
      "Max: tensor(0.3950) Min: tensor(-0.4134)\n",
      "tensor(4163)\n",
      "torch.Size([16, 8])\n",
      "Max: tensor(0.2096) Min: tensor(-0.2545)\n",
      "tensor(64)\n"
     ]
    }
   ],
   "source": [
    "experts = 8\n",
    "epsilon = 0.01\n",
    "scale = 10\n",
    "for i in range(experts):\n",
    "    lin_moe.models[i].conv1.weight.data *= scale\n",
    "    lin_moe.models[i].conv1.bias.data *= scale\n",
    "\n",
    "target_wt = []\n",
    "for i in range(experts):\n",
    "    wt_flatten = lin_moe.models[i].conv1.weight.data.view(-1)\n",
    "    target_wt.append(wt_flatten)\n",
    "target_wt = torch.stack(target_wt, dim=0)\n",
    "target_wt = target_wt.T\n",
    "print(target_wt.shape)\n",
    "print(\"Max: \"+str(torch.max(target_wt))+\" Min: \"+str(torch.min(target_wt)))\n",
    "print(torch.sum(np.abs(target_wt)>=epsilon))\n",
    "\n",
    "target_b = []\n",
    "for i in range(experts):\n",
    "    b_flatten = lin_moe.models[i].conv1.bias.data.view(-1)\n",
    "    target_b.append(b_flatten)\n",
    "target_b = torch.stack(target_b, dim=0)\n",
    "target_b = target_b.T\n",
    "print(target_b.shape)\n",
    "print(\"Max: \"+str(torch.max(target_b))+\" Min: \"+str(torch.min(target_b)))\n",
    "print(torch.sum(np.abs(target_b)>=epsilon))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "30edc5bd-afdc-4198-a358-7858a83bb9fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Verifying Weights...\n",
      "####################################################################################################\n",
      "No solution for index: 93\n",
      "####################################################################################################\n",
      "Verifying biases\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "approach = 2\n",
    "def get_binary(num, digits=15):\n",
    "    bin = [0]*digits\n",
    "    start = digits-1\n",
    "    while(num>0):\n",
    "        bin[start] = num%2\n",
    "        num = num//2\n",
    "        start-=1\n",
    "    return bin\n",
    "############### VERIFICATION OF WEIGHTS #########################################\n",
    "print(\"Verifying Weights...\")\n",
    "count = 0\n",
    "for i in range(len(target_wt)):\n",
    "    if(not len_stats_w['status'][i]):\n",
    "        print(\"#\"*100)\n",
    "        print(\"No solution for index: \"+str(i))\n",
    "        print(\"#\"*100)\n",
    "        continue\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_w['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_w['stats'][str(approach)+'subsets'][i][j])*len_stats_w['source'][i][j]))\n",
    "    if(not np.all(np.abs(np.array(len_stats_w['tgts_nz'][i])-np.array(recon_tgt))<=epsilon)):\n",
    "        count+=1\n",
    "    #assert(np.all(np.abs(np.array(len_stats_w['tgts_nz'][i])-np.array(recon_tgt))<=epsilon))\n",
    "############## VERIFICATION OF BIASES ############################################\n",
    "print(\"Verifying biases\")\n",
    "for i in range(len(target_b)):\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_b['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_b['stats'][str(approach)+'subsets'][i][j])*len_stats_b['source'][i][j]))\n",
    "    if(not np.all(np.abs(np.array(len_stats_b['tgts_nz'][i])-np.array(recon_tgt))<=epsilon)):\n",
    "        count+=1\n",
    "    # assert(np.all(np.abs(np.array(len_stats_b['tgts_nz'][i])-np.array(recon_tgt))<=epsilon))\n",
    "print(count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "73ca3747-5ff3-40e7-bcaf-e7f4136f3a65",
   "metadata": {},
   "outputs": [],
   "source": [
    "wt_recon = np.zeros_like(target_wt)\n",
    "experts = 8\n",
    "for i in range(len(target_wt)):\n",
    "    if(not len_stats_w['status'][i]):\n",
    "        continue\n",
    "    ind = np.argwhere(np.abs(target_wt[i])>=epsilon)[0]\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_w['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_w['stats'][str(approach)+'subsets'][i][j])*len_stats_w['source'][i][j]))\n",
    "    wt_recon[i][ind] = np.array(recon_tgt)\n",
    "wt_recon = torch.Tensor(wt_recon)\n",
    "    \n",
    "b_recon = np.zeros_like(target_b)\n",
    "experts = 8\n",
    "for i in range(len(target_b)):\n",
    "    if(not len_stats_b['status'][i]):\n",
    "        continue\n",
    "    ind = np.argwhere(np.abs(target_b[i])>=epsilon)[0]\n",
    "    recon_tgt = []\n",
    "    for j in range(len(len_stats_b['tgts_nz'][i])):\n",
    "        recon_tgt.append(np.sum(get_binary(len_stats_b['stats'][str(approach)+'subsets'][i][j])*len_stats_b['source'][i][j]))\n",
    "    b_recon[i][ind] = np.array(recon_tgt)\n",
    "b_recon = torch.Tensor(b_recon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "2c4e669a-4da3-498b-98cc-5cfa9b55f2bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3734345243.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe_normal = torch.load('linear_moe.pth')\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_26176\\3734345243.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  lin_moe_recon = torch.load('linear_moe.pth')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performance of Normal networks\n",
      "Accuracy of the network on the 16000 test images: 95.4750 %\n",
      "Accuracy of the network on the 16000 test images: 94.0625 %\n",
      "Performance of Reconstructed networks\n",
      "Accuracy of the network on the 16000 test images: 95.3812 %\n",
      "Accuracy of the network on the 16000 test images: 94.0625 %\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "94.0625"
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin_moe_normal = torch.load('linear_moe.pth')\n",
    "print(\"Performance of Normal networks\")\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "test(lin_moe_normal, criterion, training_data, training_labels)\n",
    "test(lin_moe_normal, criterion, test_data, test_labels)\n",
    "\n",
    "lin_moe_recon = torch.load('linear_moe.pth')\n",
    "\n",
    "for i in range(experts):\n",
    "    lin_moe_recon.models[i].conv1.weight.data = wt_recon[:,i].view(lin_moe_normal.models[i].conv1.weight.shape)\n",
    "    lin_moe_recon.models[i].conv1.bias.data = b_recon[:,i].view(lin_moe_normal.models[i].conv1.bias.shape)\n",
    "print(\"Performance of Reconstructed networks\")\n",
    "test(lin_moe_recon, criterion, training_data, training_labels)\n",
    "test(lin_moe_recon, criterion, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5ff08256-2000-4525-abea-790d195c22bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def subsets(arr,status,curr = 0):\n",
    "    global s\n",
    "    if(curr>=len(arr)):\n",
    "        s.append(np.sum(arr*status))\n",
    "        return\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 1\n",
    "    subsets(arr,status,curr+1)\n",
    "    status[curr] = 0\n",
    "\n",
    "s = []\n",
    "epsilon = 0.001\n",
    "rand_vars = 2**np.arange(11)*epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7e5c7a80-a076-4c03-982a-9e552e8c0eb7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.000e-03, 2.000e-03, 4.000e-03, 8.000e-03, 1.600e-02, 3.200e-02,\n",
       "       6.400e-02, 1.280e-01, 2.560e-01, 5.120e-01, 1.024e+00])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rand_vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "b437c62f-0e48-4f2c-b1d3-c0bec4cb7031",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_17696\\2897033741.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  mdl = torch.load(\"linear_moe.pth\")\n",
      "C:\\Users\\RahulN\\AppData\\Local\\Temp\\ipykernel_17696\\2897033741.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  ssa_approx = torch.load(\"linear_moe_ssa_approx.pth\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "mdl = torch.load(\"linear_moe.pth\")\n",
    "ssa_approx = torch.load(\"linear_moe_ssa_approx.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "fca105b8-b2c1-47f0-b2cf-ef6e7f7df92e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n",
      "torch.Size([16, 1, 50])\n"
     ]
    }
   ],
   "source": [
    "params = []\n",
    "for i in range(8):\n",
    "    params.append(mdl.models[i].conv1.weight.data.view(-1,16,1,50))\n",
    "    print(mdl.models[i].conv1.weight.data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "13f6526e-d9c4-4e6b-aee0-1851d1266b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = torch.vstack(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "e90d8ffd-d061-45f6-9be0-2881166136ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 16, 1, 50])"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "bf2731db-82da-4fbf-9927-3dd77135fd0c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range: (tensor(0.0557), tensor(-0.2439))\n",
      "Range difference: tensor(0.2995)\n",
      "Range: (tensor(0.1498), tensor(-0.0139))\n",
      "Range difference: tensor(0.1638)\n",
      "Range: (tensor(0.0581), tensor(-0.1043))\n",
      "Range difference: tensor(0.1624)\n",
      "Range: (tensor(0.1967), tensor(-0.0847))\n",
      "Range difference: tensor(0.2813)\n",
      "Range: (tensor(0.0386), tensor(-0.1641))\n",
      "Range difference: tensor(0.2027)\n",
      "Range: (tensor(0.1842), tensor(-0.0073))\n",
      "Range difference: tensor(0.1915)\n",
      "Range: (tensor(0.2527), tensor(-0.0870))\n",
      "Range difference: tensor(0.3397)\n",
      "Range: (tensor(0.0637), tensor(-0.0636))\n",
      "Range difference: tensor(0.1273)\n",
      "Range: (tensor(0.0279), tensor(-0.1666))\n",
      "Range difference: tensor(0.1945)\n",
      "Range: (tensor(0.0361), tensor(-0.0492))\n",
      "Range difference: tensor(0.0853)\n",
      "Range: (tensor(0.0084), tensor(-0.1338))\n",
      "Range difference: tensor(0.1422)\n",
      "Range: (tensor(0.1124), tensor(-0.4134))\n",
      "Range difference: tensor(0.5258)\n",
      "Range: (tensor(0.0091), tensor(-0.1512))\n",
      "Range difference: tensor(0.1602)\n",
      "Range: (tensor(0.0551), tensor(-0.0402))\n",
      "Range difference: tensor(0.0953)\n",
      "Range: (tensor(0.1081), tensor(-0.0789))\n",
      "Range difference: tensor(0.1870)\n",
      "Range: (tensor(0.3129), tensor(-0.0309))\n",
      "Range difference: tensor(0.3439)\n",
      "Range: (tensor(0.0673), tensor(-0.0474))\n",
      "Range difference: tensor(0.1147)\n",
      "Range: (tensor(0.0226), tensor(-0.0376))\n",
      "Range difference: tensor(0.0602)\n",
      "Range: (tensor(0.1004), tensor(-0.1062))\n",
      "Range difference: tensor(0.2066)\n",
      "Range: (tensor(0.0623), tensor(-0.1202))\n",
      "Range difference: tensor(0.1825)\n",
      "Range: (tensor(0.0161), tensor(-0.1057))\n",
      "Range difference: tensor(0.1218)\n",
      "Range: (tensor(0.2642), tensor(-0.1201))\n",
      "Range difference: tensor(0.3843)\n",
      "Range: (tensor(0.1008), tensor(-0.1700))\n",
      "Range difference: tensor(0.2708)\n",
      "Range: (tensor(0.1599), tensor(-0.0592))\n",
      "Range difference: tensor(0.2190)\n",
      "Range: (tensor(0.0534), tensor(-0.0623))\n",
      "Range difference: tensor(0.1157)\n",
      "Range: (tensor(0.1918), tensor(-0.0318))\n",
      "Range difference: tensor(0.2236)\n",
      "Range: (tensor(0.0434), tensor(-0.0048))\n",
      "Range difference: tensor(0.0482)\n",
      "Range: (tensor(0.0823), tensor(-0.0206))\n",
      "Range difference: tensor(0.1029)\n",
      "Range: (tensor(0.1407), tensor(0.0036))\n",
      "Range difference: tensor(0.1370)\n",
      "Range: (tensor(0.0803), tensor(-0.1648))\n",
      "Range difference: tensor(0.2451)\n",
      "Range: (tensor(0.0637), tensor(-0.0052))\n",
      "Range difference: tensor(0.0689)\n",
      "Range: (tensor(0.0746), tensor(-0.1270))\n",
      "Range difference: tensor(0.2016)\n",
      "Range: (tensor(0.0028), tensor(-0.0627))\n",
      "Range difference: tensor(0.0655)\n",
      "Range: (tensor(0.0682), tensor(-0.0764))\n",
      "Range difference: tensor(0.1446)\n",
      "Range: (tensor(0.0802), tensor(-0.0528))\n",
      "Range difference: tensor(0.1330)\n",
      "Range: (tensor(0.1006), tensor(-0.3654))\n",
      "Range difference: tensor(0.4660)\n",
      "Range: (tensor(0.0449), tensor(-0.0224))\n",
      "Range difference: tensor(0.0673)\n",
      "Range: (tensor(0.0745), tensor(-0.0599))\n",
      "Range difference: tensor(0.1345)\n",
      "Range: (tensor(0.0609), tensor(-0.0255))\n",
      "Range difference: tensor(0.0864)\n",
      "Range: (tensor(0.3950), tensor(-0.0424))\n",
      "Range difference: tensor(0.4374)\n",
      "Range: (tensor(0.0461), tensor(-0.0411))\n",
      "Range difference: tensor(0.0871)\n",
      "Range: (tensor(0.0964), tensor(-0.0579))\n",
      "Range difference: tensor(0.1542)\n",
      "Range: (tensor(0.1177), tensor(-0.0669))\n",
      "Range difference: tensor(0.1846)\n",
      "Range: (tensor(0.0297), tensor(-0.0564))\n",
      "Range difference: tensor(0.0861)\n",
      "Range: (tensor(0.0450), tensor(-0.0511))\n",
      "Range difference: tensor(0.0961)\n",
      "Range: (tensor(0.1946), tensor(-0.0641))\n",
      "Range difference: tensor(0.2587)\n",
      "Range: (tensor(0.0353), tensor(-0.1120))\n",
      "Range difference: tensor(0.1473)\n",
      "Range: (tensor(0.0587), tensor(-0.0712))\n",
      "Range difference: tensor(0.1299)\n",
      "Range: (tensor(0.2452), tensor(-0.0021))\n",
      "Range difference: tensor(0.2473)\n",
      "Range: (tensor(0.1122), tensor(-0.1157))\n",
      "Range difference: tensor(0.2279)\n",
      "Range: (tensor(0.0582), tensor(-0.2437))\n",
      "Range difference: tensor(0.3018)\n",
      "Range: (tensor(0.1473), tensor(-0.0138))\n",
      "Range difference: tensor(0.1611)\n",
      "Range: (tensor(0.0570), tensor(-0.1030))\n",
      "Range difference: tensor(0.1600)\n",
      "Range: (tensor(0.1969), tensor(-0.0842))\n",
      "Range difference: tensor(0.2811)\n",
      "Range: (tensor(0.0380), tensor(-0.1643))\n",
      "Range difference: tensor(0.2023)\n",
      "Range: (tensor(0.1838), tensor(-0.0083))\n",
      "Range difference: tensor(0.1921)\n",
      "Range: (tensor(0.2526), tensor(-0.0871))\n",
      "Range difference: tensor(0.3396)\n",
      "Range: (tensor(0.0623), tensor(-0.0652))\n",
      "Range difference: tensor(0.1275)\n",
      "Range: (tensor(0.0277), tensor(-0.1677))\n",
      "Range difference: tensor(0.1954)\n",
      "Range: (tensor(0.0352), tensor(-0.0490))\n",
      "Range difference: tensor(0.0842)\n",
      "Range: (tensor(0.0102), tensor(-0.1353))\n",
      "Range difference: tensor(0.1455)\n",
      "Range: (tensor(0.1140), tensor(-0.4113))\n",
      "Range difference: tensor(0.5254)\n",
      "Range: (tensor(0.0084), tensor(-0.1522))\n",
      "Range difference: tensor(0.1606)\n",
      "Range: (tensor(0.0543), tensor(-0.0405))\n",
      "Range difference: tensor(0.0948)\n",
      "Range: (tensor(0.1080), tensor(-0.0794))\n",
      "Range difference: tensor(0.1874)\n",
      "Range: (tensor(0.3149), tensor(-0.0288))\n",
      "Range difference: tensor(0.3438)\n",
      "Range: (tensor(0.0693), tensor(-0.0474))\n",
      "Range difference: tensor(0.1167)\n",
      "Range: (tensor(0.0225), tensor(-0.0374))\n",
      "Range difference: tensor(0.0599)\n",
      "Range: (tensor(0.0990), tensor(-0.1055))\n",
      "Range difference: tensor(0.2046)\n",
      "Range: (tensor(0.0609), tensor(-0.1199))\n",
      "Range difference: tensor(0.1808)\n",
      "Range: (tensor(0.0160), tensor(-0.1076))\n",
      "Range difference: tensor(0.1236)\n",
      "Range: (tensor(0.2663), tensor(-0.1196))\n",
      "Range difference: tensor(0.3859)\n",
      "Range: (tensor(0.0995), tensor(-0.1699))\n",
      "Range difference: tensor(0.2694)\n",
      "Range: (tensor(0.1596), tensor(-0.0592))\n",
      "Range difference: tensor(0.2187)\n",
      "Range: (tensor(0.0538), tensor(-0.0645))\n",
      "Range difference: tensor(0.1183)\n",
      "Range: (tensor(0.1924), tensor(-0.0311))\n",
      "Range difference: tensor(0.2234)\n",
      "Range: (tensor(0.0459), tensor(-0.0061))\n",
      "Range difference: tensor(0.0520)\n",
      "Range: (tensor(0.0836), tensor(-0.0200))\n",
      "Range difference: tensor(0.1037)\n",
      "Range: (tensor(0.1399), tensor(0.0030))\n",
      "Range difference: tensor(0.1369)\n",
      "Range: (tensor(0.0797), tensor(-0.1649))\n",
      "Range difference: tensor(0.2446)\n",
      "Range: (tensor(0.0646), tensor(-0.0054))\n",
      "Range difference: tensor(0.0700)\n",
      "Range: (tensor(0.0771), tensor(-0.1275))\n",
      "Range difference: tensor(0.2046)\n",
      "Range: (tensor(0.0020), tensor(-0.0650))\n",
      "Range difference: tensor(0.0669)\n",
      "Range: (tensor(0.0671), tensor(-0.0753))\n",
      "Range difference: tensor(0.1424)\n",
      "Range: (tensor(0.0786), tensor(-0.0538))\n",
      "Range difference: tensor(0.1324)\n",
      "Range: (tensor(0.1027), tensor(-0.3660))\n",
      "Range difference: tensor(0.4687)\n",
      "Range: (tensor(0.0447), tensor(-0.0232))\n",
      "Range difference: tensor(0.0679)\n",
      "Range: (tensor(0.0750), tensor(-0.0618))\n",
      "Range difference: tensor(0.1368)\n",
      "Range: (tensor(0.0600), tensor(-0.0258))\n",
      "Range difference: tensor(0.0858)\n",
      "Range: (tensor(0.3930), tensor(-0.0448))\n",
      "Range difference: tensor(0.4378)\n",
      "Range: (tensor(0.0473), tensor(-0.0400))\n",
      "Range difference: tensor(0.0873)\n",
      "Range: (tensor(0.0972), tensor(-0.0581))\n",
      "Range difference: tensor(0.1553)\n",
      "Range: (tensor(0.1178), tensor(-0.0665))\n",
      "Range difference: tensor(0.1842)\n",
      "Range: (tensor(0.0293), tensor(-0.0576))\n",
      "Range difference: tensor(0.0869)\n",
      "Range: (tensor(0.0472), tensor(-0.0491))\n",
      "Range difference: tensor(0.0963)\n",
      "Range: (tensor(0.1938), tensor(-0.0628))\n",
      "Range difference: tensor(0.2566)\n",
      "Range: (tensor(0.0357), tensor(-0.1106))\n",
      "Range difference: tensor(0.1463)\n",
      "Range: (tensor(0.0589), tensor(-0.0723))\n",
      "Range difference: tensor(0.1312)\n",
      "Range: (tensor(0.2437), tensor(-0.0017))\n",
      "Range difference: tensor(0.2454)\n",
      "Range: (tensor(0.1125), tensor(-0.1152))\n",
      "Range difference: tensor(0.2278)\n",
      "Range: (tensor(0.0583), tensor(-0.2437))\n",
      "Range difference: tensor(0.3020)\n",
      "Range: (tensor(0.1495), tensor(-0.0131))\n",
      "Range difference: tensor(0.1626)\n",
      "Range: (tensor(0.0576), tensor(-0.1025))\n",
      "Range difference: tensor(0.1601)\n",
      "Range: (tensor(0.1967), tensor(-0.0843))\n",
      "Range difference: tensor(0.2810)\n",
      "Range: (tensor(0.0395), tensor(-0.1642))\n",
      "Range difference: tensor(0.2037)\n",
      "Range: (tensor(0.1827), tensor(-0.0089))\n",
      "Range difference: tensor(0.1915)\n",
      "Range: (tensor(0.2544), tensor(-0.0875))\n",
      "Range difference: tensor(0.3419)\n",
      "Range: (tensor(0.0642), tensor(-0.0643))\n",
      "Range difference: tensor(0.1285)\n",
      "Range: (tensor(0.0289), tensor(-0.1690))\n",
      "Range difference: tensor(0.1979)\n",
      "Range: (tensor(0.0359), tensor(-0.0494))\n",
      "Range difference: tensor(0.0853)\n",
      "Range: (tensor(0.0097), tensor(-0.1360))\n",
      "Range difference: tensor(0.1457)\n",
      "Range: (tensor(0.1142), tensor(-0.4119))\n",
      "Range difference: tensor(0.5261)\n",
      "Range: (tensor(0.0082), tensor(-0.1519))\n",
      "Range difference: tensor(0.1601)\n",
      "Range: (tensor(0.0553), tensor(-0.0409))\n",
      "Range difference: tensor(0.0962)\n",
      "Range: (tensor(0.1085), tensor(-0.0787))\n",
      "Range difference: tensor(0.1872)\n",
      "Range: (tensor(0.3139), tensor(-0.0302))\n",
      "Range difference: tensor(0.3441)\n",
      "Range: (tensor(0.0672), tensor(-0.0469))\n",
      "Range difference: tensor(0.1142)\n",
      "Range: (tensor(0.0221), tensor(-0.0392))\n",
      "Range difference: tensor(0.0613)\n",
      "Range: (tensor(0.1013), tensor(-0.1062))\n",
      "Range difference: tensor(0.2076)\n",
      "Range: (tensor(0.0605), tensor(-0.1210))\n",
      "Range difference: tensor(0.1815)\n",
      "Range: (tensor(0.0144), tensor(-0.1059))\n",
      "Range difference: tensor(0.1203)\n",
      "Range: (tensor(0.2654), tensor(-0.1192))\n",
      "Range difference: tensor(0.3846)\n",
      "Range: (tensor(0.0990), tensor(-0.1711))\n",
      "Range difference: tensor(0.2701)\n",
      "Range: (tensor(0.1594), tensor(-0.0584))\n",
      "Range difference: tensor(0.2179)\n",
      "Range: (tensor(0.0534), tensor(-0.0632))\n",
      "Range difference: tensor(0.1165)\n",
      "Range: (tensor(0.1923), tensor(-0.0318))\n",
      "Range difference: tensor(0.2241)\n",
      "Range: (tensor(0.0457), tensor(-0.0047))\n",
      "Range difference: tensor(0.0504)\n",
      "Range: (tensor(0.0828), tensor(-0.0197))\n",
      "Range difference: tensor(0.1025)\n",
      "Range: (tensor(0.1394), tensor(0.0037))\n",
      "Range difference: tensor(0.1357)\n",
      "Range: (tensor(0.0791), tensor(-0.1638))\n",
      "Range difference: tensor(0.2429)\n",
      "Range: (tensor(0.0630), tensor(-0.0043))\n",
      "Range difference: tensor(0.0673)\n",
      "Range: (tensor(0.0762), tensor(-0.1276))\n",
      "Range difference: tensor(0.2038)\n",
      "Range: (tensor(0.0043), tensor(-0.0632))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0671), tensor(-0.0751))\n",
      "Range difference: tensor(0.1422)\n",
      "Range: (tensor(0.0807), tensor(-0.0530))\n",
      "Range difference: tensor(0.1337)\n",
      "Range: (tensor(0.1020), tensor(-0.3642))\n",
      "Range difference: tensor(0.4662)\n",
      "Range: (tensor(0.0459), tensor(-0.0241))\n",
      "Range difference: tensor(0.0700)\n",
      "Range: (tensor(0.0747), tensor(-0.0613))\n",
      "Range difference: tensor(0.1360)\n",
      "Range: (tensor(0.0588), tensor(-0.0240))\n",
      "Range difference: tensor(0.0828)\n",
      "Range: (tensor(0.3924), tensor(-0.0422))\n",
      "Range difference: tensor(0.4346)\n",
      "Range: (tensor(0.0455), tensor(-0.0412))\n",
      "Range difference: tensor(0.0868)\n",
      "Range: (tensor(0.0971), tensor(-0.0582))\n",
      "Range difference: tensor(0.1553)\n",
      "Range: (tensor(0.1195), tensor(-0.0651))\n",
      "Range difference: tensor(0.1846)\n",
      "Range: (tensor(0.0311), tensor(-0.0568))\n",
      "Range difference: tensor(0.0879)\n",
      "Range: (tensor(0.0455), tensor(-0.0500))\n",
      "Range difference: tensor(0.0955)\n",
      "Range: (tensor(0.1960), tensor(-0.0623))\n",
      "Range difference: tensor(0.2584)\n",
      "Range: (tensor(0.0347), tensor(-0.1122))\n",
      "Range difference: tensor(0.1469)\n",
      "Range: (tensor(0.0584), tensor(-0.0732))\n",
      "Range difference: tensor(0.1317)\n",
      "Range: (tensor(0.2446), tensor(-0.0013))\n",
      "Range difference: tensor(0.2460)\n",
      "Range: (tensor(0.1125), tensor(-0.1167))\n",
      "Range difference: tensor(0.2293)\n",
      "Range: (tensor(0.0574), tensor(-0.2441))\n",
      "Range difference: tensor(0.3016)\n",
      "Range: (tensor(0.1497), tensor(-0.0130))\n",
      "Range difference: tensor(0.1628)\n",
      "Range: (tensor(0.0583), tensor(-0.1019))\n",
      "Range difference: tensor(0.1602)\n",
      "Range: (tensor(0.1961), tensor(-0.0845))\n",
      "Range difference: tensor(0.2806)\n",
      "Range: (tensor(0.0393), tensor(-0.1633))\n",
      "Range difference: tensor(0.2026)\n",
      "Range: (tensor(0.1840), tensor(-0.0068))\n",
      "Range difference: tensor(0.1907)\n",
      "Range: (tensor(0.2524), tensor(-0.0874))\n",
      "Range difference: tensor(0.3398)\n",
      "Range: (tensor(0.0637), tensor(-0.0637))\n",
      "Range difference: tensor(0.1274)\n",
      "Range: (tensor(0.0286), tensor(-0.1676))\n",
      "Range difference: tensor(0.1962)\n",
      "Range: (tensor(0.0368), tensor(-0.0502))\n",
      "Range difference: tensor(0.0870)\n",
      "Range: (tensor(0.0105), tensor(-0.1345))\n",
      "Range difference: tensor(0.1451)\n",
      "Range: (tensor(0.1135), tensor(-0.4133))\n",
      "Range difference: tensor(0.5267)\n",
      "Range: (tensor(0.0081), tensor(-0.1520))\n",
      "Range difference: tensor(0.1601)\n",
      "Range: (tensor(0.0550), tensor(-0.0395))\n",
      "Range difference: tensor(0.0945)\n",
      "Range: (tensor(0.1065), tensor(-0.0799))\n",
      "Range difference: tensor(0.1864)\n",
      "Range: (tensor(0.3127), tensor(-0.0305))\n",
      "Range difference: tensor(0.3432)\n",
      "Range: (tensor(0.0682), tensor(-0.0478))\n",
      "Range difference: tensor(0.1160)\n",
      "Range: (tensor(0.0221), tensor(-0.0383))\n",
      "Range difference: tensor(0.0604)\n",
      "Range: (tensor(0.0999), tensor(-0.1059))\n",
      "Range difference: tensor(0.2058)\n",
      "Range: (tensor(0.0630), tensor(-0.1217))\n",
      "Range difference: tensor(0.1847)\n",
      "Range: (tensor(0.0142), tensor(-0.1077))\n",
      "Range difference: tensor(0.1219)\n",
      "Range: (tensor(0.2658), tensor(-0.1199))\n",
      "Range difference: tensor(0.3857)\n",
      "Range: (tensor(0.0989), tensor(-0.1698))\n",
      "Range difference: tensor(0.2687)\n",
      "Range: (tensor(0.1571), tensor(-0.0575))\n",
      "Range difference: tensor(0.2145)\n",
      "Range: (tensor(0.0534), tensor(-0.0646))\n",
      "Range difference: tensor(0.1180)\n",
      "Range: (tensor(0.1907), tensor(-0.0305))\n",
      "Range difference: tensor(0.2211)\n",
      "Range: (tensor(0.0453), tensor(-0.0052))\n",
      "Range difference: tensor(0.0504)\n",
      "Range: (tensor(0.0837), tensor(-0.0194))\n",
      "Range difference: tensor(0.1032)\n",
      "Range: (tensor(0.1395), tensor(0.0039))\n",
      "Range difference: tensor(0.1356)\n",
      "Range: (tensor(0.0795), tensor(-0.1644))\n",
      "Range difference: tensor(0.2439)\n",
      "Range: (tensor(0.0638), tensor(-0.0034))\n",
      "Range difference: tensor(0.0671)\n",
      "Range: (tensor(0.0744), tensor(-0.1261))\n",
      "Range difference: tensor(0.2005)\n",
      "Range: (tensor(0.0045), tensor(-0.0630))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0691), tensor(-0.0753))\n",
      "Range difference: tensor(0.1444)\n",
      "Range: (tensor(0.0792), tensor(-0.0543))\n",
      "Range difference: tensor(0.1335)\n",
      "Range: (tensor(0.1014), tensor(-0.3640))\n",
      "Range difference: tensor(0.4654)\n",
      "Range: (tensor(0.0459), tensor(-0.0230))\n",
      "Range difference: tensor(0.0690)\n",
      "Range: (tensor(0.0745), tensor(-0.0608))\n",
      "Range difference: tensor(0.1353)\n",
      "Range: (tensor(0.0598), tensor(-0.0256))\n",
      "Range difference: tensor(0.0853)\n",
      "Range: (tensor(0.3942), tensor(-0.0442))\n",
      "Range difference: tensor(0.4385)\n",
      "Range: (tensor(0.0452), tensor(-0.0395))\n",
      "Range difference: tensor(0.0847)\n",
      "Range: (tensor(0.0962), tensor(-0.0580))\n",
      "Range difference: tensor(0.1542)\n",
      "Range: (tensor(0.1188), tensor(-0.0660))\n",
      "Range difference: tensor(0.1849)\n",
      "Range: (tensor(0.0299), tensor(-0.0573))\n",
      "Range difference: tensor(0.0872)\n",
      "Range: (tensor(0.0461), tensor(-0.0506))\n",
      "Range difference: tensor(0.0967)\n",
      "Range: (tensor(0.1960), tensor(-0.0625))\n",
      "Range difference: tensor(0.2585)\n",
      "Range: (tensor(0.0347), tensor(-0.1128))\n",
      "Range difference: tensor(0.1475)\n",
      "Range: (tensor(0.0608), tensor(-0.0713))\n",
      "Range difference: tensor(0.1320)\n",
      "Range: (tensor(0.2428), tensor(-0.0021))\n",
      "Range difference: tensor(0.2449)\n",
      "Range: (tensor(0.1128), tensor(-0.1154))\n",
      "Range difference: tensor(0.2283)\n",
      "Range: (tensor(0.0566), tensor(-0.2433))\n",
      "Range difference: tensor(0.3000)\n",
      "Range: (tensor(0.1490), tensor(-0.0141))\n",
      "Range difference: tensor(0.1631)\n",
      "Range: (tensor(0.0592), tensor(-0.1040))\n",
      "Range difference: tensor(0.1632)\n",
      "Range: (tensor(0.1976), tensor(-0.0841))\n",
      "Range difference: tensor(0.2817)\n",
      "Range: (tensor(0.0388), tensor(-0.1643))\n",
      "Range difference: tensor(0.2031)\n",
      "Range: (tensor(0.1826), tensor(-0.0088))\n",
      "Range difference: tensor(0.1915)\n",
      "Range: (tensor(0.2540), tensor(-0.0871))\n",
      "Range difference: tensor(0.3411)\n",
      "Range: (tensor(0.0631), tensor(-0.0647))\n",
      "Range difference: tensor(0.1278)\n",
      "Range: (tensor(0.0270), tensor(-0.1680))\n",
      "Range difference: tensor(0.1950)\n",
      "Range: (tensor(0.0341), tensor(-0.0493))\n",
      "Range difference: tensor(0.0835)\n",
      "Range: (tensor(0.0093), tensor(-0.1355))\n",
      "Range difference: tensor(0.1447)\n",
      "Range: (tensor(0.1143), tensor(-0.4111))\n",
      "Range difference: tensor(0.5254)\n",
      "Range: (tensor(0.0078), tensor(-0.1525))\n",
      "Range difference: tensor(0.1603)\n",
      "Range: (tensor(0.0551), tensor(-0.0400))\n",
      "Range difference: tensor(0.0951)\n",
      "Range: (tensor(0.1071), tensor(-0.0801))\n",
      "Range difference: tensor(0.1872)\n",
      "Range: (tensor(0.3134), tensor(-0.0288))\n",
      "Range difference: tensor(0.3422)\n",
      "Range: (tensor(0.0676), tensor(-0.0465))\n",
      "Range difference: tensor(0.1141)\n",
      "Range: (tensor(0.0206), tensor(-0.0374))\n",
      "Range difference: tensor(0.0580)\n",
      "Range: (tensor(0.0991), tensor(-0.1059))\n",
      "Range difference: tensor(0.2051)\n",
      "Range: (tensor(0.0618), tensor(-0.1214))\n",
      "Range difference: tensor(0.1832)\n",
      "Range: (tensor(0.0158), tensor(-0.1081))\n",
      "Range difference: tensor(0.1239)\n",
      "Range: (tensor(0.2645), tensor(-0.1200))\n",
      "Range difference: tensor(0.3845)\n",
      "Range: (tensor(0.1008), tensor(-0.1701))\n",
      "Range difference: tensor(0.2710)\n",
      "Range: (tensor(0.1587), tensor(-0.0571))\n",
      "Range difference: tensor(0.2157)\n",
      "Range: (tensor(0.0522), tensor(-0.0622))\n",
      "Range difference: tensor(0.1144)\n",
      "Range: (tensor(0.1907), tensor(-0.0321))\n",
      "Range difference: tensor(0.2228)\n",
      "Range: (tensor(0.0435), tensor(-0.0043))\n",
      "Range difference: tensor(0.0478)\n",
      "Range: (tensor(0.0823), tensor(-0.0206))\n",
      "Range difference: tensor(0.1029)\n",
      "Range: (tensor(0.1396), tensor(0.0030))\n",
      "Range difference: tensor(0.1366)\n",
      "Range: (tensor(0.0800), tensor(-0.1653))\n",
      "Range difference: tensor(0.2453)\n",
      "Range: (tensor(0.0642), tensor(-0.0036))\n",
      "Range difference: tensor(0.0677)\n",
      "Range: (tensor(0.0745), tensor(-0.1268))\n",
      "Range difference: tensor(0.2013)\n",
      "Range: (tensor(0.0024), tensor(-0.0636))\n",
      "Range difference: tensor(0.0660)\n",
      "Range: (tensor(0.0679), tensor(-0.0761))\n",
      "Range difference: tensor(0.1440)\n",
      "Range: (tensor(0.0804), tensor(-0.0536))\n",
      "Range difference: tensor(0.1340)\n",
      "Range: (tensor(0.1027), tensor(-0.3655))\n",
      "Range difference: tensor(0.4682)\n",
      "Range: (tensor(0.0446), tensor(-0.0228))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0745), tensor(-0.0606))\n",
      "Range difference: tensor(0.1350)\n",
      "Range: (tensor(0.0605), tensor(-0.0262))\n",
      "Range difference: tensor(0.0867)\n",
      "Range: (tensor(0.3940), tensor(-0.0423))\n",
      "Range difference: tensor(0.4363)\n",
      "Range: (tensor(0.0456), tensor(-0.0413))\n",
      "Range difference: tensor(0.0869)\n",
      "Range: (tensor(0.0947), tensor(-0.0571))\n",
      "Range difference: tensor(0.1518)\n",
      "Range: (tensor(0.1188), tensor(-0.0675))\n",
      "Range difference: tensor(0.1863)\n",
      "Range: (tensor(0.0310), tensor(-0.0573))\n",
      "Range difference: tensor(0.0882)\n",
      "Range: (tensor(0.0458), tensor(-0.0510))\n",
      "Range difference: tensor(0.0968)\n",
      "Range: (tensor(0.1944), tensor(-0.0640))\n",
      "Range difference: tensor(0.2585)\n",
      "Range: (tensor(0.0357), tensor(-0.1128))\n",
      "Range difference: tensor(0.1485)\n",
      "Range: (tensor(0.0598), tensor(-0.0732))\n",
      "Range difference: tensor(0.1330)\n",
      "Range: (tensor(0.2452), tensor(-0.0021))\n",
      "Range difference: tensor(0.2473)\n",
      "Range: (tensor(0.1121), tensor(-0.1169))\n",
      "Range difference: tensor(0.2290)\n",
      "Range: (tensor(0.0576), tensor(-0.2449))\n",
      "Range difference: tensor(0.3026)\n",
      "Range: (tensor(0.1476), tensor(-0.0122))\n",
      "Range difference: tensor(0.1598)\n",
      "Range: (tensor(0.0578), tensor(-0.1031))\n",
      "Range difference: tensor(0.1610)\n",
      "Range: (tensor(0.1985), tensor(-0.0854))\n",
      "Range difference: tensor(0.2839)\n",
      "Range: (tensor(0.0395), tensor(-0.1642))\n",
      "Range difference: tensor(0.2038)\n",
      "Range: (tensor(0.1818), tensor(-0.0083))\n",
      "Range difference: tensor(0.1901)\n",
      "Range: (tensor(0.2528), tensor(-0.0867))\n",
      "Range difference: tensor(0.3395)\n",
      "Range: (tensor(0.0636), tensor(-0.0655))\n",
      "Range difference: tensor(0.1291)\n",
      "Range: (tensor(0.0282), tensor(-0.1684))\n",
      "Range difference: tensor(0.1966)\n",
      "Range: (tensor(0.0350), tensor(-0.0504))\n",
      "Range difference: tensor(0.0854)\n",
      "Range: (tensor(0.0092), tensor(-0.1339))\n",
      "Range difference: tensor(0.1432)\n",
      "Range: (tensor(0.1122), tensor(-0.4118))\n",
      "Range difference: tensor(0.5241)\n",
      "Range: (tensor(0.0078), tensor(-0.1523))\n",
      "Range difference: tensor(0.1601)\n",
      "Range: (tensor(0.0546), tensor(-0.0397))\n",
      "Range difference: tensor(0.0943)\n",
      "Range: (tensor(0.1073), tensor(-0.0802))\n",
      "Range difference: tensor(0.1874)\n",
      "Range: (tensor(0.3129), tensor(-0.0296))\n",
      "Range difference: tensor(0.3425)\n",
      "Range: (tensor(0.0688), tensor(-0.0480))\n",
      "Range difference: tensor(0.1168)\n",
      "Range: (tensor(0.0216), tensor(-0.0369))\n",
      "Range difference: tensor(0.0586)\n",
      "Range: (tensor(0.1000), tensor(-0.1056))\n",
      "Range difference: tensor(0.2056)\n",
      "Range: (tensor(0.0623), tensor(-0.1222))\n",
      "Range difference: tensor(0.1845)\n",
      "Range: (tensor(0.0156), tensor(-0.1082))\n",
      "Range difference: tensor(0.1238)\n",
      "Range: (tensor(0.2664), tensor(-0.1187))\n",
      "Range difference: tensor(0.3851)\n",
      "Range: (tensor(0.1012), tensor(-0.1704))\n",
      "Range difference: tensor(0.2716)\n",
      "Range: (tensor(0.1578), tensor(-0.0591))\n",
      "Range difference: tensor(0.2169)\n",
      "Range: (tensor(0.0547), tensor(-0.0645))\n",
      "Range difference: tensor(0.1192)\n",
      "Range: (tensor(0.1904), tensor(-0.0316))\n",
      "Range difference: tensor(0.2220)\n",
      "Range: (tensor(0.0451), tensor(-0.0060))\n",
      "Range difference: tensor(0.0510)\n",
      "Range: (tensor(0.0834), tensor(-0.0215))\n",
      "Range difference: tensor(0.1049)\n",
      "Range: (tensor(0.1390), tensor(0.0034))\n",
      "Range difference: tensor(0.1357)\n",
      "Range: (tensor(0.0798), tensor(-0.1662))\n",
      "Range difference: tensor(0.2461)\n",
      "Range: (tensor(0.0640), tensor(-0.0045))\n",
      "Range difference: tensor(0.0684)\n",
      "Range: (tensor(0.0766), tensor(-0.1278))\n",
      "Range difference: tensor(0.2045)\n",
      "Range: (tensor(0.0046), tensor(-0.0645))\n",
      "Range difference: tensor(0.0691)\n",
      "Range: (tensor(0.0697), tensor(-0.0762))\n",
      "Range difference: tensor(0.1458)\n",
      "Range: (tensor(0.0800), tensor(-0.0533))\n",
      "Range difference: tensor(0.1333)\n",
      "Range: (tensor(0.1017), tensor(-0.3650))\n",
      "Range difference: tensor(0.4667)\n",
      "Range: (tensor(0.0455), tensor(-0.0230))\n",
      "Range difference: tensor(0.0684)\n",
      "Range: (tensor(0.0760), tensor(-0.0605))\n",
      "Range difference: tensor(0.1366)\n",
      "Range: (tensor(0.0596), tensor(-0.0264))\n",
      "Range difference: tensor(0.0860)\n",
      "Range: (tensor(0.3931), tensor(-0.0449))\n",
      "Range difference: tensor(0.4380)\n",
      "Range: (tensor(0.0459), tensor(-0.0404))\n",
      "Range difference: tensor(0.0863)\n",
      "Range: (tensor(0.0969), tensor(-0.0592))\n",
      "Range difference: tensor(0.1561)\n",
      "Range: (tensor(0.1186), tensor(-0.0655))\n",
      "Range difference: tensor(0.1841)\n",
      "Range: (tensor(0.0288), tensor(-0.0567))\n",
      "Range difference: tensor(0.0855)\n",
      "Range: (tensor(0.0453), tensor(-0.0490))\n",
      "Range difference: tensor(0.0943)\n",
      "Range: (tensor(0.1940), tensor(-0.0626))\n",
      "Range difference: tensor(0.2566)\n",
      "Range: (tensor(0.0344), tensor(-0.1109))\n",
      "Range difference: tensor(0.1453)\n",
      "Range: (tensor(0.0593), tensor(-0.0714))\n",
      "Range difference: tensor(0.1308)\n",
      "Range: (tensor(0.2449), tensor(-0.0010))\n",
      "Range difference: tensor(0.2459)\n",
      "Range: (tensor(0.1133), tensor(-0.1155))\n",
      "Range difference: tensor(0.2287)\n",
      "Range: (tensor(0.0562), tensor(-0.2437))\n",
      "Range difference: tensor(0.2999)\n",
      "Range: (tensor(0.1494), tensor(-0.0138))\n",
      "Range difference: tensor(0.1632)\n",
      "Range: (tensor(0.0583), tensor(-0.1016))\n",
      "Range difference: tensor(0.1599)\n",
      "Range: (tensor(0.1970), tensor(-0.0848))\n",
      "Range difference: tensor(0.2818)\n",
      "Range: (tensor(0.0399), tensor(-0.1639))\n",
      "Range difference: tensor(0.2039)\n",
      "Range: (tensor(0.1828), tensor(-0.0063))\n",
      "Range difference: tensor(0.1891)\n",
      "Range: (tensor(0.2533), tensor(-0.0890))\n",
      "Range difference: tensor(0.3422)\n",
      "Range: (tensor(0.0634), tensor(-0.0654))\n",
      "Range difference: tensor(0.1288)\n",
      "Range: (tensor(0.0274), tensor(-0.1685))\n",
      "Range difference: tensor(0.1959)\n",
      "Range: (tensor(0.0345), tensor(-0.0509))\n",
      "Range difference: tensor(0.0854)\n",
      "Range: (tensor(0.0084), tensor(-0.1345))\n",
      "Range difference: tensor(0.1428)\n",
      "Range: (tensor(0.1131), tensor(-0.4134))\n",
      "Range difference: tensor(0.5265)\n",
      "Range: (tensor(0.0080), tensor(-0.1513))\n",
      "Range difference: tensor(0.1593)\n",
      "Range: (tensor(0.0550), tensor(-0.0418))\n",
      "Range difference: tensor(0.0969)\n",
      "Range: (tensor(0.1073), tensor(-0.0802))\n",
      "Range difference: tensor(0.1874)\n",
      "Range: (tensor(0.3150), tensor(-0.0310))\n",
      "Range difference: tensor(0.3460)\n",
      "Range: (tensor(0.0696), tensor(-0.0488))\n",
      "Range difference: tensor(0.1184)\n",
      "Range: (tensor(0.0213), tensor(-0.0373))\n",
      "Range difference: tensor(0.0586)\n",
      "Range: (tensor(0.1008), tensor(-0.1049))\n",
      "Range difference: tensor(0.2057)\n",
      "Range: (tensor(0.0621), tensor(-0.1211))\n",
      "Range difference: tensor(0.1832)\n",
      "Range: (tensor(0.0139), tensor(-0.1076))\n",
      "Range difference: tensor(0.1215)\n",
      "Range: (tensor(0.2639), tensor(-0.1177))\n",
      "Range difference: tensor(0.3816)\n",
      "Range: (tensor(0.1001), tensor(-0.1707))\n",
      "Range difference: tensor(0.2709)\n",
      "Range: (tensor(0.1575), tensor(-0.0582))\n",
      "Range difference: tensor(0.2157)\n",
      "Range: (tensor(0.0528), tensor(-0.0619))\n",
      "Range difference: tensor(0.1147)\n",
      "Range: (tensor(0.1913), tensor(-0.0319))\n",
      "Range difference: tensor(0.2231)\n",
      "Range: (tensor(0.0444), tensor(-0.0064))\n",
      "Range difference: tensor(0.0509)\n",
      "Range: (tensor(0.0838), tensor(-0.0196))\n",
      "Range difference: tensor(0.1034)\n",
      "Range: (tensor(0.1410), tensor(0.0039))\n",
      "Range difference: tensor(0.1372)\n",
      "Range: (tensor(0.0796), tensor(-0.1659))\n",
      "Range difference: tensor(0.2454)\n",
      "Range: (tensor(0.0643), tensor(-0.0047))\n",
      "Range difference: tensor(0.0689)\n",
      "Range: (tensor(0.0758), tensor(-0.1268))\n",
      "Range difference: tensor(0.2026)\n",
      "Range: (tensor(0.0034), tensor(-0.0644))\n",
      "Range difference: tensor(0.0678)\n",
      "Range: (tensor(0.0685), tensor(-0.0759))\n",
      "Range difference: tensor(0.1444)\n",
      "Range: (tensor(0.0795), tensor(-0.0524))\n",
      "Range difference: tensor(0.1320)\n",
      "Range: (tensor(0.1033), tensor(-0.3659))\n",
      "Range difference: tensor(0.4693)\n",
      "Range: (tensor(0.0445), tensor(-0.0244))\n",
      "Range difference: tensor(0.0689)\n",
      "Range: (tensor(0.0748), tensor(-0.0611))\n",
      "Range difference: tensor(0.1359)\n",
      "Range: (tensor(0.0593), tensor(-0.0265))\n",
      "Range difference: tensor(0.0857)\n",
      "Range: (tensor(0.3940), tensor(-0.0430))\n",
      "Range difference: tensor(0.4370)\n",
      "Range: (tensor(0.0458), tensor(-0.0414))\n",
      "Range difference: tensor(0.0871)\n",
      "Range: (tensor(0.0959), tensor(-0.0584))\n",
      "Range difference: tensor(0.1543)\n",
      "Range: (tensor(0.1170), tensor(-0.0674))\n",
      "Range difference: tensor(0.1844)\n",
      "Range: (tensor(0.0308), tensor(-0.0567))\n",
      "Range difference: tensor(0.0875)\n",
      "Range: (tensor(0.0459), tensor(-0.0498))\n",
      "Range difference: tensor(0.0956)\n",
      "Range: (tensor(0.1936), tensor(-0.0628))\n",
      "Range difference: tensor(0.2564)\n",
      "Range: (tensor(0.0349), tensor(-0.1127))\n",
      "Range difference: tensor(0.1476)\n",
      "Range: (tensor(0.0592), tensor(-0.0730))\n",
      "Range difference: tensor(0.1323)\n",
      "Range: (tensor(0.2440), tensor(-0.0022))\n",
      "Range difference: tensor(0.2462)\n",
      "Range: (tensor(0.1119), tensor(-0.1176))\n",
      "Range difference: tensor(0.2294)\n",
      "Range: (tensor(0.0555), tensor(-0.2440))\n",
      "Range difference: tensor(0.2995)\n",
      "Range: (tensor(0.1491), tensor(-0.0137))\n",
      "Range difference: tensor(0.1627)\n",
      "Range: (tensor(0.0583), tensor(-0.1023))\n",
      "Range difference: tensor(0.1606)\n",
      "Range: (tensor(0.1958), tensor(-0.0840))\n",
      "Range difference: tensor(0.2798)\n",
      "Range: (tensor(0.0394), tensor(-0.1633))\n",
      "Range difference: tensor(0.2027)\n",
      "Range: (tensor(0.1822), tensor(-0.0068))\n",
      "Range difference: tensor(0.1890)\n",
      "Range: (tensor(0.2526), tensor(-0.0887))\n",
      "Range difference: tensor(0.3413)\n",
      "Range: (tensor(0.0641), tensor(-0.0649))\n",
      "Range difference: tensor(0.1290)\n",
      "Range: (tensor(0.0290), tensor(-0.1693))\n",
      "Range difference: tensor(0.1983)\n",
      "Range: (tensor(0.0356), tensor(-0.0498))\n",
      "Range difference: tensor(0.0854)\n",
      "Range: (tensor(0.0081), tensor(-0.1336))\n",
      "Range difference: tensor(0.1416)\n",
      "Range: (tensor(0.1128), tensor(-0.4130))\n",
      "Range difference: tensor(0.5258)\n",
      "Range: (tensor(0.0082), tensor(-0.1511))\n",
      "Range difference: tensor(0.1594)\n",
      "Range: (tensor(0.0550), tensor(-0.0411))\n",
      "Range difference: tensor(0.0961)\n",
      "Range: (tensor(0.1065), tensor(-0.0778))\n",
      "Range difference: tensor(0.1843)\n",
      "Range: (tensor(0.3131), tensor(-0.0287))\n",
      "Range difference: tensor(0.3418)\n",
      "Range: (tensor(0.0673), tensor(-0.0486))\n",
      "Range difference: tensor(0.1159)\n",
      "Range: (tensor(0.0222), tensor(-0.0394))\n",
      "Range difference: tensor(0.0616)\n",
      "Range: (tensor(0.0988), tensor(-0.1054))\n",
      "Range difference: tensor(0.2043)\n",
      "Range: (tensor(0.0621), tensor(-0.1216))\n",
      "Range difference: tensor(0.1837)\n",
      "Range: (tensor(0.0154), tensor(-0.1066))\n",
      "Range difference: tensor(0.1220)\n",
      "Range: (tensor(0.2654), tensor(-0.1177))\n",
      "Range difference: tensor(0.3832)\n",
      "Range: (tensor(0.0987), tensor(-0.1711))\n",
      "Range difference: tensor(0.2698)\n",
      "Range: (tensor(0.1597), tensor(-0.0582))\n",
      "Range difference: tensor(0.2178)\n",
      "Range: (tensor(0.0547), tensor(-0.0637))\n",
      "Range difference: tensor(0.1185)\n",
      "Range: (tensor(0.1906), tensor(-0.0328))\n",
      "Range difference: tensor(0.2234)\n",
      "Range: (tensor(0.0445), tensor(-0.0058))\n",
      "Range difference: tensor(0.0504)\n",
      "Range: (tensor(0.0838), tensor(-0.0195))\n",
      "Range difference: tensor(0.1033)\n",
      "Range: (tensor(0.1403), tensor(0.0035))\n",
      "Range difference: tensor(0.1368)\n",
      "Range: (tensor(0.0817), tensor(-0.1641))\n",
      "Range difference: tensor(0.2459)\n",
      "Range: (tensor(0.0628), tensor(-0.0054))\n",
      "Range difference: tensor(0.0681)\n",
      "Range: (tensor(0.0763), tensor(-0.1280))\n",
      "Range difference: tensor(0.2044)\n",
      "Range: (tensor(0.0044), tensor(-0.0624))\n",
      "Range difference: tensor(0.0668)\n",
      "Range: (tensor(0.0682), tensor(-0.0758))\n",
      "Range difference: tensor(0.1440)\n",
      "Range: (tensor(0.0786), tensor(-0.0523))\n",
      "Range difference: tensor(0.1309)\n",
      "Range: (tensor(0.1022), tensor(-0.3644))\n",
      "Range difference: tensor(0.4665)\n",
      "Range: (tensor(0.0443), tensor(-0.0242))\n",
      "Range difference: tensor(0.0685)\n",
      "Range: (tensor(0.0767), tensor(-0.0611))\n",
      "Range difference: tensor(0.1377)\n",
      "Range: (tensor(0.0602), tensor(-0.0262))\n",
      "Range difference: tensor(0.0864)\n",
      "Range: (tensor(0.3933), tensor(-0.0422))\n",
      "Range difference: tensor(0.4356)\n",
      "Range: (tensor(0.0477), tensor(-0.0400))\n",
      "Range difference: tensor(0.0877)\n",
      "Range: (tensor(0.0971), tensor(-0.0572))\n",
      "Range difference: tensor(0.1543)\n",
      "Range: (tensor(0.1195), tensor(-0.0675))\n",
      "Range difference: tensor(0.1871)\n",
      "Range: (tensor(0.0306), tensor(-0.0554))\n",
      "Range difference: tensor(0.0861)\n",
      "Range: (tensor(0.0466), tensor(-0.0503))\n",
      "Range difference: tensor(0.0969)\n",
      "Range: (tensor(0.1942), tensor(-0.0640))\n",
      "Range difference: tensor(0.2582)\n",
      "Range: (tensor(0.0351), tensor(-0.1104))\n",
      "Range difference: tensor(0.1455)\n",
      "Range: (tensor(0.0592), tensor(-0.0711))\n",
      "Range difference: tensor(0.1304)\n",
      "Range: (tensor(0.2452), tensor(4.6522e-05))\n",
      "Range difference: tensor(0.2451)\n",
      "Range: (tensor(0.1132), tensor(-0.1176))\n",
      "Range difference: tensor(0.2308)\n",
      "Range: (tensor(0.1019), tensor(-0.0559))\n",
      "Range difference: tensor(0.1578)\n",
      "Range: (tensor(0.0128), tensor(-0.0609))\n",
      "Range difference: tensor(0.0737)\n",
      "Range: (tensor(0.0261), tensor(-0.0590))\n",
      "Range difference: tensor(0.0851)\n",
      "Range: (tensor(0.0841), tensor(-0.0954))\n",
      "Range difference: tensor(0.1795)\n",
      "Range: (tensor(0.0658), tensor(-0.0379))\n",
      "Range difference: tensor(0.1037)\n",
      "Range: (tensor(0.0088), tensor(-0.1010))\n",
      "Range difference: tensor(0.1098)\n",
      "Range: (tensor(0.0885), tensor(-0.0933))\n",
      "Range difference: tensor(0.1818)\n",
      "Range: (tensor(0.0214), tensor(-0.0631))\n",
      "Range difference: tensor(0.0845)\n",
      "Range: (tensor(0.0455), tensor(-0.0279))\n",
      "Range difference: tensor(0.0734)\n",
      "Range: (tensor(0.0518), tensor(-0.0367))\n",
      "Range difference: tensor(0.0885)\n",
      "Range: (tensor(0.0401), tensor(-0.0097))\n",
      "Range difference: tensor(0.0498)\n",
      "Range: (tensor(0.1144), tensor(-0.1130))\n",
      "Range difference: tensor(0.2273)\n",
      "Range: (tensor(0.0449), tensor(-0.0065))\n",
      "Range difference: tensor(0.0514)\n",
      "Range: (tensor(0.0415), tensor(-0.0542))\n",
      "Range difference: tensor(0.0957)\n",
      "Range: (tensor(0.0780), tensor(-0.1068))\n",
      "Range difference: tensor(0.1847)\n",
      "Range: (tensor(0.0298), tensor(-0.1141))\n",
      "Range difference: tensor(0.1439)\n",
      "Range: (tensor(0.0468), tensor(-0.0514))\n",
      "Range difference: tensor(0.0982)\n",
      "Range: (tensor(0.0344), tensor(-0.0225))\n",
      "Range difference: tensor(0.0570)\n",
      "Range: (tensor(0.0535), tensor(-0.0991))\n",
      "Range difference: tensor(0.1526)\n",
      "Range: (tensor(0.0625), tensor(-0.0624))\n",
      "Range difference: tensor(0.1249)\n",
      "Range: (tensor(0.0612), tensor(-0.0162))\n",
      "Range difference: tensor(0.0774)\n",
      "Range: (tensor(0.1202), tensor(-0.0922))\n",
      "Range difference: tensor(0.2124)\n",
      "Range: (tensor(0.0383), tensor(-0.1007))\n",
      "Range difference: tensor(0.1389)\n",
      "Range: (tensor(0.0570), tensor(-0.0946))\n",
      "Range difference: tensor(0.1517)\n",
      "Range: (tensor(0.0566), tensor(-0.0543))\n",
      "Range difference: tensor(0.1109)\n",
      "Range: (tensor(0.0318), tensor(-0.0642))\n",
      "Range difference: tensor(0.0960)\n",
      "Range: (tensor(0.0055), tensor(-0.0442))\n",
      "Range difference: tensor(0.0498)\n",
      "Range: (tensor(0.0148), tensor(-0.0816))\n",
      "Range difference: tensor(0.0964)\n",
      "Range: (tensor(-0.0022), tensor(-0.1096))\n",
      "Range difference: tensor(0.1074)\n",
      "Range: (tensor(0.0593), tensor(-0.0817))\n",
      "Range difference: tensor(0.1410)\n",
      "Range: (tensor(0.0031), tensor(-0.0643))\n",
      "Range difference: tensor(0.0674)\n",
      "Range: (tensor(0.0921), tensor(-0.0753))\n",
      "Range difference: tensor(0.1674)\n",
      "Range: (tensor(0.0644), tensor(-0.0043))\n",
      "Range difference: tensor(0.0686)\n",
      "Range: (tensor(0.0743), tensor(-0.0671))\n",
      "Range difference: tensor(0.1415)\n",
      "Range: (tensor(0.0543), tensor(-0.0785))\n",
      "Range difference: tensor(0.1328)\n",
      "Range: (tensor(0.1398), tensor(-0.1025))\n",
      "Range difference: tensor(0.2423)\n",
      "Range: (tensor(0.0219), tensor(-0.0438))\n",
      "Range difference: tensor(0.0657)\n",
      "Range: (tensor(0.0602), tensor(-0.0689))\n",
      "Range difference: tensor(0.1291)\n",
      "Range: (tensor(0.0265), tensor(-0.0601))\n",
      "Range difference: tensor(0.0866)\n",
      "Range: (tensor(0.0448), tensor(-0.1067))\n",
      "Range difference: tensor(0.1515)\n",
      "Range: (tensor(0.0396), tensor(-0.0465))\n",
      "Range difference: tensor(0.0861)\n",
      "Range: (tensor(0.0265), tensor(-0.0953))\n",
      "Range difference: tensor(0.1218)\n",
      "Range: (tensor(0.0144), tensor(-0.1195))\n",
      "Range difference: tensor(0.1339)\n",
      "Range: (tensor(0.0422), tensor(-0.0288))\n",
      "Range difference: tensor(0.0711)\n",
      "Range: (tensor(0.0503), tensor(-0.0151))\n",
      "Range difference: tensor(0.0654)\n",
      "Range: (tensor(0.0628), tensor(-0.0989))\n",
      "Range difference: tensor(0.1616)\n",
      "Range: (tensor(0.0628), tensor(-0.0341))\n",
      "Range difference: tensor(0.0968)\n",
      "Range: (tensor(0.0728), tensor(-0.0586))\n",
      "Range difference: tensor(0.1315)\n",
      "Range: (tensor(0.0021), tensor(-0.1199))\n",
      "Range difference: tensor(0.1220)\n",
      "Range: (tensor(0.0478), tensor(-0.1114))\n",
      "Range difference: tensor(0.1591)\n",
      "Range: (tensor(0.1023), tensor(-0.0570))\n",
      "Range difference: tensor(0.1593)\n",
      "Range: (tensor(0.0145), tensor(-0.0613))\n",
      "Range difference: tensor(0.0758)\n",
      "Range: (tensor(0.0252), tensor(-0.0578))\n",
      "Range difference: tensor(0.0830)\n",
      "Range: (tensor(0.0828), tensor(-0.0949))\n",
      "Range difference: tensor(0.1777)\n",
      "Range: (tensor(0.0652), tensor(-0.0390))\n",
      "Range difference: tensor(0.1043)\n",
      "Range: (tensor(0.0069), tensor(-0.1001))\n",
      "Range difference: tensor(0.1070)\n",
      "Range: (tensor(0.0878), tensor(-0.0914))\n",
      "Range difference: tensor(0.1792)\n",
      "Range: (tensor(0.0231), tensor(-0.0627))\n",
      "Range difference: tensor(0.0859)\n",
      "Range: (tensor(0.0439), tensor(-0.0280))\n",
      "Range difference: tensor(0.0719)\n",
      "Range: (tensor(0.0509), tensor(-0.0355))\n",
      "Range difference: tensor(0.0863)\n",
      "Range: (tensor(0.0382), tensor(-0.0078))\n",
      "Range difference: tensor(0.0461)\n",
      "Range: (tensor(0.1151), tensor(-0.1146))\n",
      "Range difference: tensor(0.2297)\n",
      "Range: (tensor(0.0447), tensor(-0.0087))\n",
      "Range difference: tensor(0.0533)\n",
      "Range: (tensor(0.0394), tensor(-0.0538))\n",
      "Range difference: tensor(0.0932)\n",
      "Range: (tensor(0.0804), tensor(-0.1079))\n",
      "Range difference: tensor(0.1882)\n",
      "Range: (tensor(0.0308), tensor(-0.1153))\n",
      "Range difference: tensor(0.1461)\n",
      "Range: (tensor(0.0461), tensor(-0.0522))\n",
      "Range difference: tensor(0.0983)\n",
      "Range: (tensor(0.0317), tensor(-0.0210))\n",
      "Range difference: tensor(0.0526)\n",
      "Range: (tensor(0.0554), tensor(-0.1008))\n",
      "Range difference: tensor(0.1562)\n",
      "Range: (tensor(0.0622), tensor(-0.0624))\n",
      "Range difference: tensor(0.1246)\n",
      "Range: (tensor(0.0624), tensor(-0.0160))\n",
      "Range difference: tensor(0.0785)\n",
      "Range: (tensor(0.1193), tensor(-0.0911))\n",
      "Range difference: tensor(0.2104)\n",
      "Range: (tensor(0.0381), tensor(-0.1010))\n",
      "Range difference: tensor(0.1391)\n",
      "Range: (tensor(0.0577), tensor(-0.0964))\n",
      "Range difference: tensor(0.1542)\n",
      "Range: (tensor(0.0567), tensor(-0.0526))\n",
      "Range difference: tensor(0.1093)\n",
      "Range: (tensor(0.0320), tensor(-0.0662))\n",
      "Range difference: tensor(0.0982)\n",
      "Range: (tensor(0.0043), tensor(-0.0438))\n",
      "Range difference: tensor(0.0481)\n",
      "Range: (tensor(0.0137), tensor(-0.0832))\n",
      "Range difference: tensor(0.0969)\n",
      "Range: (tensor(-0.0031), tensor(-0.1090))\n",
      "Range difference: tensor(0.1060)\n",
      "Range: (tensor(0.0596), tensor(-0.0801))\n",
      "Range difference: tensor(0.1398)\n",
      "Range: (tensor(0.0015), tensor(-0.0623))\n",
      "Range difference: tensor(0.0639)\n",
      "Range: (tensor(0.0911), tensor(-0.0766))\n",
      "Range difference: tensor(0.1677)\n",
      "Range: (tensor(0.0630), tensor(-0.0035))\n",
      "Range difference: tensor(0.0665)\n",
      "Range: (tensor(0.0731), tensor(-0.0673))\n",
      "Range difference: tensor(0.1403)\n",
      "Range: (tensor(0.0522), tensor(-0.0808))\n",
      "Range difference: tensor(0.1330)\n",
      "Range: (tensor(0.1380), tensor(-0.1026))\n",
      "Range difference: tensor(0.2407)\n",
      "Range: (tensor(0.0206), tensor(-0.0437))\n",
      "Range difference: tensor(0.0643)\n",
      "Range: (tensor(0.0608), tensor(-0.0691))\n",
      "Range difference: tensor(0.1300)\n",
      "Range: (tensor(0.0238), tensor(-0.0601))\n",
      "Range difference: tensor(0.0839)\n",
      "Range: (tensor(0.0432), tensor(-0.1062))\n",
      "Range difference: tensor(0.1494)\n",
      "Range: (tensor(0.0399), tensor(-0.0476))\n",
      "Range difference: tensor(0.0875)\n",
      "Range: (tensor(0.0255), tensor(-0.0970))\n",
      "Range difference: tensor(0.1225)\n",
      "Range: (tensor(0.0130), tensor(-0.1168))\n",
      "Range difference: tensor(0.1299)\n",
      "Range: (tensor(0.0412), tensor(-0.0302))\n",
      "Range difference: tensor(0.0714)\n",
      "Range: (tensor(0.0489), tensor(-0.0140))\n",
      "Range difference: tensor(0.0629)\n",
      "Range: (tensor(0.0625), tensor(-0.1000))\n",
      "Range difference: tensor(0.1625)\n",
      "Range: (tensor(0.0646), tensor(-0.0355))\n",
      "Range difference: tensor(0.1000)\n",
      "Range: (tensor(0.0722), tensor(-0.0609))\n",
      "Range difference: tensor(0.1330)\n",
      "Range: (tensor(0.0025), tensor(-0.1195))\n",
      "Range difference: tensor(0.1220)\n",
      "Range: (tensor(0.0477), tensor(-0.1130))\n",
      "Range difference: tensor(0.1607)\n",
      "Range: (tensor(0.1011), tensor(-0.0580))\n",
      "Range difference: tensor(0.1591)\n",
      "Range: (tensor(0.0131), tensor(-0.0621))\n",
      "Range difference: tensor(0.0752)\n",
      "Range: (tensor(0.0254), tensor(-0.0593))\n",
      "Range difference: tensor(0.0847)\n",
      "Range: (tensor(0.0854), tensor(-0.0939))\n",
      "Range difference: tensor(0.1793)\n",
      "Range: (tensor(0.0651), tensor(-0.0395))\n",
      "Range difference: tensor(0.1047)\n",
      "Range: (tensor(0.0065), tensor(-0.1009))\n",
      "Range difference: tensor(0.1074)\n",
      "Range: (tensor(0.0872), tensor(-0.0918))\n",
      "Range difference: tensor(0.1790)\n",
      "Range: (tensor(0.0217), tensor(-0.0624))\n",
      "Range difference: tensor(0.0841)\n",
      "Range: (tensor(0.0460), tensor(-0.0283))\n",
      "Range difference: tensor(0.0743)\n",
      "Range: (tensor(0.0509), tensor(-0.0359))\n",
      "Range difference: tensor(0.0868)\n",
      "Range: (tensor(0.0375), tensor(-0.0092))\n",
      "Range difference: tensor(0.0467)\n",
      "Range: (tensor(0.1146), tensor(-0.1123))\n",
      "Range difference: tensor(0.2268)\n",
      "Range: (tensor(0.0422), tensor(-0.0068))\n",
      "Range difference: tensor(0.0490)\n",
      "Range: (tensor(0.0414), tensor(-0.0552))\n",
      "Range difference: tensor(0.0965)\n",
      "Range: (tensor(0.0798), tensor(-0.1080))\n",
      "Range difference: tensor(0.1878)\n",
      "Range: (tensor(0.0303), tensor(-0.1159))\n",
      "Range difference: tensor(0.1462)\n",
      "Range: (tensor(0.0488), tensor(-0.0514))\n",
      "Range difference: tensor(0.1002)\n",
      "Range: (tensor(0.0335), tensor(-0.0210))\n",
      "Range difference: tensor(0.0545)\n",
      "Range: (tensor(0.0560), tensor(-0.1009))\n",
      "Range difference: tensor(0.1569)\n",
      "Range: (tensor(0.0635), tensor(-0.0613))\n",
      "Range difference: tensor(0.1248)\n",
      "Range: (tensor(0.0615), tensor(-0.0159))\n",
      "Range difference: tensor(0.0774)\n",
      "Range: (tensor(0.1195), tensor(-0.0906))\n",
      "Range difference: tensor(0.2101)\n",
      "Range: (tensor(0.0363), tensor(-0.1001))\n",
      "Range difference: tensor(0.1364)\n",
      "Range: (tensor(0.0598), tensor(-0.0940))\n",
      "Range difference: tensor(0.1538)\n",
      "Range: (tensor(0.0590), tensor(-0.0543))\n",
      "Range difference: tensor(0.1132)\n",
      "Range: (tensor(0.0321), tensor(-0.0650))\n",
      "Range difference: tensor(0.0971)\n",
      "Range: (tensor(0.0054), tensor(-0.0451))\n",
      "Range difference: tensor(0.0505)\n",
      "Range: (tensor(0.0131), tensor(-0.0834))\n",
      "Range difference: tensor(0.0965)\n",
      "Range: (tensor(-0.0022), tensor(-0.1106))\n",
      "Range difference: tensor(0.1085)\n",
      "Range: (tensor(0.0580), tensor(-0.0811))\n",
      "Range difference: tensor(0.1391)\n",
      "Range: (tensor(0.0039), tensor(-0.0641))\n",
      "Range difference: tensor(0.0679)\n",
      "Range: (tensor(0.0924), tensor(-0.0768))\n",
      "Range difference: tensor(0.1692)\n",
      "Range: (tensor(0.0626), tensor(-0.0030))\n",
      "Range difference: tensor(0.0656)\n",
      "Range: (tensor(0.0733), tensor(-0.0684))\n",
      "Range difference: tensor(0.1417)\n",
      "Range: (tensor(0.0529), tensor(-0.0793))\n",
      "Range difference: tensor(0.1322)\n",
      "Range: (tensor(0.1378), tensor(-0.1030))\n",
      "Range difference: tensor(0.2408)\n",
      "Range: (tensor(0.0216), tensor(-0.0441))\n",
      "Range difference: tensor(0.0657)\n",
      "Range: (tensor(0.0616), tensor(-0.0702))\n",
      "Range difference: tensor(0.1318)\n",
      "Range: (tensor(0.0253), tensor(-0.0595))\n",
      "Range difference: tensor(0.0848)\n",
      "Range: (tensor(0.0447), tensor(-0.1065))\n",
      "Range difference: tensor(0.1512)\n",
      "Range: (tensor(0.0399), tensor(-0.0475))\n",
      "Range difference: tensor(0.0874)\n",
      "Range: (tensor(0.0241), tensor(-0.0949))\n",
      "Range difference: tensor(0.1190)\n",
      "Range: (tensor(0.0135), tensor(-0.1191))\n",
      "Range difference: tensor(0.1327)\n",
      "Range: (tensor(0.0421), tensor(-0.0297))\n",
      "Range difference: tensor(0.0718)\n",
      "Range: (tensor(0.0499), tensor(-0.0125))\n",
      "Range difference: tensor(0.0624)\n",
      "Range: (tensor(0.0632), tensor(-0.0994))\n",
      "Range difference: tensor(0.1625)\n",
      "Range: (tensor(0.0626), tensor(-0.0344))\n",
      "Range difference: tensor(0.0970)\n",
      "Range: (tensor(0.0728), tensor(-0.0587))\n",
      "Range difference: tensor(0.1315)\n",
      "Range: (tensor(0.0013), tensor(-0.1188))\n",
      "Range difference: tensor(0.1201)\n",
      "Range: (tensor(0.0461), tensor(-0.1119))\n",
      "Range difference: tensor(0.1580)\n",
      "Range: (tensor(0.1017), tensor(-0.0567))\n",
      "Range difference: tensor(0.1584)\n",
      "Range: (tensor(0.0123), tensor(-0.0605))\n",
      "Range difference: tensor(0.0728)\n",
      "Range: (tensor(0.0236), tensor(-0.0586))\n",
      "Range difference: tensor(0.0821)\n",
      "Range: (tensor(0.0828), tensor(-0.0957))\n",
      "Range difference: tensor(0.1786)\n",
      "Range: (tensor(0.0654), tensor(-0.0389))\n",
      "Range difference: tensor(0.1043)\n",
      "Range: (tensor(0.0090), tensor(-0.1016))\n",
      "Range difference: tensor(0.1106)\n",
      "Range: (tensor(0.0884), tensor(-0.0928))\n",
      "Range difference: tensor(0.1812)\n",
      "Range: (tensor(0.0236), tensor(-0.0620))\n",
      "Range difference: tensor(0.0856)\n",
      "Range: (tensor(0.0452), tensor(-0.0271))\n",
      "Range difference: tensor(0.0722)\n",
      "Range: (tensor(0.0514), tensor(-0.0355))\n",
      "Range difference: tensor(0.0869)\n",
      "Range: (tensor(0.0394), tensor(-0.0105))\n",
      "Range difference: tensor(0.0499)\n",
      "Range: (tensor(0.1165), tensor(-0.1149))\n",
      "Range difference: tensor(0.2314)\n",
      "Range: (tensor(0.0440), tensor(-0.0085))\n",
      "Range difference: tensor(0.0525)\n",
      "Range: (tensor(0.0405), tensor(-0.0543))\n",
      "Range difference: tensor(0.0948)\n",
      "Range: (tensor(0.0781), tensor(-0.1076))\n",
      "Range difference: tensor(0.1857)\n",
      "Range: (tensor(0.0301), tensor(-0.1138))\n",
      "Range difference: tensor(0.1439)\n",
      "Range: (tensor(0.0470), tensor(-0.0521))\n",
      "Range difference: tensor(0.0990)\n",
      "Range: (tensor(0.0337), tensor(-0.0210))\n",
      "Range difference: tensor(0.0547)\n",
      "Range: (tensor(0.0556), tensor(-0.1011))\n",
      "Range difference: tensor(0.1567)\n",
      "Range: (tensor(0.0639), tensor(-0.0613))\n",
      "Range difference: tensor(0.1252)\n",
      "Range: (tensor(0.0600), tensor(-0.0154))\n",
      "Range difference: tensor(0.0754)\n",
      "Range: (tensor(0.1196), tensor(-0.0918))\n",
      "Range difference: tensor(0.2114)\n",
      "Range: (tensor(0.0371), tensor(-0.0990))\n",
      "Range difference: tensor(0.1362)\n",
      "Range: (tensor(0.0594), tensor(-0.0959))\n",
      "Range difference: tensor(0.1553)\n",
      "Range: (tensor(0.0580), tensor(-0.0535))\n",
      "Range difference: tensor(0.1115)\n",
      "Range: (tensor(0.0328), tensor(-0.0652))\n",
      "Range difference: tensor(0.0980)\n",
      "Range: (tensor(0.0048), tensor(-0.0435))\n",
      "Range difference: tensor(0.0483)\n",
      "Range: (tensor(0.0147), tensor(-0.0832))\n",
      "Range difference: tensor(0.0978)\n",
      "Range: (tensor(-0.0025), tensor(-0.1102))\n",
      "Range difference: tensor(0.1077)\n",
      "Range: (tensor(0.0589), tensor(-0.0790))\n",
      "Range difference: tensor(0.1379)\n",
      "Range: (tensor(0.0038), tensor(-0.0637))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0928), tensor(-0.0757))\n",
      "Range difference: tensor(0.1685)\n",
      "Range: (tensor(0.0631), tensor(-0.0046))\n",
      "Range difference: tensor(0.0676)\n",
      "Range: (tensor(0.0754), tensor(-0.0683))\n",
      "Range difference: tensor(0.1437)\n",
      "Range: (tensor(0.0537), tensor(-0.0809))\n",
      "Range difference: tensor(0.1347)\n",
      "Range: (tensor(0.1398), tensor(-0.1026))\n",
      "Range difference: tensor(0.2425)\n",
      "Range: (tensor(0.0233), tensor(-0.0451))\n",
      "Range difference: tensor(0.0683)\n",
      "Range: (tensor(0.0617), tensor(-0.0697))\n",
      "Range difference: tensor(0.1314)\n",
      "Range: (tensor(0.0258), tensor(-0.0604))\n",
      "Range difference: tensor(0.0862)\n",
      "Range: (tensor(0.0433), tensor(-0.1050))\n",
      "Range difference: tensor(0.1482)\n",
      "Range: (tensor(0.0412), tensor(-0.0450))\n",
      "Range difference: tensor(0.0862)\n",
      "Range: (tensor(0.0246), tensor(-0.0963))\n",
      "Range difference: tensor(0.1210)\n",
      "Range: (tensor(0.0137), tensor(-0.1188))\n",
      "Range difference: tensor(0.1325)\n",
      "Range: (tensor(0.0410), tensor(-0.0299))\n",
      "Range difference: tensor(0.0709)\n",
      "Range: (tensor(0.0510), tensor(-0.0130))\n",
      "Range difference: tensor(0.0640)\n",
      "Range: (tensor(0.0630), tensor(-0.0994))\n",
      "Range difference: tensor(0.1624)\n",
      "Range: (tensor(0.0644), tensor(-0.0346))\n",
      "Range difference: tensor(0.0991)\n",
      "Range: (tensor(0.0724), tensor(-0.0603))\n",
      "Range difference: tensor(0.1327)\n",
      "Range: (tensor(0.0013), tensor(-0.1199))\n",
      "Range difference: tensor(0.1212)\n",
      "Range: (tensor(0.0464), tensor(-0.1129))\n",
      "Range difference: tensor(0.1593)\n",
      "Range: (tensor(0.1019), tensor(-0.0573))\n",
      "Range difference: tensor(0.1592)\n",
      "Range: (tensor(0.0119), tensor(-0.0609))\n",
      "Range difference: tensor(0.0728)\n",
      "Range: (tensor(0.0249), tensor(-0.0569))\n",
      "Range difference: tensor(0.0818)\n",
      "Range: (tensor(0.0852), tensor(-0.0952))\n",
      "Range difference: tensor(0.1804)\n",
      "Range: (tensor(0.0653), tensor(-0.0392))\n",
      "Range difference: tensor(0.1045)\n",
      "Range: (tensor(0.0072), tensor(-0.1019))\n",
      "Range difference: tensor(0.1091)\n",
      "Range: (tensor(0.0865), tensor(-0.0934))\n",
      "Range difference: tensor(0.1799)\n",
      "Range: (tensor(0.0211), tensor(-0.0644))\n",
      "Range difference: tensor(0.0855)\n",
      "Range: (tensor(0.0464), tensor(-0.0267))\n",
      "Range difference: tensor(0.0732)\n",
      "Range: (tensor(0.0512), tensor(-0.0359))\n",
      "Range difference: tensor(0.0871)\n",
      "Range: (tensor(0.0377), tensor(-0.0098))\n",
      "Range difference: tensor(0.0475)\n",
      "Range: (tensor(0.1160), tensor(-0.1128))\n",
      "Range difference: tensor(0.2288)\n",
      "Range: (tensor(0.0448), tensor(-0.0080))\n",
      "Range difference: tensor(0.0529)\n",
      "Range: (tensor(0.0412), tensor(-0.0548))\n",
      "Range difference: tensor(0.0960)\n",
      "Range: (tensor(0.0777), tensor(-0.1087))\n",
      "Range difference: tensor(0.1864)\n",
      "Range: (tensor(0.0312), tensor(-0.1151))\n",
      "Range difference: tensor(0.1463)\n",
      "Range: (tensor(0.0484), tensor(-0.0526))\n",
      "Range difference: tensor(0.1010)\n",
      "Range: (tensor(0.0323), tensor(-0.0219))\n",
      "Range difference: tensor(0.0542)\n",
      "Range: (tensor(0.0535), tensor(-0.1008))\n",
      "Range difference: tensor(0.1542)\n",
      "Range: (tensor(0.0621), tensor(-0.0618))\n",
      "Range difference: tensor(0.1239)\n",
      "Range: (tensor(0.0609), tensor(-0.0155))\n",
      "Range difference: tensor(0.0765)\n",
      "Range: (tensor(0.1202), tensor(-0.0922))\n",
      "Range difference: tensor(0.2124)\n",
      "Range: (tensor(0.0381), tensor(-0.0992))\n",
      "Range difference: tensor(0.1374)\n",
      "Range: (tensor(0.0597), tensor(-0.0942))\n",
      "Range difference: tensor(0.1539)\n",
      "Range: (tensor(0.0585), tensor(-0.0540))\n",
      "Range difference: tensor(0.1125)\n",
      "Range: (tensor(0.0324), tensor(-0.0638))\n",
      "Range difference: tensor(0.0962)\n",
      "Range: (tensor(0.0057), tensor(-0.0442))\n",
      "Range difference: tensor(0.0499)\n",
      "Range: (tensor(0.0151), tensor(-0.0826))\n",
      "Range difference: tensor(0.0977)\n",
      "Range: (tensor(-0.0033), tensor(-0.1101))\n",
      "Range difference: tensor(0.1068)\n",
      "Range: (tensor(0.0578), tensor(-0.0798))\n",
      "Range difference: tensor(0.1376)\n",
      "Range: (tensor(0.0024), tensor(-0.0645))\n",
      "Range difference: tensor(0.0669)\n",
      "Range: (tensor(0.0913), tensor(-0.0748))\n",
      "Range difference: tensor(0.1661)\n",
      "Range: (tensor(0.0649), tensor(-0.0044))\n",
      "Range difference: tensor(0.0693)\n",
      "Range: (tensor(0.0731), tensor(-0.0677))\n",
      "Range difference: tensor(0.1408)\n",
      "Range: (tensor(0.0541), tensor(-0.0802))\n",
      "Range difference: tensor(0.1343)\n",
      "Range: (tensor(0.1400), tensor(-0.1022))\n",
      "Range difference: tensor(0.2422)\n",
      "Range: (tensor(0.0219), tensor(-0.0434))\n",
      "Range difference: tensor(0.0653)\n",
      "Range: (tensor(0.0598), tensor(-0.0700))\n",
      "Range difference: tensor(0.1298)\n",
      "Range: (tensor(0.0259), tensor(-0.0585))\n",
      "Range difference: tensor(0.0845)\n",
      "Range: (tensor(0.0433), tensor(-0.1065))\n",
      "Range difference: tensor(0.1499)\n",
      "Range: (tensor(0.0401), tensor(-0.0466))\n",
      "Range difference: tensor(0.0867)\n",
      "Range: (tensor(0.0253), tensor(-0.0961))\n",
      "Range difference: tensor(0.1214)\n",
      "Range: (tensor(0.0132), tensor(-0.1195))\n",
      "Range difference: tensor(0.1328)\n",
      "Range: (tensor(0.0422), tensor(-0.0312))\n",
      "Range difference: tensor(0.0734)\n",
      "Range: (tensor(0.0489), tensor(-0.0131))\n",
      "Range difference: tensor(0.0620)\n",
      "Range: (tensor(0.0619), tensor(-0.1004))\n",
      "Range difference: tensor(0.1623)\n",
      "Range: (tensor(0.0641), tensor(-0.0340))\n",
      "Range difference: tensor(0.0981)\n",
      "Range: (tensor(0.0710), tensor(-0.0597))\n",
      "Range difference: tensor(0.1307)\n",
      "Range: (tensor(0.0010), tensor(-0.1194))\n",
      "Range difference: tensor(0.1205)\n",
      "Range: (tensor(0.0479), tensor(-0.1108))\n",
      "Range difference: tensor(0.1586)\n",
      "Range: (tensor(0.1004), tensor(-0.0562))\n",
      "Range difference: tensor(0.1566)\n",
      "Range: (tensor(0.0125), tensor(-0.0633))\n",
      "Range difference: tensor(0.0758)\n",
      "Range: (tensor(0.0242), tensor(-0.0579))\n",
      "Range difference: tensor(0.0822)\n",
      "Range: (tensor(0.0835), tensor(-0.0959))\n",
      "Range difference: tensor(0.1794)\n",
      "Range: (tensor(0.0659), tensor(-0.0377))\n",
      "Range difference: tensor(0.1036)\n",
      "Range: (tensor(0.0083), tensor(-0.1017))\n",
      "Range difference: tensor(0.1100)\n",
      "Range: (tensor(0.0869), tensor(-0.0939))\n",
      "Range difference: tensor(0.1808)\n",
      "Range: (tensor(0.0228), tensor(-0.0623))\n",
      "Range difference: tensor(0.0851)\n",
      "Range: (tensor(0.0445), tensor(-0.0282))\n",
      "Range difference: tensor(0.0727)\n",
      "Range: (tensor(0.0513), tensor(-0.0369))\n",
      "Range difference: tensor(0.0882)\n",
      "Range: (tensor(0.0386), tensor(-0.0106))\n",
      "Range difference: tensor(0.0492)\n",
      "Range: (tensor(0.1157), tensor(-0.1146))\n",
      "Range difference: tensor(0.2302)\n",
      "Range: (tensor(0.0437), tensor(-0.0092))\n",
      "Range difference: tensor(0.0529)\n",
      "Range: (tensor(0.0413), tensor(-0.0547))\n",
      "Range difference: tensor(0.0960)\n",
      "Range: (tensor(0.0799), tensor(-0.1080))\n",
      "Range difference: tensor(0.1879)\n",
      "Range: (tensor(0.0300), tensor(-0.1164))\n",
      "Range difference: tensor(0.1463)\n",
      "Range: (tensor(0.0470), tensor(-0.0526))\n",
      "Range difference: tensor(0.0996)\n",
      "Range: (tensor(0.0318), tensor(-0.0226))\n",
      "Range difference: tensor(0.0544)\n",
      "Range: (tensor(0.0536), tensor(-0.0993))\n",
      "Range difference: tensor(0.1530)\n",
      "Range: (tensor(0.0639), tensor(-0.0608))\n",
      "Range difference: tensor(0.1247)\n",
      "Range: (tensor(0.0615), tensor(-0.0144))\n",
      "Range difference: tensor(0.0759)\n",
      "Range: (tensor(0.1180), tensor(-0.0916))\n",
      "Range difference: tensor(0.2097)\n",
      "Range: (tensor(0.0379), tensor(-0.0988))\n",
      "Range difference: tensor(0.1366)\n",
      "Range: (tensor(0.0578), tensor(-0.0947))\n",
      "Range difference: tensor(0.1525)\n",
      "Range: (tensor(0.0564), tensor(-0.0533))\n",
      "Range difference: tensor(0.1097)\n",
      "Range: (tensor(0.0323), tensor(-0.0643))\n",
      "Range difference: tensor(0.0967)\n",
      "Range: (tensor(0.0043), tensor(-0.0437))\n",
      "Range difference: tensor(0.0480)\n",
      "Range: (tensor(0.0131), tensor(-0.0836))\n",
      "Range difference: tensor(0.0967)\n",
      "Range: (tensor(-0.0024), tensor(-0.1108))\n",
      "Range difference: tensor(0.1083)\n",
      "Range: (tensor(0.0595), tensor(-0.0799))\n",
      "Range difference: tensor(0.1394)\n",
      "Range: (tensor(0.0016), tensor(-0.0636))\n",
      "Range difference: tensor(0.0652)\n",
      "Range: (tensor(0.0912), tensor(-0.0751))\n",
      "Range difference: tensor(0.1663)\n",
      "Range: (tensor(0.0633), tensor(-0.0041))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0750), tensor(-0.0685))\n",
      "Range difference: tensor(0.1434)\n",
      "Range: (tensor(0.0527), tensor(-0.0798))\n",
      "Range difference: tensor(0.1325)\n",
      "Range: (tensor(0.1395), tensor(-0.1025))\n",
      "Range difference: tensor(0.2420)\n",
      "Range: (tensor(0.0223), tensor(-0.0437))\n",
      "Range difference: tensor(0.0660)\n",
      "Range: (tensor(0.0617), tensor(-0.0698))\n",
      "Range difference: tensor(0.1315)\n",
      "Range: (tensor(0.0250), tensor(-0.0606))\n",
      "Range difference: tensor(0.0856)\n",
      "Range: (tensor(0.0433), tensor(-0.1067))\n",
      "Range difference: tensor(0.1500)\n",
      "Range: (tensor(0.0412), tensor(-0.0450))\n",
      "Range difference: tensor(0.0862)\n",
      "Range: (tensor(0.0253), tensor(-0.0957))\n",
      "Range difference: tensor(0.1209)\n",
      "Range: (tensor(0.0145), tensor(-0.1195))\n",
      "Range difference: tensor(0.1340)\n",
      "Range: (tensor(0.0407), tensor(-0.0297))\n",
      "Range difference: tensor(0.0704)\n",
      "Range: (tensor(0.0495), tensor(-0.0130))\n",
      "Range difference: tensor(0.0625)\n",
      "Range: (tensor(0.0616), tensor(-0.0977))\n",
      "Range difference: tensor(0.1594)\n",
      "Range: (tensor(0.0632), tensor(-0.0346))\n",
      "Range difference: tensor(0.0978)\n",
      "Range: (tensor(0.0713), tensor(-0.0594))\n",
      "Range difference: tensor(0.1307)\n",
      "Range: (tensor(0.0020), tensor(-0.1203))\n",
      "Range difference: tensor(0.1224)\n",
      "Range: (tensor(0.0470), tensor(-0.1129))\n",
      "Range difference: tensor(0.1599)\n",
      "Range: (tensor(0.1005), tensor(-0.0567))\n",
      "Range difference: tensor(0.1572)\n",
      "Range: (tensor(0.0119), tensor(-0.0632))\n",
      "Range difference: tensor(0.0751)\n",
      "Range: (tensor(0.0245), tensor(-0.0566))\n",
      "Range difference: tensor(0.0811)\n",
      "Range: (tensor(0.0854), tensor(-0.0942))\n",
      "Range difference: tensor(0.1796)\n",
      "Range: (tensor(0.0656), tensor(-0.0387))\n",
      "Range difference: tensor(0.1043)\n",
      "Range: (tensor(0.0083), tensor(-0.1010))\n",
      "Range difference: tensor(0.1093)\n",
      "Range: (tensor(0.0870), tensor(-0.0939))\n",
      "Range difference: tensor(0.1809)\n",
      "Range: (tensor(0.0235), tensor(-0.0641))\n",
      "Range difference: tensor(0.0876)\n",
      "Range: (tensor(0.0443), tensor(-0.0288))\n",
      "Range difference: tensor(0.0732)\n",
      "Range: (tensor(0.0497), tensor(-0.0352))\n",
      "Range difference: tensor(0.0849)\n",
      "Range: (tensor(0.0394), tensor(-0.0096))\n",
      "Range difference: tensor(0.0490)\n",
      "Range: (tensor(0.1143), tensor(-0.1130))\n",
      "Range difference: tensor(0.2273)\n",
      "Range: (tensor(0.0446), tensor(-0.0069))\n",
      "Range difference: tensor(0.0515)\n",
      "Range: (tensor(0.0395), tensor(-0.0540))\n",
      "Range difference: tensor(0.0935)\n",
      "Range: (tensor(0.0785), tensor(-0.1077))\n",
      "Range difference: tensor(0.1862)\n",
      "Range: (tensor(0.0305), tensor(-0.1146))\n",
      "Range difference: tensor(0.1450)\n",
      "Range: (tensor(0.0478), tensor(-0.0502))\n",
      "Range difference: tensor(0.0980)\n",
      "Range: (tensor(0.0329), tensor(-0.0226))\n",
      "Range difference: tensor(0.0556)\n",
      "Range: (tensor(0.0537), tensor(-0.0995))\n",
      "Range difference: tensor(0.1532)\n",
      "Range: (tensor(0.0637), tensor(-0.0631))\n",
      "Range difference: tensor(0.1268)\n",
      "Range: (tensor(0.0617), tensor(-0.0152))\n",
      "Range difference: tensor(0.0768)\n",
      "Range: (tensor(0.1178), tensor(-0.0918))\n",
      "Range difference: tensor(0.2096)\n",
      "Range: (tensor(0.0364), tensor(-0.0995))\n",
      "Range difference: tensor(0.1359)\n",
      "Range: (tensor(0.0579), tensor(-0.0955))\n",
      "Range difference: tensor(0.1534)\n",
      "Range: (tensor(0.0589), tensor(-0.0539))\n",
      "Range difference: tensor(0.1128)\n",
      "Range: (tensor(0.0321), tensor(-0.0651))\n",
      "Range difference: tensor(0.0972)\n",
      "Range: (tensor(0.0060), tensor(-0.0447))\n",
      "Range difference: tensor(0.0507)\n",
      "Range: (tensor(0.0129), tensor(-0.0829))\n",
      "Range difference: tensor(0.0959)\n",
      "Range: (tensor(-0.0024), tensor(-0.1110))\n",
      "Range difference: tensor(0.1086)\n",
      "Range: (tensor(0.0592), tensor(-0.0817))\n",
      "Range difference: tensor(0.1409)\n",
      "Range: (tensor(0.0027), tensor(-0.0648))\n",
      "Range difference: tensor(0.0675)\n",
      "Range: (tensor(0.0906), tensor(-0.0766))\n",
      "Range difference: tensor(0.1672)\n",
      "Range: (tensor(0.0649), tensor(-0.0037))\n",
      "Range difference: tensor(0.0686)\n",
      "Range: (tensor(0.0752), tensor(-0.0696))\n",
      "Range difference: tensor(0.1448)\n",
      "Range: (tensor(0.0542), tensor(-0.0804))\n",
      "Range difference: tensor(0.1346)\n",
      "Range: (tensor(0.1377), tensor(-0.1031))\n",
      "Range difference: tensor(0.2408)\n",
      "Range: (tensor(0.0209), tensor(-0.0458))\n",
      "Range difference: tensor(0.0667)\n",
      "Range: (tensor(0.0599), tensor(-0.0699))\n",
      "Range difference: tensor(0.1298)\n",
      "Range: (tensor(0.0242), tensor(-0.0588))\n",
      "Range difference: tensor(0.0830)\n",
      "Range: (tensor(0.0439), tensor(-0.1058))\n",
      "Range difference: tensor(0.1497)\n",
      "Range: (tensor(0.0392), tensor(-0.0471))\n",
      "Range difference: tensor(0.0863)\n",
      "Range: (tensor(0.0243), tensor(-0.0951))\n",
      "Range difference: tensor(0.1194)\n",
      "Range: (tensor(0.0124), tensor(-0.1192))\n",
      "Range difference: tensor(0.1317)\n",
      "Range: (tensor(0.0400), tensor(-0.0300))\n",
      "Range difference: tensor(0.0701)\n",
      "Range: (tensor(0.0514), tensor(-0.0129))\n",
      "Range difference: tensor(0.0643)\n",
      "Range: (tensor(0.0620), tensor(-0.1000))\n",
      "Range difference: tensor(0.1621)\n",
      "Range: (tensor(0.0647), tensor(-0.0362))\n",
      "Range difference: tensor(0.1009)\n",
      "Range: (tensor(0.0714), tensor(-0.0612))\n",
      "Range difference: tensor(0.1325)\n",
      "Range: (tensor(0.0012), tensor(-0.1204))\n",
      "Range difference: tensor(0.1217)\n",
      "Range: (tensor(0.0461), tensor(-0.1110))\n",
      "Range difference: tensor(0.1571)\n",
      "Range: (tensor(0.0999), tensor(-0.0581))\n",
      "Range difference: tensor(0.1580)\n",
      "Range: (tensor(0.0141), tensor(-0.0624))\n",
      "Range difference: tensor(0.0765)\n",
      "Range: (tensor(0.0240), tensor(-0.0590))\n",
      "Range difference: tensor(0.0830)\n",
      "Range: (tensor(0.0839), tensor(-0.0952))\n",
      "Range difference: tensor(0.1792)\n",
      "Range: (tensor(0.0665), tensor(-0.0388))\n",
      "Range difference: tensor(0.1052)\n",
      "Range: (tensor(0.0080), tensor(-0.1004))\n",
      "Range difference: tensor(0.1084)\n",
      "Range: (tensor(0.0890), tensor(-0.0921))\n",
      "Range difference: tensor(0.1812)\n",
      "Range: (tensor(0.0218), tensor(-0.0639))\n",
      "Range difference: tensor(0.0857)\n",
      "Range: (tensor(0.0461), tensor(-0.0287))\n",
      "Range difference: tensor(0.0748)\n",
      "Range: (tensor(0.0490), tensor(-0.0344))\n",
      "Range difference: tensor(0.0834)\n",
      "Range: (tensor(0.0379), tensor(-0.0087))\n",
      "Range difference: tensor(0.0466)\n",
      "Range: (tensor(0.1165), tensor(-0.1137))\n",
      "Range difference: tensor(0.2302)\n",
      "Range: (tensor(0.0441), tensor(-0.0092))\n",
      "Range difference: tensor(0.0533)\n",
      "Range: (tensor(0.0392), tensor(-0.0537))\n",
      "Range difference: tensor(0.0930)\n",
      "Range: (tensor(0.0786), tensor(-0.1066))\n",
      "Range difference: tensor(0.1852)\n",
      "Range: (tensor(0.0292), tensor(-0.1148))\n",
      "Range difference: tensor(0.1440)\n",
      "Range: (tensor(0.0479), tensor(-0.0506))\n",
      "Range difference: tensor(0.0984)\n",
      "Range: (tensor(0.0320), tensor(-0.0224))\n",
      "Range difference: tensor(0.0544)\n",
      "Range: (tensor(0.0533), tensor(-0.1014))\n",
      "Range difference: tensor(0.1547)\n",
      "Range: (tensor(0.0630), tensor(-0.0604))\n",
      "Range difference: tensor(0.1234)\n",
      "Range: (tensor(0.0622), tensor(-0.0137))\n",
      "Range difference: tensor(0.0759)\n",
      "Range: (tensor(0.1176), tensor(-0.0912))\n",
      "Range difference: tensor(0.2088)\n",
      "Range: (tensor(0.0363), tensor(-0.0992))\n",
      "Range difference: tensor(0.1355)\n",
      "Range: (tensor(0.0571), tensor(-0.0946))\n",
      "Range difference: tensor(0.1517)\n",
      "Range: (tensor(0.0583), tensor(-0.0527))\n",
      "Range difference: tensor(0.1111)\n",
      "Range: (tensor(0.0322), tensor(-0.0647))\n",
      "Range difference: tensor(0.0969)\n",
      "Range: (tensor(0.0066), tensor(-0.0446))\n",
      "Range difference: tensor(0.0512)\n",
      "Range: (tensor(0.0145), tensor(-0.0823))\n",
      "Range difference: tensor(0.0968)\n",
      "Range: (tensor(-0.0023), tensor(-0.1099))\n",
      "Range difference: tensor(0.1076)\n",
      "Range: (tensor(0.0589), tensor(-0.0803))\n",
      "Range difference: tensor(0.1392)\n",
      "Range: (tensor(0.0027), tensor(-0.0637))\n",
      "Range difference: tensor(0.0664)\n",
      "Range: (tensor(0.0927), tensor(-0.0754))\n",
      "Range difference: tensor(0.1681)\n",
      "Range: (tensor(0.0649), tensor(-0.0023))\n",
      "Range difference: tensor(0.0672)\n",
      "Range: (tensor(0.0752), tensor(-0.0685))\n",
      "Range difference: tensor(0.1436)\n",
      "Range: (tensor(0.0518), tensor(-0.0801))\n",
      "Range difference: tensor(0.1319)\n",
      "Range: (tensor(0.1381), tensor(-0.1032))\n",
      "Range difference: tensor(0.2412)\n",
      "Range: (tensor(0.0223), tensor(-0.0457))\n",
      "Range difference: tensor(0.0681)\n",
      "Range: (tensor(0.0612), tensor(-0.0694))\n",
      "Range difference: tensor(0.1306)\n",
      "Range: (tensor(0.0256), tensor(-0.0585))\n",
      "Range difference: tensor(0.0841)\n",
      "Range: (tensor(0.0423), tensor(-0.1051))\n",
      "Range difference: tensor(0.1474)\n",
      "Range: (tensor(0.0412), tensor(-0.0455))\n",
      "Range difference: tensor(0.0867)\n",
      "Range: (tensor(0.0256), tensor(-0.0948))\n",
      "Range difference: tensor(0.1205)\n",
      "Range: (tensor(0.0124), tensor(-0.1171))\n",
      "Range difference: tensor(0.1295)\n",
      "Range: (tensor(0.0415), tensor(-0.0294))\n",
      "Range difference: tensor(0.0709)\n",
      "Range: (tensor(0.0489), tensor(-0.0139))\n",
      "Range difference: tensor(0.0628)\n",
      "Range: (tensor(0.0628), tensor(-0.1001))\n",
      "Range difference: tensor(0.1629)\n",
      "Range: (tensor(0.0634), tensor(-0.0359))\n",
      "Range difference: tensor(0.0993)\n",
      "Range: (tensor(0.0713), tensor(-0.0588))\n",
      "Range difference: tensor(0.1301)\n",
      "Range: (tensor(0.0021), tensor(-0.1209))\n",
      "Range difference: tensor(0.1230)\n",
      "Range: (tensor(0.0458), tensor(-0.1114))\n",
      "Range difference: tensor(0.1572)\n"
     ]
    }
   ],
   "source": [
    "max_diff = 0\n",
    "min_diff = 0\n",
    "diff = []\n",
    "for i in range(16):\n",
    "    for j in range(50):\n",
    "        print(\"Range: (\"+str(torch.max(params[:,i,0,j]*10))+\", \"+str(torch.min(params[:,i,0,j]*10))+\")\")\n",
    "        print(\"Range difference: \"+str(torch.max(params[:,i,0,j]*10) - torch.min(params[:,i,0,j]*10)))\n",
    "        diff.append((torch.max(params[:,i,0,j]*10) - torch.min(params[:,i,0,j]*10)).data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "1b2f5406-1e5f-459d-8620-b5d389272cc8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Target similarity in MoEs')"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkcAAAHRCAYAAABglB00AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABF70lEQVR4nO3dd3RU5f7+/WskpBESSughhRaKotLEICSAYEFpIr2LHDwg5uBXBaRLs2BUEAtwiIIiRREBFSPSkabi4YcgNRjCARIwjZJC9vMHT+bsYRJImEnl/Vpr1iL3fe+9P/tOmYvdxmIYhiEAAABIku4q7AIAAACKEsIRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhHueBaLJc+vyMjIwi67WJgyZYosFoumTJmS79uKjIyUxWLR4MGDc9WeX6Kjo2WxWBQYGFgg27uZwMBAWSwWRUdH5/u2zL8fn3/+eY7jrl27pqpVq1rHzps3z+FtZ835rV5F4XuC4sGlsAsACtugQYPs2o4dO6YdO3aoSpUqevTRR+3669SpUxCl5Yvo6GgFBQUpICCgQN40cd2dNO+ffPKJ+vbtm23f999/r3PnzuXbtrP7fc7i6+ubb9tFyUI4wh0vu6NAkZGR2rFjh+rXr89RIgeMGjVKvXv3LpA3pW7duqlly5by8fHJ923dTI0aNXTo0CGVLl26UOuQpI0bNyo9PV01atQosG02bdpUP/74o2JjY7PdbtbvU7NmzbRv3z6nb5/fVzgDp9UA5BtfX1/Vr1+/QMKRj4+P6tevr2rVquX7tm6mdOnSql+/vmrXrl2odUhS7dq1Vb9+/QINaoMHD1ZmZqaWLl1q1/f3339r7dq1atq0qRo1alRgNQF5RTgC8iA9PV1LlixRr169VK9ePXl5ecnLy0v33nuvpk2bpkuXLmW7XNY1D4Zh6IMPPlDTpk3l5eWlcuXKWcdkZmZq7ty5uueee+Tu7q4qVaqoT58+OnHixC2v3dm2bZt69OihatWqydXVVVWrVlXPnj21f/9+m3FTpkxRUFCQJOnUqVO3dT3G33//rWnTpqlx48YqX768PD09FRgYqMcee0wLFiyw2152dZvbT548qT59+qhy5coqU6aMWrZsqR9++ME69ptvvtFDDz0kb29vlS9fXr1799aZM2fs6srrtUXJycn66KOP1LlzZ9WuXVseHh7y9vZWixYt9O677yojI8NuGfP1ROnp6Zo5c6YaNWokDw8P3XfffXZjzPt7q3nv0KGDLBaLvvzyyxxr7tKliywWiz799NNc7WNO1xyZ2zds2KCwsDB5e3vLy8tLYWFh2rx5c67Wn50ePXqoTJky+uSTT+z6li1bptTU1Jue+pKu/y5ERkaqdevWKleunNzd3RUcHKyXXnpJ8fHxt13bjTIyMrR48WK1atVKVatWlZubm6pXr66QkBBNmDBBV69eddq2ULwQjoA8OHfunAYOHKiffvpJVatW1RNPPKGQkBCdOnVKkydPVmhoqK5cuZLj8iNHjtTo0aPl4+OjJ5980uZ/z4MHD9bo0aN15MgRhYWFqW3bttqxY4eaNWumkydP5rjOGTNmqE2bNlq9erX8/f3VtWtX+fn5aeXKlXrggQe0du1a69j77rtPTz31lCSpTJkyGjRokPXVo0ePW+7/pUuX1LJlS02ePFkXLlxQaGioOnfurJo1a2rXrl2aM2dObqbR6uTJk2revLn27dundu3aqVGjRtq9e7c6deqkLVu26L333lP37t2VkZGhDh06yNPTU8uXL9fDDz+s1NTUPG3rRr///rtGjBihX375RQEBAeratauaN2+uAwcOKDw8XN26dZNhGNkum5mZqW7duum1115TzZo11blzZ2v4yU5u5n3UqFGSpA8++CDbdZw+fVrr169XhQoV1LNnT0d23WrBggV67LHHlJiYqEceeUQBAQHasmWLOnbsqO3bt9/WOr28vPTUU0/p0KFDdqfNPvnkE5UuXTrH65EkyTAM9enTR0OGDNHevXv14IMPqkuXLrp06ZLeeustNWnSRMeOHbut2m40aNAgDR06VP/5z3+s36OGDRvq9OnTmjFjhhISEpyyHRRDBgA7ixcvNiQZoaGhNu1JSUnGunXrjPT0dJv2hIQE4/HHHzckGbNmzbJbnyRDklG+fHnj119/tetfsWKFIcmoXLmy8ccff1jbU1NTjX79+lmXnzx5ss1ya9asMSQZ/v7+xi+//GLT98033xguLi5GuXLljIsXL1rbT548aUgyAgICcjkb/xMZGWlIMjp37mxkZGTY9F29etXYsmWLTdvkyZOzrTurXZLx4osvGteuXbP2jR8/3pBk1KtXzyhXrpyxceNGa19CQoLRsGFDQ5IRGRlps86s79mgQYNy1R4TE2Ns3rzZyMzMtGk/e/as0aRJE0OSsWzZMpu+rLmTZAQGBhonT560m6Oc5vdW856RkWEEBAQYFovF+PPPP+36J02aZJ2v3AoICDAk2dWZ1e7u7m6sXbvW2p6ZmWmMGDHCkGS0a9cu19sxjP/9jCcnJxsbN240JBmjRo2y9v/xxx+GJKNbt26GYRjGoEGDDEnG3LlzbdYzd+5cQ5JRs2ZN4+jRo9b2q1evGn379jUkGS1atLBZxvx9ya3o6Gjr9yM+Pt6uf8eOHcalS5dyvT6ULBw5AvKgbNmy6tSpk1xcbO9l8PHx0TvvvCNJNz0t8vLLL+v++++3a3///fclSePHj1eDBg2s7a6urnrvvffk5eWV7fqmTp0q6foppSZNmtj0PfnkkxoxYoQSEhKyvf7jdpw/f16SFBYWplKlStn0ubm5qU2bNnlaX1BQkGbOnKm77vrfn6KXXnpJFotFR44c0ahRo9SuXTtrn4+Pj4YPHy5JDp36kSQ/Pz+FhobKYrHYtFepUkWzZ8+WdPPv5axZs5x6a3ipUqU0YsQIGYahjz76yKbv2rVrWrRokSwWi/7xj384bZsvvPCCnnjiCevXFotF06ZNkyRt375d6enpt7Xetm3bKiAgQMuWLVNaWpokWU+z3eqUWkREhKTr82u+K9TNzU3vv/++fHx8tGfPnhyPbN3sVv7w8HDruKyf5XvvvVcVK1a0W09ISIg8PT1zv9MoUbhbDbgNe/fu1aZNm3Tq1CldvnxZhmFYT8EcOXIkx+W6detm15aRkaFdu3ZJUranSypUqKAOHTpo9erVNu1xcXH69ddf5evrq7CwsGy3Fxoaqnnz5mnXrl16/vnnc7t7OWrWrJkk6fXXX1fFihX1xBNPqEKFCre9vrCwMLm6utq0lStXThUrVlR8fLw6dOhgt0zWG2Z21x3llWEY2rp1q7Zt26YzZ87oypUrMgxDycnJkm7+vezatavD27/RsGHDNGXKFEVGRmrGjBlyd3eXdP26q9jYWHXo0EF169Z12vYee+wxu7ZKlSqpQoUKunjxouLj42/rAneLxaIBAwZo+vTpWr9+vbp06aKlS5fK19dXjz/+eI7LnT59WidOnJCrq6t69+5t11+uXDl169ZNkZGR2rJlix566CG7MTcLXy1atLD+u379+vLy8tL69es1Y8YM9e3b96anRnFnIRwBeZCSkqLevXtr/fr1OY5JSkrKsc/f39+uLT4+XqmpqdYLqbMTEBBg15Z1kW18fLzNkZfsZP0v2VFt27bVuHHj9MYbb2jQoEGyWCxq0KCBwsLC1KtXrzwfOfLz88u2vUyZMoqPj8+2v0yZMpLk8DVH586dU5cuXbR79+4cx+T0vaxcubI1uDiTr6+vevXqpU8//VTLly+3vtF/+OGHkqTnnnvOqdurWbNmtu1ly5bVxYsXHZrjwYMHa/r06fr000/l6emp2NhYjR49+qZ3zsXGxkq6/nty45HJLFl3AWaNvVFub+UvW7asIiMjNWzYME2YMEETJkxQjRo11KZNG3Xp0kVPPfWU3RFi3Dn4zgN5MG7cOK1fv16NGjXS66+/rmbNmqlChQoqXbq00tLS5ObmdtPlPTw8cuy78fSOmZHNhcGZmZmSrh9ZevLJJ2+63fr169+0Py9mzpyp4cOHa+3atfrpp5+0fft2zZ8/X/Pnz9fAgQOzvUspJ7cKdbfqd8SwYcO0e/dutW7dWlOnTlXjxo3l4+MjFxcXHTlyRMHBwTlekH2z76OjRo4cqU8//VQffvihBg0apOPHjysqKko1atS45fc5r/JzfmvXrq1WrVpp/fr11rs4b3VKLaf5zuuY3HrqqafUvn17rV+/XlFRUdq2bZuWLVumZcuW6Z577tG2bdsK/blZKByEIyAPVq5cKUn64osvdPfdd9v03e4dNBUrVpSbm5tSU1N19uzZbE9j/PXXX3ZtWf/r9/T0LPAH3wUGBur555/X888/L8MwFBUVpT59+ujTTz9V37599cgjjxRoPXl16dIlfffddypVqpTWrl1r9wborLuhbkeLFi3UvHlz7dq1S/v379fnn38uwzD07LPPFrsjGYMHD9aOHTsUFRWle+65x+66uBtlHSn866+/dO3atWyPHmXduemsB1uWK1dO/fr1U79+/SRJf/zxhwYNGqR9+/Zp9uzZmjVrllO2g+KFC7KBPLh48aKk7E9HLFu27LbWWbp0abVs2VKStGLFimy3GRUVZddevXp1NW7cWKdPn9aePXtyvb2sa3yye47P7bBYLOrYsaP1VvX//Oc/TllvfkpMTNS1a9dUtmzZbI8M3O738mbyMu8jR46UJL377ruKjIyUi4uLnn32WafXlN969uyp6tWrq2LFiho2bNgtx/v5+SkoKEhpaWn64osv7PoTExOt196FhoY6vV5Jatiwof71r39JKh4/y8gfhCMgD7JOT82fP9+m/ccff8zzM37Mst4MZ86cqcOHD1vb09PT9cILLyglJSXb5bLuLOrdu7e2bNli13/58mUtW7ZMhw4dsrZVqlRJrq6uOnfunP7+++881bl69Wpt377d7tRGYmKitm3bJin766qKmipVqqhcuXJKSEiw+5DUpUuX6rPPPnP6NvMy71kfuRIZGam4uDh16dJF1atXd3pN+c3b21uxsbGKj4/X6NGjc7VMVjAZN26cjh8/bm1PS0vTqFGjlJCQoBYtWmR7MXZe/Pbbb1qxYoXdgx4Nw7BeU1gcfpaRP4rXMVqgkE2YMEG9evXS+PHjtXLlSgUHB+vUqVP6+eefNXbsWOst4Hn19NNPa8CAAVqyZInuu+8+tW3bVj4+Ptq5c6dSUlKsfTfe2dWlSxfNmTNHL730ksLCwtSwYUPVq1dPkhQTE6PDhw9bTyFlPSKgdOnS6tSpk1avXq37779frVq1koeHh3x9fW9Z/5YtW/Tuu++qcuXKatKkiSpWrKi///5b27dvV1JSklq1aqXu3bvf1hwUpFKlSmn8+PF6+eWX1a9fP73//vsKCAjQoUOHtH//foe+lznJy7y7ubnpmWee0euvvy7J+RdiF2UjR47Utm3btHLlSt19991q27atvL29tWPHDp0+fVp+fn43Da+3ekL6/Pnz5enpqVOnTqlXr14qU6aMmjZtqho1aujq1avat2+fYmJiVKVKFb388stO3jsUF4QjIA969uypChUqaNq0aTpw4ICOHj2qRo0a6ZNPPtHAgQMdekONjIxUs2bN9PHHH2vTpk3y9vZW+/btNWPGDM2cOVNS9p8qPmbMGLVt21bvvfeeNm/erO+++05ubm6qVq2aHn/8cXXp0kWtW7e2WWbBggWqUKGCNmzYoBUrVigjI0MBAQG3rH/w4MFyc3PTtm3b9Pvvv+vChQuqUKGC7rnnHg0cOFCDBg0qEh+4mhsvvfSS/P39NWfOHP3nP//RwYMHdf/992v9+vVq2LCh08ORlLd579Chg15//XXVq1fP5llPJd1dd92lL774Qo899pgWLVqk7du3KzU1Vf7+/nrxxRf1yiuvqFKlSjkuf6sbAt555x15enqqZcuWmjlzprZs2aLDhw9rz5498vDwkL+/vwYPHqxRo0apcuXKzt49FBMWw5mX/gNwuoyMDN199936888/tXfvXuuzhlCyPfvss1q4cKHefvtt66kmAAWDcAQUEQcOHFDdunVtnp9z5coVvfzyy5o3b54aNWqk//f//l8hVoiCcvz4cTVq1EilS5dWTEyMzQcUA8h/nFYDioipU6dqw4YNatKkiapVq6YLFy7o999/V1xcnLy9vbV48eLCLhH5bOzYsYqJidEPP/yg1NRUTZw4kWAEFAKOHAFFxNdff61FixZp//79unDhggzDUI0aNfTwww/r5ZdfVq1atQq7ROSzwMBA/fXXX6pRo4aGDBmiKVOm5OuDGgFkj3AEAABgwn9JAAAATAhHAAAAJlyQfRsyMzN15swZlS1b9qYfFgoAAIoOwzCUnJys6tWr3/R6PsLRbThz5ky2n60FAACKvpiYGOsHHWeHcHQbypYtK+n65Hp7exdyNQAAIDeSkpJUs2ZN6/t4TghHtyHrVJq3tzfhCACAYuZWl8RwQTYAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJi6FXQCKv8Cx6wu7hNsSPbtTYZcAACiCOHIEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGDiUtgFwFbg2PWFXQIAAHc0jhwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEyKVDg6cuSIJk2apJYtW6pSpUoqW7as7rvvPs2YMUOXLl2yG3/u3DkNHTpUVapUkbu7uxo3bqwFCxbkuP5ly5apadOm8vDwkK+vr/r06aNTp07l5y4BAIBipkiFo3//+996++23FRQUpIkTJ+rNN99UcHCwJkyYoJCQEF25csU6NiEhQQ899JC++OILPfPMM5o7d678/f01fPhwTZ061W7d8+bNU9++feXh4aGIiAiFh4crKipKISEhOnPmTEHuJgAAKMIshmEYhV1Eln379qlOnToqV66cTfuECRM0Y8YMzZs3TyNHjpQkjRs3TrNnz9aXX36p7t27W8d27txZ33//vf78808FBQVJki5cuKDAwEDVq1dPu3fvlouLi3V7LVq00NChQ7Vw4cJc15mUlCQfHx8lJibK29vbwb22xQfPFpzo2Z0KuwQAQAHK7ft3kTpy1KxZM7tgJEk9e/aUJB04cMDa9tlnnykoKMgmGEnSmDFjlJ6eruXLl1vb1qxZo5SUFI0ePdoajLK216ZNG61YsUJpaWlO3hsAAFAcFalwlJPY2FhJUuXKlSVJZ8+eVUxMjB588EG7sQ8++KAsFov27Nljbcv6d0hIiN34kJAQJScn6/Dhw/lROgAAKGaKfDi6du2apk2bJhcXF/Xr10/S/8KSn5+f3Xg3Nzf5+vrq9OnT1rabjc9qM4+/UWpqqpKSkmxeAACgZCry4Wj06NHatWuXpkyZouDgYEnS5cuXJV0PQtlxd3e3jrnVeHd3d5sx2Zk1a5Z8fHysr5o1a97ezgAAgCKvSIejCRMmaP78+Ro2bJjGjx9vbff09JR0/YhOdq5cuWIdc6vxWXfAmcffaNy4cUpMTLS+YmJi8r4zAACgWCiy4WjKlCmaMWOGBg4cqI8++kgWi8XaV6NGDUnZnwq7evWqLly4YHMK7Wbjb3bKLYubm5u8vb1tXgAAoGQqkuFo6tSpmjp1qvr376/Fixfrrrtsy6xatar8/Pz0888/2y27a9cuGYah5s2bW9uy/r1z50678Tt37pSXl5fq16/v5L0AAADFUZELR9OmTdOUKVPUr18/RUZG2gWjLH379tXJkyf11Vdf2bS//fbbcnFxUa9evaxtXbp0kaenp9577z1lZGRY2/ft26etW7eqZ8+ecnV1zZ8dAgAAxYrLrYcUnPfff1+TJ0+Wv7+/OnTooGXLltn0V6lSRR06dJAkjR07VqtWrdKAAQP0yy+/KCgoSGvWrNG6des0ceJE1apVy7qcr6+vZs6cqfDwcIWFhWnAgAGKj49XRESEqlSpomnTphXofgIAgKKrSIWjvXv3SpL++usvDR482K4/NDTUGo7Kly+v7du3a/z48VqwYIGSkpJUp04dffDBBxoxYoTdsi+88IJ8fX01Z84chYeHy9PTUx06dNCsWbOs1yQBAAAUqY8PKS74+JCSgY8PAYA7S7H8+BAAAIDCRjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAICJS2EXABSWwLHrC7uEPIue3amwSwCAEo8jRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYFLkwtGsWbP09NNPq1atWrJYLAoMDMxx7JQpU2SxWLJ9hYeHZ7vMsmXL1LRpU3l4eMjX11d9+vTRqVOn8mdnAABAseNS2AXcaPz48apQoYKaNGmihISEXC0TEREhX19fm7YGDRrYjZs3b56ef/55tWrVShEREYqPj9c777yjrVu3au/evapevbozdgEAABRjRS4cHT9+XLVq1ZIk3X333UpJSbnlMl27dr3pESZJunDhgsaNG6cmTZpo8+bNcnG5vuuPPvqoWrRooUmTJmnhwoUO1w8AAIq3IndaLSsY5VVycrLS09Nz7F+zZo1SUlI0evRoazCSpGbNmqlNmzZasWKF0tLSbmvbAACg5HAoHG3cuNFZdTjk3nvvlbe3t9zd3dWsWTMtX77cbsyePXskSSEhIXZ9ISEhSk5O1uHDh/O9VgAAULQ5FI46dOigOnXqaObMmTpz5oyzasq1cuXKadiwYXr33Xf1zTffaM6cOUpISFDv3r01ffp0m7GxsbGSJD8/P7v1ZLWdPn062+2kpqYqKSnJ5gUAAEomh8LRW2+9JQ8PD02YMEEBAQHq3LmzvvnmG2VmZjqrvpsKDw/XggULNHjwYD355JMKDw/XgQMH1KhRI02dOtXmLrTLly9Lktzc3OzW4+7ubjPmRrNmzZKPj4/1VbNmzXzYGwAAUBQ4FI7GjBmjAwcO6Oeff9aQIUO0ZcsWdevWTX5+fho/fryOHTvmrDpzzcPDQy+99JIyMjL0ww8/WNs9PT0lXT8KdKMrV67YjLnRuHHjlJiYaH3FxMTkQ+UAAKAocMoF2Q888IA+/vhj/fe//9WiRYtUu3ZtzZ49W8HBwWrXrp0+//zzbENJfsm6cy0uLs7aVqNGDUnZnzq72Sk36frRJm9vb5sXAAAomZx6t5qnp6cGDx6sL774Qv3795dhGNq8ebP69++vmjVr6o033tC1a9ecuclsHT16VJJUtWpVa1vz5s0lSTt37rQbv3PnTnl5eal+/fr5XhsAACjanBaOrl27pq+//lpPPPGEAgMDtXTpUoWFhenzzz/XypUr1aBBA40bN04vvPCCU7aXkZGhCxcu2LUnJCRo1qxZcnV11SOPPGJt79Klizw9PfXee+8pIyPD2r5v3z5t3bpVPXv2lKurq1NqAwAAxZfDD4E8cuSIFi5cqE8//VRxcXGqWLGiwsPDNXz4cNWtW9c67qmnntI///lPLVu2TPPmzctxfUuWLLFeSB0XF6e0tDTrnWflypXTqFGjJEkpKSny8/NT9+7ddc8996hixYo6ceKE/v3vf+v8+fN65513rKfSJMnX11czZ85UeHi4wsLCNGDAAMXHxysiIkJVqlTRtGnTHJ0KAABQAlgMwzBud+HWrVtr586dMgxDYWFh+sc//qHu3burdOnS2Y5ftmyZ+vXrd9O72cLCwrRly5Zs+wICAhQdHS3p+oXVI0eO1J49exQTE6OUlBSVL19eDzzwgMLDw9W+ffts1/HZZ59pzpw5OnTokDw9PdWhQwfNmjVLQUFBud7vpKQk+fj4KDEx0enXHwWOXe/U9aFkiZ7dqbBLAIBiK7fv3w6FoypVqmjQoEEaPny46tSpc8vxcXFx+uOPPxQaGnq7mywSCEcoLIQjALh9uX3/dui02unTp3M8SpSdSpUqFftgBAAASjaHLsg+ffq01q5dm2P/2rVrrafBAAAAigOHjhy9+uqriomJ0ZNPPplt/5w5c1SzZk0tWbLEkc0AAAAUGIeOHG3fvt3mdvkbdezYUdu2bXNkEwAAAAXKoXB0/vx5mwct3qhy5co6d+6cI5sAAAAoUA6Fo3Llyun48eM59h87dkxly5Z1ZBMAAAAFyqFw1Lp1ay1cuFDnz5+36zt79qwWLlyohx56yJFNAAAAFCiHL8heu3at7r33Xo0ZM0aNGzeWJO3fv18RERFKSUnR+PHjnVIoAABAQXAoHN13331atWqVhgwZoldeeUUWi0WSZBiGfH19tXLlSjVr1swphQIAABQEhz9b7YknntBff/2lDRs26OjRozIMQ8HBwerYsaM8PDycUSMAAECBcTgcSZKHh4e6du3qjFUBAAAUKocuyAYAAChpHA5HX3zxhVq1aqXKlSurVKlSdi8XF6ccnAIAACgQDiWXN998U2PHjlXFihXVsmVLVaxY0Vl1AQAAFAqHwtH777+vBx54QBs3buTiawAAUCI4dFrt7Nmz6t+/P8EIAACUGA6Fo9q1aysxMdFZtQAAABQ6h8LRv/71Ly1cuFDJycnOqgcAAKBQOXTNkaurqypVqqQGDRpo6NChCgoKUqlSpezGDRw40JHNAAAAFBiLYRjG7S581123PvBksVh07dq1291EkZSUlCQfHx8lJibK29vbqesOHLveqetDyRI9u1NhlwAAxVZu378dOnK0adMmRxYHAAAochwKR6Ghoc6qAwAAoEhw2seHpKamKjY2Vmlpac5aJQAAQIFzOBz9+uuvateuncqWLSt/f39t375dknT+/Hm1b99eP/74o8NFAgAAFBSHwtH+/fvVunVrHT9+3O6OtMqVK+vKlSv65JNPHCoQAACgIDkUjiZNmqQaNWro4MGDmj17tm688a19+/bas2ePQwUCAAAUJIfC0bZt2zRs2DB5eXnJYrHY9fv7++vMmTOObAIAAKBAORSOrl69Kh8fnxz7k5KSHFk9AABAgXP4s9V++eWXHPs3btyohg0bOrIJAACAAuVQOOrbt6+WLFmiqKgoa1vW6bU33nhDGzZs0IABAxyrEAAAoAA59BDI//u//1NUVJQeffRR1a1bVxaLRaNHj1ZcXJzi4uLUoUMH/fOf/3RWrQAAAPnOoSNHrq6uioqK0ptvvikvLy+5u7vr+PHjqlq1qt544w2tW7cuV5+/BgAAUFQ4dORIklxcXDRmzBiNGTPGGfUAAAAUKg7rAAAAmDh05OjTTz/N1bgbn54NAABQVDkUjgYPHiyLxWL3ZOwbHwhJOAIAAMWFQ+Fo06ZNdm0ZGRk6fvy45s+fL09PT82YMcORTQAAABQoh8JRaGhotu3t27fXoEGD1KJFC/36669q27atI5sBAAAoMPl2Qbabm5v69++v+fPn59cmAAAAnC5f71Zzc3NTbGxsfm4CAADAqfItHP33v//Vhx9+qKCgoPzaBAAAgNM5dM1Ru3btsm2/ePGiDh8+rLS0NH3yySeObAIAAKBAORSOTpw4YXfbvsViUYUKFdS9e3eNGjVKISEhDhUIAABQkBwKR9HR0U4qAwAAoGjg40MAAABMCEcAAAAmDp1Wu+uuu+yuOboVi8WijIwMRzYLAACQbxwKRwMHDtRvv/2mAwcOqF69emrQoIEk6dChQzpy5IgaN26s+++/3ymFAgAAFASHwtGAAQO0atUqrVq1St27d7fpW7VqlQYPHqy33347x1v+AQAAihqHrjmaOHGinn32WbtgJEk9evTQM888owkTJjiyCQAAgALlUDj6/fffFRwcnGN/gwYN9PvvvzuyCQAAgALlUDjy8vLS9u3bc+zfunWrvLy8HNkEAABAgXIoHHXv3l3Lli3Tq6++qoSEBGt7QkKCxo8fr+XLl+upp55ytEYAAIACYzEMw7jdhRMTE9WxY0ft3btXd911l6pUqSKLxaKzZ88qMzNTLVq00A8//CBvb29n1lzokpKS5OPjo8TERKfvW+DY9U5dH0qW6NmdCrsEACi2cvv+7dDdaj4+PtqxY4f+/e9/65tvvtHx48clSffdd5+6du2qwYMHy8XFoU0AAAAUKIeTi4uLi4YPH67hw4c7ox4AAIBC5bSPD0lNTVVsbKzS0tKctUoAAIAC53A4+vXXX9WuXTuVLVtW/v7+1rvXzp8/r/bt2+vHH390uEgAAICC4lA42r9/v1q3bq3jx49r4MCBNn2VK1fWlStX9MknnzhUIAAAQEFyKBxNmjRJNWrU0MGDBzV79mzdeONb+/bttWfPHocKBAAAKEgOhaNt27Zp2LBh8vLyksVisev39/fXmTNnHNkEAABAgXIoHF29elU+Pj459iclJTmyegAAgALnUDiqXbu2fvnllxz7N27cqIYNG+ZpnbNmzdLTTz+tWrVqyWKxKDAw8Kbjz507p6FDh6pKlSpyd3dX48aNtWDBghzHL1u2TE2bNpWHh4d8fX3Vp08fnTp1Kk81AgCAksuhcNS3b18tWbJEUVFR1ras02tvvPGGNmzYoAEDBuRpnePHj9dPP/2k2rVrq3z58jcdm5CQoIceekhffPGFnnnmGc2dO1f+/v4aPny4pk6dajd+3rx56tu3rzw8PBQREaHw8HBFRUUpJCSE038AAECSgx8fkpaWpkceeURbt25V3bp1dfToUTVs2FBxcXGKi4tThw4d9O233+quu3KfwU6cOKFatWpJku6++26lpKQoOjo627Hjxo3T7Nmz9eWXX6p79+7W9s6dO+v777/Xn3/+qaCgIEnShQsXFBgYqHr16mn37t3WJ3fv27dPLVq00NChQ7Vw4cJc1cjHh6Cw8PEhAHD7cvv+7dCRI1dXV0VFRenNN9+Ul5eX3N3ddfz4cVWtWlVvvPGG1q1bl6dgJMkajHLjs88+U1BQkE0wkqQxY8YoPT1dy5cvt7atWbNGKSkpGj16tM1HmjRr1kxt2rTRihUreIAlAAC4/Y8PSUtL065du1StWjWNGTNGY8aMcWZdt3T27FnFxMSob9++dn0PPvigLBaLzWMEsv4dEhJiNz4kJERbtmzR4cOH1bhxY7v+1NRUpaamWr/mQnMAAEqu2z5yVKpUKbVv317fffedM+vJtdjYWEmSn5+fXZ+bm5t8fX11+vTpXI3PajOPN5s1a5Z8fHysr5o1azpcPwAAKJocCkdVq1a1e/BjQbl8+bKk60EoO+7u7tYxtxrv7u5uM+ZG48aNU2JiovUVExPjUO0AAKDocuiao6efflqrVq0qlIDk6ekpSTanu8yuXLliHXOr8VeuXLEZcyM3Nzd5e3vbvAAAQMnkUDgaNmyYUlJS1LFjR61bt06HDx/WX3/9ZffKDzVq1JCU/amwq1ev6sKFCzan0G42/man3AAAwJ3lti/Ilq7fap/lp59+ynHctWvXHNlMtqpWrSo/Pz/9/PPPdn27du2SYRhq3ry5ta158+b66KOPtHPnTtWtW9dm/M6dO+Xl5aX69es7vU4AAFC85Dkcbd26VQ0aNFClSpU0adKkbD9TraD07dtXb7zxhr766iub2/nffvttubi4qFevXta2Ll26aPTo0XrvvffUr18/m+ccbd26VUOGDJGrq2uB7wMAACha8hyO2rZtqyVLlqhv376aMmWKUlJSNHDgQE2bNs3mSNLtWrJkifXjPOLi4pSWlqbp06dLksqVK6dRo0ZZx44dO1arVq3SgAED9MsvvygoKEhr1qzRunXrNHHiRJtnJvn6+mrmzJkKDw9XWFiYBgwYoPj4eEVERKhKlSqaNm2aw7UDAIDiL8/h6MaLr1NTU7VmzRqb0OKIRYsWacuWLTZtEydOlCQFBATYbKd8+fLavn27xo8frwULFigpKUl16tTRBx98oBEjRtit+4UXXpCvr6/mzJmj8PBweXp6qkOHDpo1a5b1miQAAHBnc+iaoyzOvFtt8+bNeRpfrVo1LV68ONfj+/Xrp379+uWxKgAAcKdw6G41AACAkoZwBAAAYHJbp9XWrVtnfV7Q5cuXZbFY9MUXX2jfvn12Yy0Wi1566SXHqgQAACggFiOPFwzddVfeDjZZLJZ8ec5RYUpKSpKPj48SExOd/rTswLHrnbo+lCzRszsVdgkAUGzl9v07z0eONm3a5FBhAAAARVmew1FoaGh+1AEAAFAkcEE2AACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMHEp7AIA5F7g2PWFXUKeRc/uVNglAECecOQIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBSrMORxWLJ8ZWQkGAz9ty5cxo6dKiqVKkid3d3NW7cWAsWLCicwgEAQJHlUtgFOKp169YaPny4XXuZMmWs/05ISNBDDz2k2NhYhYeHKygoSGvWrNHw4cN15swZTZ48uSBLBgAARVixD0e1atVS//79bzrm9ddf17Fjx/Tll1+qe/fukqRnn31WnTt31owZMzRw4EAFBQUVRLkAAKCIK9an1bKkpaUpOTk5x/7PPvtMQUFB1mCUZcyYMUpPT9fy5cvzu0QAAFBMFPtwtGrVKnl6esrb21sVK1bUsGHDdPbsWWv/2bNnFRMTowcffNBu2QcffFAWi0V79uwpyJIBAEARVqxPqzVv3lw9evRQ3bp1dfnyZW3atEmLFy/WDz/8oN27d6tatWqKjY2VJPn5+dkt7+bmJl9fX50+ffqm20lNTVVqaqr166SkJOfuCAAAKDKKdTi68YhPv379FBoaqoEDB2ry5Mn6+OOPdfnyZUnXg1B23N3drWNyMmvWLE2dOtU5RQMAgCKt2J9Wu9GAAQMUGBio9evXS5I8PT0lyebIj9mVK1esY3Iybtw4JSYmWl8xMTHOLRoAABQZJS4cSVJgYKDi4uIkSTVq1JCkbE+dXb16VRcuXMj2lJuZm5ubvL29bV4AAKBkKtan1bJjGIaOHTumqlWrSpKqVq0qPz8//fzzz3Zjd+3aJcMw1Lx584IuE7hjBI5dX9gl5Fn07E6FXQKAQlRsjxydO3cu2/a5c+fq9OnT6ty5s7Wtb9++OnnypL766iubsW+//bZcXFzUq1evfK0VAAAUH8X2yNGsWbP0448/6oknnlBAQICuXLmizZs3a+3atapbt66mTJliHTt27FitWrVKAwYM0C+//GJ9Qva6des0ceJE1apVq/B2BAAAFCnFNhy1a9dOhw8f1tKlSxUfHy+LxaLatWvr1Vdf1UsvvSQfHx/r2PLly2v79u0aP368FixYoKSkJNWpU0cffPCBRowYUYh7AQAAihqLYRhGYRdR3CQlJcnHx0eJiYlOvzi7OF6fAZQ0XHMElEy5ff8uttccAQAA5AfCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAE8IRAACACeEIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATl8IuAADguMCx6wu7hDyLnt2psEsAssWRIwAAABOOHAEAkEscobszcOQIAADAhHAEAABgQjgCAAAwIRwBAACYEI4AAABMCEcAAAAmhCMAAAATwhEAAIAJ4QgAAMCEcAQAAGBCOAIAADDhs9UA4AbF8fOzgJwUx5/nwv48OMIRAKBQFMc3bdwZOK0GAABgQjgCAAAwIRwBAACYEI4AAABM7rhwtGzZMjVt2lQeHh7y9fVVnz59dOrUqcIuCwAAFBF3VDiaN2+e+vbtKw8PD0VERCg8PFxRUVEKCQnRmTNnCrs8AABQBNwxt/JfuHBB48aNU5MmTbR582a5uFzf9UcffVQtWrTQpEmTtHDhwkKuEgAAFLY75sjRmjVrlJKSotGjR1uDkSQ1a9ZMbdq00YoVK5SWllaIFQIAgKLgjglHe/bskSSFhITY9YWEhCg5OVmHDx8u6LIAAEARc8ecVouNjZUk+fn52fVltZ0+fVqNGze2609NTVVqaqr168TERElSUlKS0+vMTL3s9HUCAFCc5Mf7q3m9hmHcdNwdE44uX74eOtzc3Oz63N3dbcbcaNasWZo6dapde82aNZ1YIQAAkCSfd/J3/cnJyfLx8cmx/44JR56enpKuHwXy8PCw6bty5YrNmBuNGzdOY8aMsX6dmZmpixcvqmLFirJYLPlUcfGWlJSkmjVrKiYmRt7e3oVdzh2DeS8czHvBY84LR3Gfd8MwlJycrOrVq9903B0TjmrUqCHp+qmzunXr2vTd7JSbdP1o041HnMqVK+f8Iksgb2/vYvkLVNwx74WDeS94zHnhKM7zfrMjRlnumAuymzdvLknauXOnXd/OnTvl5eWl+vXrF3RZAACgiLljwlGXLl3k6emp9957TxkZGdb2ffv2aevWrerZs6dcXV0LsUIAAFAU3DGn1Xx9fTVz5kyFh4crLCxMAwYMUHx8vCIiIlSlShVNmzatsEssUdzc3DR58uRsL4BH/mHeCwfzXvCY88Jxp8y7xbjV/WwlzGeffaY5c+bo0KFD8vT0VIcOHTRr1iwFBQUVdmkAAKAIuOPCEQAAwM3cMdccAQAA5AbhCAAAwIRwBAAAYEI4Qp4sW7ZMTZs2lYeHh3x9fdWnTx+dOnUqV8uuWLFCQ4YMUePGjeXi4iKLxaLo6Oj8LbiEuN15//vvv/Xuu++qY8eOqlmzpjw8PBQcHKzhw4crJiamACovvm53ztPT0zVixAg1bdpUvr6+cnNzU1BQkHr16qX9+/fnf+HFnCN/Y27Us2dPWSwWnmGXC47Me1hYmCwWS7avr7/+On8LzydckI1cmzdvnp5//nm1atVK/fv3V3x8vN555x25ublp7969t3wce1hYmHbv3q17771XCQkJ+vPPP3Xy5EkFBgYWzA4UU47M+/fff69OnTqpXbt2at++vXx9fXXw4EF99NFHcnV11c6dO9WwYcMC3JviwZE5v3TpkkJDQ9WqVSsFBQWpbNmy+uuvv7R48WKdPXtW3333ndq3b1+Ae1N8OPo3xmz9+vXq3Lmz3Nzc5O/vr8OHD+dj5cWbM/62Hzx4UBEREdn25fTpE0WaAeRCfHy84eXlZTRp0sRIT0+3tu/du9ewWCzGM888c8t1nDp1yrrsyJEjDUnGyZMn86vkEsHReT958qRx9OhRu/aoqChDktGjRw+n11zcOeNnPTuxsbFGqVKljA4dOjir1BLFmfOenJxs+Pv7G6NGjTICAgKM4ODg/Ci5RHDGvIeGhhoBAQH5WGXB47QacmXNmjVKSUnR6NGj5eLyv2eHNmvWTG3atNGKFSuUlpZ203X4+/vbLItbc3TeAwMDVadOHbv2hx9+WBUqVNCBAwfype7izBk/69mpWrWqPD099ffffzuz3BLDmfM+YcIEpaena8aMGflVbonhzHnPzMxUUlKSMjMz86vcAkM4Qq7s2bNHkhQSEmLXFxISouTkZA5b54P8mvfExEQlJyercuXKDtdY0jhrzq9du6b4+HidO3dO+/btU//+/ZWcnKxOnTo5veaSwFnzvnfvXs2dO1cRERHF9oNRC5Kz5j02NlZeXl7y8fFRmTJl9Pjjj2vfvn1Or7eg8N945EpsbKwkZXvuOKvt9OnTaty4cYHWVdLl17xPnz5d6enpGjRokONFljDOmvNDhw7pnnvusX5dtmxZvfzyy3r11VedWG3J4Yx5z8jI0LPPPquHH35YvXr1yp9CSxhnzHtgYKBCQkJ0zz33yM3NTb/99pvee+89tWrVSt99953atWuXP8XnI8IRcuXy5cuSlO3n6bi7u9uMgfPkx7yvWLFCc+bMUYcOHTRkyBDHiyxhnDXnQUFBioqKUlpamo4dO6bPP/9cly5dUlpamkqXLu3coksAZ8z7nDlz9Oeff+rLL790foEllDPmPTIy0ubr7t27q3///mrSpIlGjBihI0eOOKfYAkQ4Qq54enpKklJTU+Xh4WHTd+XKFZsxcB5nz/u3336rAQMG6P7779fKlSt1112cWb+Rs+a8TJkyevjhh61fDx06VE2aNNHRo0e1YcMGJ1ZcMjg678ePH9fUqVM1fvx41a5dO/8KLWHy6297cHCwevbsqcjISB09elR169Z1vNgCxF9G5EqNGjUkXT+8eqObHZaFY5w5799//726d++u+vXr64cffpCPj4/zCi1B8utn3cvLS927d9cPP/yg48ePO1ZkCeTovL/44osqX768evXqpejoaOsrIyND6enpio6O1rlz5/Kn+GIsP/+2Zz2mJS4u7vaKK0SEI+RK8+bNJUk7d+6069u5c6e8vLx40Fo+cNa8b9iwQd26dVO9evW0ceNGVaxY0em1lhT5+bOe9T/xixcv3n6BJZSj8x4dHa0zZ84oODhYQUFB1ldsbKxOnDihoKAgrrHLRn7+vB89elTS9Ts1i53CfpYAioe4uDjD09Mzx2dhDB061Np25swZ49ChQ8alS5dyXB/POcodZ8z7hg0bDHd3d+Oee+4x4uLiCqz24srROT9//rxx7do1u/X+97//NapVq2Z4eXnd9HfjTuXovP/000/G6tWr7V6VKlUyatSoYaxevdrYuXNnge5TceDovF+8eNFITU21W+/evXuN0qVLG40aNcrfHcgnhCPk2jvvvGNIMlq1amV8+OGHxvTp042KFSsaVatWNU6fPm0dN2jQIEOSsWnTJpvlt2zZYrz22mvGa6+9ZjzwwAOGJOPFF1+0tiUkJBTwHhUPjsz73r17DXd3d8PNzc2IiIgwlixZYveCPUfmPCIiwggICDDCw8ONd9991/jggw+Mf/3rX0aFChUMi8ViLFq0qBD2qHhw9G9MdngI5K05Mu+rV682KleubIwcOdKIiIgw5s+fbwwfPtxwdXU1PD09i20gJRwhT5YuXWrcf//9hru7u1GhQgWjV69exokTJ2zG5PSHa/LkyYakHF8cRcrZ7c774sWLbzrnHDzO2e3O+b59+4y+ffsatWvXNsqUKWOULl3a8PPzM3r27Gns2LGjgPei+HHkb0x2CEe5c7vz/scffxhPP/20Ubt2bcPLy8soXbq0ERAQYAwdOtQ4cuRIAe+F8/DZagAAACZckA0AAGBCOAIAADAhHAEAAJgQjgAAAEwIRwAAACaEIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4Au4QmzdvlsVisXl5eXmpadOmevfdd3Xt2rXCLrHARUZG6p133insMgAUMS6FXQCAgtWrVy898cQTMgxDZ86cUWRkpMLDw3Xw4EF9/PHHhV1egYqMjFR0dLTCw8MLuxQARQhHjoA7zH333af+/ftrwIABeuWVV7R7925Vr15dCxcu1Llz55yyjUuXLjllPcVZSkpKnpfJyMhQampqPlQDIC8IR8AdztvbWw8++KAMw9CJEyeUmZmpGTNmqE2bNqpatapcXV3l7++v5557ThcuXLBZNjo6WhaLRVOmTNHy5cvVtGlTeXh4aOTIkZKkw4cP65///KcaNWqksmXLytPTU02bNtWCBQvs6pgyZYosFov++OMPhYeHq1q1aipTpozat2+vP//8U5L01VdfqUmTJvLw8FBAQIA+/PDDbPfpxx9/VMeOHVWuXDm5u7urcePGdmMtFou2bNmiU6dO2ZxqjI6Oto7Zt2+funXrJl9fX7m5uSk4OFgzZsxQRkaGzbrCwsIUGBioEydOqEePHqpQoYLKli1703nP2t+DBw9qzJgx8vPzk5ubm37++WdJ0vLly9W5c2f5+/vLzc1Nvr6+6tq1q/7zn//YrSswMFBhYWE6ePCgHn30UZUtW1Y+Pj7q0aOHzp49azf+4MGDevzxx+Xl5aVy5cqpS5cuOnHihHU9tzOfQEnCaTXgDmcYho4dOyZJ8vX1VVpamt566y09/fTT6tatmzw9PbVnzx4tWrRI27dv1y+//CJXV1ebdXz99deaO3eunnvuOY0YMULe3t6Srl/ntH37dnXt2lX+/v5KSUnRypUrNXz4cMXHx2vcuHF29QwcOFA+Pj4aP3684uPjNWfOHHXs2FHTp0/XK6+8ohEjRmjo0KFatGiRnnvuOTVs2FBt2rSxLv/xxx9rxIgRatmypV599VV5eXkpKipKzz33nI4fP64333xTkrRkyRLNmDFD8fHxioiIsC5fqVIlSdK3336rbt26qU6dOnrxxRdVoUIF/fzzz5o0aZL279+vlStX2tSdkpKi0NBQPfTQQ5oxY4bOnz+fq/nv16+fypQpoxdffFEWi0XVqlWTJL3//vuqVKmSnnvuOVWqVEnHjx/Xxx9/rFatWunXX39V3bp1bdYTGxurdu3aqXv37urWrZt+++03ffzxx0pKStIPP/xgHXf8+HE99NBDSk1N1ciRIxUUFKTNmzerbdu2unz5sl19uZ1PoEQxANwRNm3aZEgyJk6caMTFxRnnz583fv/9d2PYsGGGJKN58+aGYRhGZmamcfnyZbvlFy5caEgyli9fbm07efKkIckoXbq0cfjwYbtlLl26ZNd27do1IzQ01PD29jbS0tKs7ZMnTzYkGV26dDEyMzOt7XPnzjUkGWXLljViYmKs7efPnzfc3NyMXr16WdvOnDljuLm5Gb1797bb7ujRo4277rrLOHbsmLUtNDTUCAgIsBt75coVo3Llykbr1q2N9PR0m763337bkGRs2rTJZj2SjEmTJtmtKydZ+9u2bVsjIyPDrj8lJcWu7Y8//jBcXV2N5557zqY9ICDA7ntjGIbxz3/+05BkHDp0yNrWq1cvQ5KxceNGm7H/+te/DElGaGiotS2v8wmUFJxWA+4wr732mipVqqTKlSvr3nvv1aJFi/TYY4/p66+/lnT9dJOHh4ck6dq1a0pISFB8fLzatWsnSdq9e7fdOjt16qTg4GC7dk9PT+u/r169qgsXLujixYvq2LGjkpKSdPjwYbtlRo0aJYvFYv26VatWkqQuXbrIz8/P2l6pUiUFBwdbj3pJ0qpVq5SamqohQ4YoPj7e5vXkk08qMzNTGzduvOUcRUVF6fz58xo4cKB1/7Nejz/+uCTZHI3JMmbMmFuu+0YvvPCCSpUqZddepkwZSdeP7CUlJSk+Pt66z9l9D6pXr66ePXvatGV9z7Lm6Nq1a1q3bp2aNGli7cvyyiuv2K3TWfMJFDecVgPuMM8884x69+4ti8UiT09P1atXTxUrVrQZs2LFCs2ZM0e//fab0tPTbfr+/vtvu3XeeIonS0pKiqZMmaIVK1YoJibGrj+7dQUFBdl8Xb58eUnXr6u5Ufny5XXq1Cnr14cOHZIkPfLII9nWIylXF51nrefZZ5/Vs88+m6v1VKpUST4+Prdc941ymrtff/1VkyZN0ubNm+0ucL9xjiSpVq1adm1Z39esa8Xi4uJ06dKlbINslSpVVK5cOZs2Z80nUNwQjoA7TJ06dfTwww/n2P/ll1+qV69eatGihd59913VrFlT7u7uunbtmh599FFlZmbaLWM+QmTWp08frV+/XsOHD1ebNm1UoUIFubi46Ntvv1VERES268ruKMrN2g3DsPv34sWLbY4ymWUXInJa5+zZs9W0adNsx1SvXt3m65zm4FayW+6vv/5SmzZt5OPjo4kTJyo4OFhlypSRxWJReHh4tnfC5TQ/0v/2xzxXNxt349eOzidQ3BCOANhYunSp3N3dtWnTJps37uxOgd1MQkKC1q9frwEDBtjd2fTjjz86pdYb1atXT9L1IyY3C4BZzKfvsluPp6dnrtbjbKtXr9alS5e0du1atW3b1qbvwoULcnNzu631Vq5cWWXKlMn2e3nu3DklJibatOV1PoGSgmuOANgoVaqULBaLzVEdwzA0ffr0PK8na1mz//73v1q4cKHjhWbj6aeflpubm6ZMmZLtnVeJiYk2zxHy8vJSQkKCXY2PPPKIKleurDfeeEPx8fF267ly5YqSk5OdvwP/v5zmbsGCBdnemp+X9Xbq1Em//fabfvrpJ5u+119/3W58XucTKCk4cgTARo8ePfTll1+qXbt2GjhwoNLT0/X1119n++Z4M2XLllXHjh21dOlSeXh4qHnz5jp16pQ++ugjBQUF2T0zyRn8/Pz0wQcfaNiwYWrQoIEGDhyogIAAxcXF6cCBA/r666/1xx9/WK9feuCBB7Ru3TqNHj1aLVu2VKlSpfTkk0+qTJky+vTTT9W1a1fVr19fQ4cOVd26dZWQkKDDhw/rq6++0urVq7N9JpAzPPbYY/L09NSAAQM0atQolS9fXjt27NC3336r2rVr2z1nKS+mT5+uDRs26IknnrC5lX/Pnj3y9fW1OZqW1/kESgrCEQAbvXv3VnJysiIiIvR///d/Kl++vJ588knNnj3b7sLtW1m6dKnGjh2rtWvX6pNPPlHdunU1Y8YMlS5dWkOGDMmX+ocMGaJ69erprbfe0kcffaSEhAT5+voqODhYr732mqpWrWodGx4erqNHj2rZsmV6//33ZRiGTp48qTJlyuiRRx7R3r17NXv2bH322WeKi4tT+fLlVbt2bY0ZM0aNGzfOl/olqXbt2vruu+80fvx4zZw5U6VKlVKrVq20ZcsWjRo1yuZBlXlVt25dbdu2TS+99JLmz5+v0qVLq127dtqyZYv1AZtmeZlPoKSwGLe6Qg8AUOJlPSrgH//4B0+/xh2Pa44A4A5z5coVu7aZM2dKkjp27FjQ5QBFDkeOAOAOExwcrPbt2+vuu+/W1atXFRUVpe+//15t2rTRTz/9dNPHAgB3AsIRANxhXn75Za1du1anT59WWlqa/P391aNHD02YMMH6ZG7gTkY4AgAAMOGaIwAAABPCEQAAgAnhCAAAwIRwBAAAYEI4AgAAMCEcAQAAmBCOAAAATAhHAAAAJoQjAAAAk/8PISqlW6fTi88AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.hist(diff, bins=10)\n",
    "plt.xlabel(\"Parameter range\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.title(\"Target similarity in MoEs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "e9e71ffb-1fd8-4e18-bc79-69b2858a9b7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['source', 'tgts_nz', 'status', 'stats'])"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "range = []\n",
    "\n",
    "for i in range(800):\n",
    "    if(ssa_approx['status']):\n",
    "        ssa = dict_keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "7068e834-7491-4c02-a165-30df8893df51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "799"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "f5e10d20-1e58-40d0-b454-47fc39e87789",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.01134816,  0.02860608, -0.24385521,  0.05569104,  0.02034841,\n",
       "       -0.01062901,  0.01007963], dtype=float32)"
      ]
     },
     "execution_count": 135,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len_stats_w['tgts_nz'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "70d6588e-3bcb-4c80-bb9c-6c39d1df96c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.09507355,  0.02956831,  0.00601606, -0.52719149,  0.27250724,\n",
       "       -0.01057738,  0.84489429, -0.03861785,  0.3157704 , -0.00259541,\n",
       "        0.08832736,  0.83140138,  0.24810372, -0.03533605, -0.04548484])"
      ]
     },
     "execution_count": 145,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len_stats_w['first_source'].reshape(-1)*len_stats_w['source'][0][0].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "ed742a7b-c507-47d5-9903-02a90c0a8a00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['0best', '0overlap', '0extra', '0total', '0subsets', '0errors', '1best', '1overlap', '1extra', '1total', '1subsets', '1errors', '2best', '2overlap', '2extra', '2total', '2subsets', '2errors', '3best', '3overlap', '3extra', '3total', '3subsets', '3errors', '4best', '4overlap', '4extra', '4total', '4subsets', '4errors', '5best', '5overlap', '5extra', '5total', '5subsets', '5errors'])"
      ]
     },
     "execution_count": 153,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len_stats_w['stats'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e8c3701-0740-4e27-aa18-cea8ebc130f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
