{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from Utils import load\n",
    "from Utils import generator\n",
    "from Utils import metrics\n",
    "from train import *\n",
    "from prune import *\n",
    "from Layers import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import functional as F\n",
    "import torch.nn as nn\n",
    "def fc(input_shape, nonlinearity=nn.ReLU()):\n",
    "      size = np.prod(input_shape)\n",
    "\n",
    "      # Linear feature extractor\n",
    "      modules = [nn.Flatten()]\n",
    "      modules.append(layers.Linear(size, 5000))\n",
    "      modules.append(nonlinearity)\n",
    "      modules.append(layers.Linear(5000, 900))\n",
    "      modules.append(nonlinearity)\n",
    "      modules.append(layers.Linear(900, 400))\n",
    "      modules.append(nonlinearity)\n",
    "      modules.append(layers.Linear(400, 100))\n",
    "      modules.append(nonlinearity)        \n",
    "      modules.append(layers.Linear(100, 30))\n",
    "      modules.append(nonlinearity)\n",
    "      modules.append(layers.Linear(30, 1))\n",
    "\n",
    "      model = nn.Sequential(*modules)\n",
    "\n",
    "      return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from data import *\n",
    "from models import *\n",
    "from utils import *\n",
    "from sklearn.model_selection import KFold\n",
    "import os, shutil, pickle\n",
    "ctx = mx.gpu(0) if mx.context.num_gpus() > 0 else mx.cpu(0)\n",
    "\n",
    "loss_before_prune=[]\n",
    "loss_after_prune=[]\n",
    "loss_prune_posttrain=[]\n",
    "NUM_PARA=[]\n",
    "for datasetindex in range(10):#[0,1,4,5,6,7,8,9]:\n",
    "    dataset=str(datasetindex)+'.csv'\n",
    "    X, y= get_data(dataset)\n",
    "\n",
    "    np.random.seed(0)    \n",
    "    kf = KFold(n_splits=5,random_state=0,shuffle=True)\n",
    "    kf.get_n_splits(X)\n",
    "    seed=0#[0,1,2,3,4]\n",
    "    chosenarmsList=[]\n",
    "    for train_index, test_index in kf.split(X):\n",
    "        X_tr, X_te = X[train_index], X[test_index]\n",
    "        y_tr, y_te = y[train_index], y[test_index]\n",
    "        X_test=nd.array(X_te).as_in_context(ctx)  # Fix test data for all seeds\n",
    "        y_test=nd.array(y_te).as_in_context(ctx)\n",
    "        factor=np.max(y_te)-np.min(y_te) #normalize RMSE\n",
    "        print(factor)\n",
    "        #X_tr, X_te, y_tr, y_te = get_data(0.2,0)\n",
    "        #selected_interaction = detectNID(X_tr,y_tr,X_te,y_te,test_size,seed)\n",
    "        #index_Subsets=get_interaction_index(selected_interaction)\n",
    "        \n",
    "        N=X_tr.shape[0]\n",
    "        p=X_tr.shape[1]\n",
    "        batch_size=500\n",
    "        n_epochs=300\n",
    "        if N<250:\n",
    "            batch_size=50\n",
    "        X_train=nd.array(X_tr).as_in_context(ctx)\n",
    "        y_train=nd.array(y_tr).as_in_context(ctx)\n",
    "        train_dataset = ArrayDataset(X_train, y_train)\n",
    "#        num_workers=4\n",
    "        train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#,num_workers=num_workers)\n",
    "        #X_test=nd.array(X_te).as_in_context(ctx)\n",
    "        #y_test=nd.array(y_te).as_in_context(ctx)\n",
    "        \n",
    "    \n",
    "        print('start training FC')\n",
    "        FCnet=build_FC(train_data,ctx)    # initialize the overparametrized network\n",
    "        FCnet.load_parameters('Selected_models/FCnet_'+str(datasetindex)+'_seed_'+str(seed),ctx=ctx)\n",
    "        \n",
    "        import torch.nn as nn\n",
    "        model=fc(10)      \n",
    "        loss = nn.MSELoss()\n",
    "        dataset=torch.utils.data.TensorDataset(torch.Tensor(X_tr),torch.Tensor(y_tr))\n",
    "        for i in range(6):\n",
    "            model[int(i*2+1)].weight.data=torch.Tensor(FCnet[i].weight.data().asnumpy())\n",
    "            model[int(i*2+1)].bias.data=torch.Tensor(FCnet[i].bias.data().asnumpy())\n",
    "        \n",
    "        \n",
    "        print(\"dataset:\",datasetindex,\"seed\",seed)\n",
    "        print(\"before prune:\",torch.sqrt(torch.mean((model(torch.Tensor(X_te))-torch.Tensor(y_te))**2))/factor)\n",
    "        loss_before_prune.append(torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)  \n",
    "        print(torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)\n",
    "        # Prune ##\n",
    "\n",
    "        device = torch.device(\"cpu\")\n",
    "        prune_loader = load.dataloader(dataset, 64, True, 4, 1)\n",
    "        prune_epochs=10\n",
    "        print('Pruning with {} for {} epochs.'.format('synflow', prune_epochs))\n",
    "        pruner = load.pruner('synflow')(generator.masked_parameters(model, False, False, False))\n",
    "        sparsity = 10**(-float(2.44715803134))  #280X    #100X 10**(-float(2))\n",
    "\n",
    "        prune_loop(model, loss, pruner, prune_loader, device, sparsity, \n",
    "                  'exponential', 'global', prune_epochs, False, False, False, False)\n",
    "\n",
    "        pruner.apply_mask()\n",
    "\n",
    "\n",
    "        print(\"after prune:\",torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)\n",
    "        loss_after_prune.append(torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)\n",
    "        ## post_train\n",
    "\n",
    "        train_loader = load.dataloader(dataset, 64, True, 4)\n",
    "        test_loader = load.dataloader(dataset, 200 , False, 4)\n",
    "        optimizer = torch.optim.Adam(generator.parameters(model), betas=(0.9, 0.99))\n",
    "\n",
    "        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[30,80], gamma=0.1)\n",
    "\n",
    "        post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, \n",
    "                                          test_loader, device, 100, True)   \n",
    "\n",
    "        print(\"after post_train:\",torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)\n",
    "        loss_prune_posttrain.append(torch.sqrt(loss(model(torch.Tensor(X_te)),torch.Tensor(y_te)))/factor)\n",
    "\n",
    "        num=0\n",
    "        for i in pruner.masked_parameters:\n",
    "            num=num+sum(sum(i[0]))\n",
    "        print(num)\n",
    "        NUM_PARA.append(num)\n",
    "        seed=seed+1\n",
    "        \n",
    "        import mxnet.gluon.nn as nn\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0254, grad_fn=<DivBackward0>)\n",
      "tensor(0.0483, grad_fn=<DivBackward0>)\n",
      "tensor(0.0339, grad_fn=<DivBackward0>)\n",
      "tensor(0.0349, grad_fn=<DivBackward0>)\n",
      "tensor(0.0394, grad_fn=<DivBackward0>)\n",
      "tensor(0.0314, grad_fn=<DivBackward0>)\n",
      "tensor(0.0151, grad_fn=<DivBackward0>)\n",
      "tensor(0.0179, grad_fn=<DivBackward0>)\n",
      "tensor(0.0263, grad_fn=<DivBackward0>)\n",
      "tensor(0.0218, grad_fn=<DivBackward0>)\n",
      "ave: tensor(0.0294, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "##synflow results 280X\n",
    "a=0\n",
    "for i in range(10):\n",
    "    print(sum(loss_prune_posttrain[5*i:5*i+5])/5)\n",
    "    a=a+sum(loss_prune_posttrain[5*i:5*i+5])/5\n",
    "print(\"ave:\",a/10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0260, grad_fn=<DivBackward0>)\n",
      "tensor(0.0503, grad_fn=<DivBackward0>)\n",
      "tensor(0.0372, grad_fn=<DivBackward0>)\n",
      "tensor(0.0394, grad_fn=<DivBackward0>)\n",
      "tensor(0.0383, grad_fn=<DivBackward0>)\n",
      "tensor(0.0344, grad_fn=<DivBackward0>)\n",
      "tensor(0.0150, grad_fn=<DivBackward0>)\n",
      "tensor(0.0185, grad_fn=<DivBackward0>)\n",
      "tensor(0.0275, grad_fn=<DivBackward0>)\n",
      "tensor(0.0218, grad_fn=<DivBackward0>)\n",
      "ave: tensor(0.0308, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "##synflow results 100X\n",
    "a=0\n",
    "for i in range(10):\n",
    "    print(sum(loss_prune_posttrain[5*i:5*i+5])/5)\n",
    "    a=a+sum(loss_prune_posttrain[5*i:5*i+5])/5\n",
    "print(\"ave:\",a/10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml",
   "language": "python",
   "name": "ml"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
