{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9f3bad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import atom3d.datasets as da\n",
    "#da.download_dataset('lba', 'atom3d')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8c6c0a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import os\n",
    "import time\n",
    "import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn as nn\n",
    "from atom3d.datasets import LMDBDataset\n",
    "from scipy.stats import spearmanr\n",
    "random_seed=3\n",
    "np.random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "\n",
    "class CNN3D_LBA(nn.Module):\n",
    "    def __init__(self, in_channels, spatial_size,\n",
    "                 conv_drop_rate, fc_drop_rate,\n",
    "                 conv_filters, conv_kernel_size,\n",
    "                 max_pool_positions, max_pool_sizes, max_pool_strides,\n",
    "                 fc_units,\n",
    "                 batch_norm=True,\n",
    "                 dropout=False):\n",
    "        super(CNN3D_LBA, self).__init__()\n",
    "\n",
    "        layers = []\n",
    "        if batch_norm:\n",
    "            layers.append(nn.BatchNorm3d(in_channels))\n",
    "\n",
    "        # Convs\n",
    "        for i in range(len(conv_filters)):\n",
    "            layers.extend([\n",
    "                nn.Conv3d(in_channels, conv_filters[i],\n",
    "                          kernel_size=conv_kernel_size,\n",
    "                          bias=True),\n",
    "                nn.ReLU()\n",
    "                ])\n",
    "            spatial_size -= (conv_kernel_size - 1)\n",
    "            if max_pool_positions[i]:\n",
    "                layers.append(nn.MaxPool3d(max_pool_sizes[i], max_pool_strides[i]))\n",
    "                spatial_size = int(np.floor((spatial_size - (max_pool_sizes[i]-1) - 1)/max_pool_strides[i] + 1))\n",
    "            if batch_norm:\n",
    "                layers.append(nn.BatchNorm3d(conv_filters[i]))\n",
    "            if dropout:\n",
    "                layers.append(nn.Dropout(conv_drop_rate))\n",
    "            in_channels = conv_filters[i]\n",
    "\n",
    "        layers.append(nn.Flatten())\n",
    "        in_features = in_channels * (spatial_size**3)\n",
    "        # FC layers\n",
    "        for units in fc_units:\n",
    "            layers.extend([\n",
    "                nn.Linear(in_features, units),\n",
    "                nn.ReLU()\n",
    "                ])\n",
    "            if batch_norm:\n",
    "                layers.append(nn.BatchNorm3d(units))\n",
    "            if dropout:\n",
    "                layers.append(nn.Dropout(fc_drop_rate))\n",
    "            in_features = units\n",
    "\n",
    "        # Final FC layer\n",
    "        layers.append(nn.Linear(in_features, 1))\n",
    "\n",
    "        self.model = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x).view(-1)\n",
    "\n",
    "    \n",
    "from atom3d.datasets import LMDBDataset\n",
    "from atom3d.util.voxelize import dotdict, get_center, gen_rot_matrix, get_grid\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import dotenv as de\n",
    "de.load_dotenv(de.find_dotenv(usecwd=True))\n",
    "\n",
    "\n",
    "class CNN3D_TransformLBA(object):\n",
    "    def __init__(self, random_seed=None, **kwargs):\n",
    "        self.random_seed = random_seed\n",
    "        self.grid_config =  dotdict({\n",
    "            # Mapping from elements to position in channel dimension.\n",
    "            'element_mapping': {\n",
    "                'H': 0,\n",
    "                'C': 1,\n",
    "                'O': 2,\n",
    "                'N': 3,\n",
    "                'F': 4,\n",
    "            },\n",
    "            # Radius of the grids to generate, in angstroms.\n",
    "            'radius': 20.0,\n",
    "            # Resolution of each voxel, in angstroms.\n",
    "            'resolution': 1.0,\n",
    "            # Number of directions to apply for data augmentation.\n",
    "            'num_directions': 20,\n",
    "            # Number of rolls to apply for data augmentation.\n",
    "            'num_rolls': 20,\n",
    "        })\n",
    "        # Update grid configs as necessary\n",
    "        self.grid_config.update(kwargs)\n",
    "\n",
    "    def _voxelize(self, atoms_pocket, atoms_ligand):\n",
    "        # Use center of ligand as subgrid center\n",
    "        ligand_pos = atoms_ligand[['x', 'y', 'z']].astype(np.float32)\n",
    "        ligand_center = get_center(ligand_pos)\n",
    "        # Generate random rotation matrix\n",
    "        rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed)\n",
    "        # Transform protein/ligand into voxel grids and rotate\n",
    "        grid = get_grid(pd.concat([atoms_pocket, atoms_ligand]),\n",
    "                        ligand_center, config=self.grid_config, rot_mat=rot_mat)\n",
    "        # Last dimension is atom channel, so we need to move it to the front\n",
    "        # per pytroch style\n",
    "        grid = np.moveaxis(grid, -1, 0)\n",
    "        return grid\n",
    "\n",
    "    def __call__(self, item):\n",
    "        # Transform protein/ligand into voxel grids.\n",
    "        # Apply random rotation matrix.\n",
    "        transformed = {\n",
    "            'feature': self._voxelize(item['atoms_pocket'], item['atoms_ligand']),\n",
    "            'label': item['scores']['neglog_aff'],\n",
    "            'id': item['id']\n",
    "        }\n",
    "        return transformed\n",
    "\n",
    "def conv_model(in_channels, spatial_size, args):\n",
    "    num_conv = args.num_conv\n",
    "    conv_filters = [32 * (2**n) for n in range(num_conv)]\n",
    "    conv_kernel_size = 3\n",
    "    max_pool_positions = [0, 1]*int((num_conv+1)/2)\n",
    "    max_pool_sizes = [2]*num_conv\n",
    "    max_pool_strides = [2]*num_conv\n",
    "    fc_units = [512]\n",
    "\n",
    "    model = CNN3D_LBA(\n",
    "        in_channels, spatial_size,\n",
    "        args.conv_drop_rate,\n",
    "        args.fc_drop_rate,\n",
    "        conv_filters, conv_kernel_size,\n",
    "        max_pool_positions,\n",
    "        max_pool_sizes, max_pool_strides,\n",
    "        fc_units,\n",
    "        batch_norm=args.batch_norm,\n",
    "        dropout=not args.no_dropout)\n",
    "    return model\n",
    "\n",
    "def train_loop(model, loader, optimizer, device):\n",
    "    model.train()\n",
    "\n",
    "    losses = []\n",
    "    epoch_loss = 0\n",
    "    progress_format = 'train loss: {:6.6f}'\n",
    "    with tqdm.tqdm(total=len(loader), desc=progress_format.format(0)) as t:\n",
    "        for i, data in enumerate(loader):\n",
    "            feature = data['feature'].to(device).to(torch.float32)\n",
    "            label = data['label'].to(device).to(torch.float32)\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "            # forward + backward + optimize\n",
    "            output = model(feature)\n",
    "            batch_losses = F.mse_loss(output, label, reduction='none')\n",
    "            batch_losses_mean = batch_losses.mean()\n",
    "            batch_losses_mean.backward()\n",
    "            optimizer.step()\n",
    "            # stats\n",
    "            epoch_loss += (batch_losses_mean.item() - epoch_loss) / float(i + 1)\n",
    "            losses.extend(batch_losses.tolist())\n",
    "            t.set_description(progress_format.format(np.sqrt(epoch_loss)))\n",
    "            t.update(1)\n",
    "\n",
    "    return np.sqrt(np.mean(losses))\n",
    "import pickle\n",
    "\n",
    "def test(model, loader, device):\n",
    "    model.eval()\n",
    "\n",
    "    losses = []\n",
    "\n",
    "    ids = []\n",
    "    \n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    with torch.no_grad():\n",
    "        for data in loader:\n",
    "            feature = data['feature'].to(device).to(torch.float32)\n",
    "            label = data['label'].to(device).to(torch.float32)\n",
    "            output = model(feature)\n",
    "            batch_losses = F.mse_loss(output, label, reduction='none')\n",
    "            losses.extend(batch_losses.tolist())\n",
    "            ids.extend(data['id'])\n",
    "            y_true.extend(label.tolist())\n",
    "            y_pred.extend(output.tolist())\n",
    "\n",
    "        results_df = pd.DataFrame(\n",
    "            np.array([ids, y_true, y_pred]).T,\n",
    "            columns=['structure', 'true', 'pred'],\n",
    "            )\n",
    "        r_p = np.corrcoef(y_true, y_pred)[0,1]\n",
    "        r_s = spearmanr(y_true, y_pred)[0]\n",
    "\n",
    "    return np.sqrt(np.mean(losses)), r_p, r_s, results_df\n",
    "def save_weights(model, weight_dir):\n",
    "    torch.save(model.state_dict(), weight_dir)\n",
    "def train(args, device, test_mode=False):\n",
    "    print(\"Training model with config:\")\n",
    "    print(str(json.dumps(args.__dict__, indent=4)) + \"\\n\")\n",
    "\n",
    "    # Save config\n",
    "    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:\n",
    "        json.dump(args.__dict__, f, indent=4)\n",
    "\n",
    "    np.random.seed(args.random_seed)\n",
    "    torch.manual_seed(args.random_seed)\n",
    "\n",
    "    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),\n",
    "                                transform=CNN3D_TransformLBA(random_seed=args.random_seed))\n",
    "    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),\n",
    "                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))\n",
    "    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),\n",
    "                               transform=CNN3D_TransformLBA(random_seed=args.random_seed))\n",
    "\n",
    "    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)\n",
    "    val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False)\n",
    "    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False)\n",
    "\n",
    "    for data in train_loader:\n",
    "        in_channels, spatial_size = data['feature'].size()[1:3]\n",
    "        print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size))\n",
    "        break\n",
    "\n",
    "    model = conv_model(in_channels, spatial_size, args)\n",
    "    print(model)\n",
    "    model.to(device)\n",
    "\n",
    "    best_val_loss = np.Inf\n",
    "    best_rp = 0\n",
    "    best_rs = 0\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)\n",
    "\n",
    "    for epoch in range(1, args.num_epochs+1):\n",
    "        start = time.time()\n",
    "        train_loss = train_loop(model, train_loader, optimizer, device)\n",
    "        val_loss, r_p, r_s, val_df = test(model, val_loader, device)\n",
    "        if val_loss < best_val_loss:\n",
    "            print(f\"\\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}\")\n",
    "            save_weights(model, os.path.join(args.output_dir, f'best_weights.pt'))\n",
    "            best_val_loss = val_loss\n",
    "            best_rp = r_p\n",
    "            best_rs = r_s\n",
    "        elapsed = (time.time() - start)\n",
    "        print('Epoch {:03d} finished in : {:.3f} s'.format(epoch, elapsed))\n",
    "        print('\\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(\n",
    "            train_loss, val_loss, r_p, r_s))\n",
    "        file=open(\"model_{}.p\".format(epoch), 'wb')\n",
    "        pickle.dump(model, file)\n",
    "        file.close()\n",
    "    if test_mode:\n",
    "        model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt')))\n",
    "        rmse, pearson, spearman, test_df = test(model, test_loader, device)\n",
    "        test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl'))\n",
    "        print('Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(\n",
    "            rmse, pearson, spearman))\n",
    "        test_file = os.path.join(args.output_dir, f'test_results.txt')\n",
    "        with open(test_file, 'a+') as out:\n",
    "            out.write('{}\\t{:.7f}\\t{:.7f}\\t{:.7f}\\n'.format(\n",
    "                args.random_seed, rmse, pearson, spearman))\n",
    "\n",
    "    return best_val_loss, best_rp, best_rs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d167ccb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class arguments():\n",
    "    def __init__(self):\n",
    "        self.data_dir='atom3d/split-by-sequence-identity-30/data'\n",
    "        self.mode='train'\n",
    "        self.output_dir='savedOutputs'\n",
    "        self.unobserved=False\n",
    "        self.learning_rate=.001\n",
    "        self.conv_drop_rate=.1\n",
    "        self.fc_drop_rate=.25\n",
    "        self.num_epochs=50\n",
    "        self.num_conv=4\n",
    "        self.batch_norm=False\n",
    "        self.no_dropout=False\n",
    "        self.batch_size=16\n",
    "        self.random_seed=3\n",
    "\n",
    "args=arguments()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)\n",
    "#train(args, device, args.mode=='train')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e2ae6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "file=open(\"model_{}.p\".format(50), 'rb')\n",
    "model=pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5f2bf8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(model.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65936a2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class arguments():\n",
    "    def __init__(self):\n",
    "        self.data_dir='atom3d/split-by-sequence-identity-30/data'\n",
    "        self.mode='train'\n",
    "        self.output_dir='savedOutputs'\n",
    "        self.unobserved=False\n",
    "        self.learning_rate=.001\n",
    "        self.conv_drop_rate=.1\n",
    "        self.fc_drop_rate=.25\n",
    "        self.num_epochs=50\n",
    "        self.num_conv=4\n",
    "        self.batch_norm=False\n",
    "        self.no_dropout=False\n",
    "        self.batch_size=16\n",
    "        self.random_seed=3\n",
    "\n",
    "args=arguments()\n",
    "val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),\n",
    "                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))\n",
    "\n",
    "val_loader = DataLoader(val_dataset, 80, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83bcea73",
   "metadata": {},
   "outputs": [],
   "source": [
    "data=(next(iter(val_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a22ccac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x=data['feature']\n",
    "y=data['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c26cb033",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c398f58e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "k=.7\n",
    "import scipy\n",
    "def getID(k, Z, W=None, mode='kT'):\n",
    "    '''\n",
    "    calculates ID \n",
    "    mode: kT, WT (what to return)\n",
    "    requires (d, n) format\n",
    "    currently in numpy because pytorch doesn't support QR pivots\n",
    "    k: number of columns\n",
    "    Z: layer after nonlinearity\n",
    "    ''' \n",
    "    print(Z.shape)\n",
    "    assert(k <= Z.shape[1])\n",
    "\n",
    "    R, P = scipy.linalg.qr((Z), mode='r', pivoting=True)\n",
    "    \n",
    "    if W is not None: Wk = W[:, P[0:k]]\n",
    "    T = np.concatenate((\n",
    "        np.identity(k),\n",
    "        np.linalg.pinv(R[0:k, 0:k]) @ R[0:k, k:None]\n",
    "        ), axis=1)\n",
    "    T = T[:, np.argsort(P)]\n",
    "    if mode == 'kT':\n",
    "        return P[0:k], T\n",
    "    elif mode == 'WT' and W is not None:\n",
    "        return Wk, T\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt')))\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy.linalg as ln \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43da0fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned=copy.deepcopy(model)\n",
    "\n",
    "with torch.no_grad():\n",
    "    Z=pruned.model[0](x.cuda())\n",
    "    Z=pruned.model[1](Z)\n",
    "    holder=copy.deepcopy(Z)\n",
    "    fi = Z.shape[1]\n",
    "    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "    fp=int(Zr.shape[1]*k)\n",
    "    (k_idx,T) = getID(fp, Zr.cpu())\n",
    "    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "plt.figure()\n",
    "plt.semilogy(sv)\n",
    "T = torch.Tensor(T)\n",
    "with torch.no_grad():\n",
    "    Wnext=pruned.model[3].weight.clone()\n",
    "    saved=copy.deepcopy(Wnext)\n",
    "    Wnext = Wnext.permute(0,2,3,4,1)\n",
    "    Wnext = torch.matmul(Wnext.cpu(), T.T)\n",
    "    Wnext = Wnext.permute(0,4,1,2,3)\n",
    "pruned.model[3].weight=nn.Parameter(Wnext, requires_grad=True)\n",
    "pruned.model[0].weight=nn.Parameter(pruned.model[0].weight[k_idx,:].clone(), requires_grad=True)\n",
    "pruned.model[0].bias=nn.Parameter(pruned.model[0].bias[k_idx].clone(), requires_grad=True)\n",
    "pruned.model[0].out_channels=fp\n",
    "pruned.model[3].in_channels=fp\n",
    "\n",
    "pruned.cuda()\n",
    "Z=pruned.model[0](x.cuda())\n",
    "Z=pruned.model[1](Z)\n",
    "with torch.no_grad():\n",
    "    Z=pruned.model[3](Z)\n",
    "    Z=pruned.model[4](Z)\n",
    "    holder=copy.deepcopy(Z)\n",
    "    fi = Z.shape[1]\n",
    "    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "    fp=int(Zr.shape[1]*k)\n",
    "    (k_idx,T) = getID(fp, Zr.cpu())\n",
    "    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "plt.figure()\n",
    "file=open(\"secondLayerSingularValues.txt\", 'w')\n",
    "string='matrix size: [2026120, 64]\\n'\n",
    "for value in sv:\n",
    "    string+=str(value)+'\\n'\n",
    "file.write(string)\n",
    "file.close()\n",
    "#np.savetxt(\"secondLayerMat.csv\", Zr.cpu().numpy(), delimiter=\",\")\n",
    "plt.semilogy(sv)\n",
    "T = torch.Tensor(T)\n",
    "with torch.no_grad():\n",
    "    Wnext=pruned.model[7].weight.clone()\n",
    "    saved=copy.deepcopy(Wnext)\n",
    "    Wnext = Wnext.permute(0,2,3,4,1)\n",
    "    Wnext = torch.matmul(Wnext.cpu(), T.T)\n",
    "    Wnext = Wnext.permute(0,4,1,2,3)\n",
    "pruned.model[7].weight=nn.Parameter(Wnext, requires_grad=True)\n",
    "pruned.model[3].weight=nn.Parameter(pruned.model[3].weight[k_idx,:].clone(), requires_grad=True)\n",
    "pruned.model[3].bias=nn.Parameter(pruned.model[3].bias[k_idx].clone(), requires_grad=True)\n",
    "pruned.model[3].out_channels=fp\n",
    "pruned.model[7].in_channels=fp\n",
    "\n",
    "\n",
    "pruned.cuda()\n",
    "Z=pruned.model[0](x.cuda())\n",
    "Z=pruned.model[1](Z)\n",
    "Z=pruned.model[3](Z)\n",
    "Z=pruned.model[4](Z)\n",
    "Z=pruned.model[5](Z)\n",
    "with torch.no_grad():\n",
    "    Z=pruned.model[7](Z)\n",
    "    Z=pruned.model[8](Z)\n",
    "    holder=copy.deepcopy(Z)\n",
    "    fi = Z.shape[1]\n",
    "    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "    fp=int(Zr.shape[1]*k)\n",
    "    (k_idx,T) = getID(fp, Zr.cpu())\n",
    "    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "plt.figure()\n",
    "plt.semilogy(sv)\n",
    "T = torch.Tensor(T)\n",
    "with torch.no_grad():\n",
    "    Wnext=pruned.model[10].weight.clone()\n",
    "    saved=copy.deepcopy(Wnext)\n",
    "    Wnext = Wnext.permute(0,2,3,4,1)\n",
    "    Wnext = torch.matmul(Wnext.cpu(), T.T)\n",
    "    Wnext = Wnext.permute(0,4,1,2,3)\n",
    "pruned.model[10].weight=nn.Parameter(Wnext, requires_grad=True)\n",
    "pruned.model[7].weight=nn.Parameter(pruned.model[7].weight[k_idx,:].clone(), requires_grad=True)\n",
    "pruned.model[7].bias=nn.Parameter(pruned.model[7].bias[k_idx].clone(), requires_grad=True)\n",
    "pruned.model[7].out_channels=fp\n",
    "pruned.model[10].in_channels=fp\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81472b46",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "summary_input = (5,41,41, 41)\n",
    "\n",
    "\n",
    "def prune(model, x, k=.9):\n",
    "    pruned=copy.deepcopy(model)\n",
    "    scores=[]\n",
    "    \n",
    "    flops=print_model_param_flops(pruned,input_res=41 )[1]\n",
    "    print(flops)\n",
    "    minScore=1\n",
    "    #calculate scores\n",
    "    with torch.no_grad():\n",
    "        Z=pruned.model[0](x.cuda())\n",
    "        Z=pruned.model[1](Z)\n",
    "        fi = Z.shape[1]\n",
    "        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "        fp=int(Zr.shape[1]*k)\n",
    "        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "        scores.append(sv[fp]/sv[0]/(flops[0]+flops[1]))\n",
    "        if scores[-1]<minScore:\n",
    "            minScore=scores[-1]\n",
    "            (k_idx,T) = getID(fp, Zr.cpu())\n",
    "            T = torch.Tensor(T)\n",
    "            currentLayer=0\n",
    "            nextLayer=3\n",
    "            f=fp\n",
    "            \n",
    "            \n",
    "        Z=pruned.model[3](Z)\n",
    "        Z=pruned.model[4](Z)\n",
    "        fi = Z.shape[1]\n",
    "        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "        fp=int(Zr.shape[1]*k)\n",
    "        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "        scores.append(sv[fp]/sv[0]/(flops[1]+flops[2]))\n",
    "        if scores[-1]<minScore:\n",
    "            minScore=scores[-1]\n",
    "            (k_idx,T) = getID(fp, Zr.cpu())\n",
    "            T = torch.Tensor(T)\n",
    "            currentLayer=3\n",
    "            nextLayer=7\n",
    "            f=fp\n",
    "        \n",
    "        Z=pruned.model[5](Z)\n",
    "        Z=pruned.model[7](Z)\n",
    "        Z=pruned.model[8](Z)\n",
    "        fi = Z.shape[1]\n",
    "        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "        fp=int(Zr.shape[1]*k)\n",
    "        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "        scores.append(sv[fp]/sv[0]/(flops[2]+flops[3])) \n",
    "        if scores[-1]<minScore:\n",
    "            minScore=scores[-1]\n",
    "            (k_idx,T) = getID(fp, Zr.cpu())\n",
    "            T = torch.Tensor(T)\n",
    "            currentLayer=7\n",
    "            nextLayer=10\n",
    "            f=fp\n",
    "\n",
    "        Z=pruned.model[10](Z)\n",
    "        Z=pruned.model[11](Z)\n",
    "        fi = Z.shape[1]\n",
    "        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)\n",
    "        fp=int(Zr.shape[1]*k)\n",
    "        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)\n",
    "        scores.append(sv[fp]/sv[0]/(flops[3])) \n",
    "        if scores[-1]<minScore:\n",
    "            minScore=scores[-1]\n",
    "            (k_idx,T) = getID(fp, Zr.cpu())\n",
    "            T = torch.Tensor(T)\n",
    "            currentLayer=10\n",
    "            nextLayer=15\n",
    "            f=fp\n",
    "    print(T.shape)\n",
    "    \n",
    "    #prune layer\n",
    "    if currentLayer!=10:\n",
    "        with torch.no_grad():\n",
    "            Wnext=pruned.model[nextLayer].weight.clone()\n",
    "            Wnext = Wnext.permute(0,2,3,4,1)\n",
    "            Wnext = torch.matmul(Wnext.cpu(), T.T)\n",
    "            Wnext = Wnext.permute(0,4,1,2,3)\n",
    "        pruned.model[nextLayer].weight=nn.Parameter(Wnext, requires_grad=True)\n",
    "        pruned.model[currentLayer].weight=nn.Parameter(pruned.model[currentLayer].weight[k_idx,:].clone(), requires_grad=True)\n",
    "        pruned.model[currentLayer].bias=nn.Parameter(pruned.model[currentLayer].bias[k_idx].clone(), requires_grad=True)\n",
    "        pruned.model[currentLayer].out_channels=f\n",
    "        pruned.model[nextLayer].in_channels=f\n",
    "    else:\n",
    "        n = int(pruned.model[nextLayer].in_features / pruned.model[currentLayer].out_channels)\n",
    "        T = torch.kron(T.contiguous(), torch.eye(n))\n",
    "        print(T.shape)\n",
    "        pruned.model[nextLayer].weight = nn.Parameter(pruned.model[nextLayer].weight.cpu() @ T.T,\n",
    "                    requires_grad=True)\n",
    "        pruned.model[currentLayer].out_channels = f\n",
    "        pruned.model[nextLayer].in_features = f * n\n",
    "        pruned.model[currentLayer].weight=nn.Parameter(pruned.model[currentLayer].weight[k_idx,:].clone(), requires_grad=True)\n",
    "        pruned.model[currentLayer].bias=nn.Parameter(pruned.model[currentLayer].bias[k_idx].clone(), requires_grad=True)\n",
    "        pruned.model[currentLayer].out_channels=f\n",
    "\n",
    "    #test\n",
    "    \n",
    "    rmse, pearson, spearman, test_df=test(pruned.cuda(), test_loader, device)\n",
    "    return pruned,[rmse, pearson, spearman] , print_model_param_flops(pruned,input_res=41 )[0]\n",
    "losses=[]\n",
    "flops=[]\n",
    "pruned=model\n",
    "for i in range(0, 100):\n",
    "    pruned, t, fs=prune(pruned, x)\n",
    "    losses.append(t)\n",
    "    flops.append(fs)\n",
    "    print(t)\n",
    "    print(fs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66803be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#file=open(\"atom3dDump.p\", 'wb')\n",
    "import pickle\n",
    "#pickle.dump([losses, flops], file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9869239a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "rmse, pearson, spearman, test_df=test(model.cuda(), test_loader, device)\n",
    "print(rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "193ab997",
   "metadata": {},
   "outputs": [],
   "source": [
    "ls=np.array(losses)\n",
    "fs=np.array(flops)/print_model_param_flops(model,input_res=41 )[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197b4df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(fs, ls[:, 0])\n",
    "plt.xlabel(\"flops\")\n",
    "plt.ylabel(\"loss\")\n",
    "plt.axhline(y=1.43, color='r', linestyle='-')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f12fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),\n",
    "                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))\n",
    "\n",
    "test_loader = DataLoader(test_dataset, 16, shuffle=False)\n",
    "\n",
    "model.eval()\n",
    "test(pruned.cuda(), test_loader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b4a89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rmse, pearson, spearman, test_df = test(pruned.cuda(), test_loader, device)\n",
    "print('Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(\n",
    "            rmse, pearson, spearman\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9305eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf8b24b",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_input = (5,41,41, 41)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e492e83a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from train import model_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "858551e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model_summary(model,summary_input= summary_input, input_res=41))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0a2b63c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model_summary(pruned,summary_input= summary_input, input_res=41))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c7453c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.autograd import Variable\n",
    "def print_model_param_flops(model=None, input_res=224, multiply_adds=True):\n",
    "\n",
    "    prods = {}\n",
    "    def save_hook(name):\n",
    "        def hook_per(self, input, output):\n",
    "            prods[name] = np.prod(input[0].shape)\n",
    "        return hook_per\n",
    "\n",
    "    list_1=[]\n",
    "    def simple_hook(self, input, output):\n",
    "        list_1.append(np.prod(input[0].shape))\n",
    "    list_2={}\n",
    "    def simple_hook2(self, input, output):\n",
    "        list_2['names'] = np.prod(input[0].shape)\n",
    "\n",
    "    list_conv=[]\n",
    "    def conv_hook(self, input, output):\n",
    "        batch_size, input_channels, input_height, input_width, input_depth = input[0].size()\n",
    "        output_channels, output_height, output_width, output_depth = output[0].size()\n",
    "\n",
    "        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]* (self.in_channels / self.groups)\n",
    "        bias_ops = 1 if self.bias is not None else 0\n",
    "        \n",
    "        params = output_channels * (kernel_ops + bias_ops)\n",
    "        flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width*output_depth * batch_size\n",
    "\n",
    "        list_conv.append(flops)\n",
    "        print(flops)\n",
    "    list_linear=[]\n",
    "    def linear_hook(self, input, output):\n",
    "        batch_size = input[0].size(0) if input[0].dim() == 2 else 1\n",
    "\n",
    "        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)\n",
    "        bias_ops = self.bias.nelement()\n",
    "\n",
    "        flops = batch_size * (weight_ops + bias_ops)\n",
    "        list_linear.append(flops)\n",
    "\n",
    "    list_bn=[]\n",
    "    def bn_hook(self, input, output):\n",
    "        list_bn.append(input[0].nelement() * 2)\n",
    "\n",
    "    list_relu=[]\n",
    "    def relu_hook(self, input, output):\n",
    "        list_relu.append(input[0].nelement())\n",
    "\n",
    "    list_pooling=[]\n",
    "    def pooling_hook(self, input, output):\n",
    "        batch_size, input_channels, input_height, input_width = input[0].size()\n",
    "        output_channels, output_height, output_width = output[0].size()\n",
    "\n",
    "        kernel_ops = self.kernel_size * self.kernel_size\n",
    "        bias_ops = 0\n",
    "        params = 0\n",
    "        flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size\n",
    "\n",
    "        list_pooling.append(flops)\n",
    "\n",
    "    list_upsample=[]\n",
    "    # For bilinear upsample\n",
    "    def upsample_hook(self, input, output):\n",
    "        batch_size, input_channels, input_height, input_width = input[0].size()\n",
    "        output_channels, output_height, output_width = output[0].size()\n",
    "\n",
    "        flops = output_height * output_width * output_channels * batch_size * 12\n",
    "        list_upsample.append(flops)\n",
    "\n",
    "    def foo(net):\n",
    "        childrens = list(net.children())\n",
    "        if not childrens:\n",
    "            if isinstance(net, torch.nn.Conv3d):\n",
    "                net.register_forward_hook(conv_hook)\n",
    "            if isinstance(net, torch.nn.Linear):\n",
    "                net.register_forward_hook(linear_hook)\n",
    "            if isinstance(net, torch.nn.BatchNorm2d):\n",
    "                net.register_forward_hook(bn_hook)\n",
    "            if isinstance(net, torch.nn.ReLU):\n",
    "                net.register_forward_hook(relu_hook)\n",
    "            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):\n",
    "                net.register_forward_hook(pooling_hook)\n",
    "            if isinstance(net, torch.nn.Upsample):\n",
    "                net.register_forward_hook(upsample_hook)\n",
    "            return\n",
    "        for c in childrens:\n",
    "            foo(c)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    m = copy.deepcopy(model)\n",
    "    foo(m)\n",
    "    input = Variable(torch.rand(3, 5, input_res, input_res, input_res), requires_grad = True)\n",
    "    input = input.to(device)\n",
    "    out = m(input)\n",
    "\n",
    "    print(list_conv, list_linear, list_bn, list_pooling, list_relu)\n",
    "    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample))\n",
    "\n",
    "    print('  + Number of FLOPs: %.5fG' % (total_flops / 3 / 1e9))\n",
    "\n",
    "    return total_flops / 3, list_conv, list_linear\n",
    "\n",
    "print_model_param_flops(model,input_res=41 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38c739f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_model_param_flops(pruned,input_res=41 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "134407f1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1c2cb86",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch.nn.utils.prune as prune\n",
    "def conv_model(in_channels, spatial_size, args, k=1.0):\n",
    "    num_conv = args.num_conv\n",
    "    conv_filters = [int(32*k) * (2**n) for n in range(num_conv)]\n",
    "    conv_kernel_size = 3\n",
    "    max_pool_positions = [0, 1]*int((num_conv+1)/2)\n",
    "    max_pool_sizes = [2]*num_conv\n",
    "    max_pool_strides = [2]*num_conv\n",
    "    fc_units = [512]\n",
    "\n",
    "    model = CNN3D_LBA(\n",
    "        in_channels, spatial_size,\n",
    "        args.conv_drop_rate,\n",
    "        args.fc_drop_rate,\n",
    "        conv_filters, conv_kernel_size,\n",
    "        max_pool_positions,\n",
    "        max_pool_sizes, max_pool_strides,\n",
    "        fc_units,\n",
    "        batch_norm=args.batch_norm,\n",
    "        dropout=not args.no_dropout)\n",
    "    return model\n",
    "amounts=[.05,.075,.1,.13, .15,.175, .2,.25, .3, .4,.5]\n",
    "#rmses=[]\n",
    "flops=[]\n",
    "\n",
    "for data in test_loader:\n",
    "    in_channels, spatial_size = data['feature'].size()[1:3]\n",
    "    print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size))\n",
    "    break\n",
    "\n",
    "        \n",
    "for amount in amounts:\n",
    "    dummy=conv_model(in_channels, spatial_size, args, k=1-amount)\n",
    "    flops.append(print_model_param_flops(dummy.cuda(),input_res=41 )[0])\n",
    "    new=copy.deepcopy(model)\n",
    "    for name, module in new.named_modules():\n",
    "        # prune 20% of connections in all 2D-conv layers\n",
    "        if isinstance(module, torch.nn.Conv3d):\n",
    "            prune.ln_structured(module, name='weight', amount=amount, dim=1, n=2)\n",
    "    #rmses.append(test(new.cuda(), test_loader, device)[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd2e4319",
   "metadata": {},
   "outputs": [],
   "source": [
    "file=open(\"atom3dmag.p\", 'wb')\n",
    "pickle.dump([flops, rmses],file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b97e7a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_model_param_flops(new,input_res=41 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290dfc4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(flops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ad721b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
