{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b65c9114",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.utils.data as data_utils\n",
    "from typing import List, Tuple\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "from scipy.io import loadmat, savemat\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.special import eval_legendre, gammaln\n",
    "from sympy import Poly, legendre, Symbol\n",
    "import h5py\n",
    "import pickle as pk\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split, KFold, ShuffleSplit\n",
    "from sklearn.utils import shuffle\n",
    "\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "from timeit import default_timer\n",
    "\n",
    "from utils import train, test, LpLoss, get_filter, exp_pade_coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "11540a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fa94ba8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a073b431",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "61da64ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "din = np.load('data/Corona/data_input.npy')\n",
    "dout = np.load('data/Corona/data_output.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0fae570b",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "learning_rate = 1e-4\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "x_train, x_test, y_train, y_test = train_test_split(\n",
    "    din, dout, test_size=0.1, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "20e385de",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = torch.from_numpy(x_train.astype('float32'))\n",
    "x_test = torch.from_numpy(x_test.astype('float32'))\n",
    "y_train = torch.from_numpy(y_train.astype('float32'))\n",
    "y_test = torch.from_numpy(y_test.astype('float32'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "96e3c0f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntrain, Nx, Ny, Tin = x_train.shape\n",
    "_, _, _, Tout = y_train.shape\n",
    "ntest = x_test.shape[0]\n",
    "x_train = x_train.reshape(ntrain, Nx, Ny, 1, Tin).repeat([1, 1, 1, Tout, 1])\n",
    "\n",
    "x_test = x_test.reshape(ntest, Nx, Ny, 1, Tin).repeat([1, 1, 1, Tout, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bb33fe7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([435, 64, 4, 7, 14])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a414b020",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_initializer(name):\n",
    "    \n",
    "    if name == 'xavier_normal':\n",
    "        init_ = partial(nn.init.xavier_normal_)\n",
    "    elif name == 'kaiming_uniform':\n",
    "        init_ = partial(nn.init.kaiming_uniform_)\n",
    "    elif name == 'kaiming_normal':\n",
    "        init_ = partial(nn.init.kaiming_normal_)\n",
    "    return init_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "39d99f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "class sparseKernel(nn.Module):\n",
    "    def __init__(self,\n",
    "                 k, alpha, c=1, \n",
    "                 nl = 1,\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(sparseKernel,self).__init__()\n",
    "        \n",
    "        self.k = k\n",
    "        self.conv = self.convBlock(alpha*k**2, alpha*k**2)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = torch.tanh(x)\n",
    "        \n",
    "        return x\n",
    "        \n",
    "        \n",
    "    def convBlock(self, ich, och):\n",
    "        net = nn.Sequential(\n",
    "            nn.Conv3d(och, och, 3, 1, 1),\n",
    "            nn.Tanh(),\n",
    "            nn.Conv3d(och, och, 3, 1, 1),\n",
    "        )\n",
    "        return net \n",
    "        \n",
    "\n",
    "class pade_exponential(nn.Module):\n",
    "    def __init__(self, \n",
    "                k, alpha, c=1,\n",
    "                p = 3, q = 4,\n",
    "                initializer = None,\n",
    "                **kwargs):\n",
    "        super(pade_exponential, self).__init__()\n",
    "        \n",
    "        self.p = p\n",
    "        self.q = q\n",
    "        Pp, Pq = exp_pade_coeff(p, q)\n",
    "        \n",
    "        self.LinOperator = sparseKernel(k, c, c)\n",
    "        self.Linear = nn.Conv1d(c*k**2, c*k**2, 1)\n",
    "        \n",
    "        self.register_buffer('Pp', torch.Tensor(Pp))\n",
    "        self.register_buffer('Pq', torch.Tensor(Pq))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        B, c, ich, Nx, Ny, T = x.shape\n",
    "        \n",
    "        x = x.reshape(B, -1, Nx, Ny, T)\n",
    "        aggr_q = self.Pq[0] * x\n",
    "        for i in range(1, self.q):\n",
    "            x = self.LinOperator(x)\n",
    "            aggr_q += self.Pq[i] * x\n",
    "            \n",
    "        aggr_q = self.Linear(aggr_q.view(B, c*ich, -1)\n",
    "                    ).view(B, c*ich, Nx, Ny, T)\n",
    "        aggr_q = F.relu(aggr_q)\n",
    "        \n",
    "        x = self.Pp[0] * aggr_q\n",
    "        for i in range(1, self.p):\n",
    "            aggr_q = self.LinOperator(aggr_q)\n",
    "            x += self.Pp[i] * aggr_q\n",
    "        \n",
    "        return x.view(B, c, ich, Nx, Ny, T)\n",
    "    \n",
    "    \n",
    "class MWT_CZ(nn.Module):\n",
    "    def __init__(self,\n",
    "                 k = 3, alpha = 5, \n",
    "                 L = 0, c = 1,\n",
    "                 p = 3, q = 4,\n",
    "                 base = 'legendre',\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(MWT_CZ, self).__init__()\n",
    "        \n",
    "        self.k = k\n",
    "        self.L = L\n",
    "        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)\n",
    "        H0r = H0@PHI0\n",
    "        G0r = G0@PHI0\n",
    "        H1r = H1@PHI1\n",
    "        G1r = G1@PHI1\n",
    "        \n",
    "        H0r[np.abs(H0r)<1e-8]=0\n",
    "        H1r[np.abs(H1r)<1e-8]=0\n",
    "        G0r[np.abs(G0r)<1e-8]=0\n",
    "        G1r[np.abs(G1r)<1e-8]=0\n",
    "        \n",
    "        self.A = pade_exponential(k, alpha, c, p, q)\n",
    "        self.B = pade_exponential(k, alpha, c, p, q)\n",
    "        self.C = pade_exponential(k, alpha, c, p, q)\n",
    "        \n",
    "        self.T0 = nn.Conv1d(c*k**2, c*k**2, 1)\n",
    "\n",
    "        if initializer is not None:\n",
    "            self.reset_parameters(initializer)\n",
    "\n",
    "        self.register_buffer('ec_s', torch.Tensor(\n",
    "            np.concatenate((np.kron(H0, H0).T, \n",
    "                            np.kron(H0, H1).T,\n",
    "                            np.kron(H1, H0).T,\n",
    "                            np.kron(H1, H1).T,\n",
    "                           ), axis=0)))\n",
    "        self.register_buffer('ec_d', torch.Tensor(\n",
    "            np.concatenate((np.kron(G0, G0).T,\n",
    "                            np.kron(G0, G1).T,\n",
    "                            np.kron(G1, G0).T,\n",
    "                            np.kron(G1, G1).T,\n",
    "                           ), axis=0)))\n",
    "        \n",
    "        self.register_buffer('ec_s_x', torch.Tensor(\n",
    "            np.concatenate((np.kron(H0, H0).T, \n",
    "                            np.kron(H1, H0).T,\n",
    "                           ), axis=0)))\n",
    "        self.register_buffer('ec_d_x', torch.Tensor(\n",
    "            np.concatenate((np.kron(G0, G0).T,\n",
    "                            np.kron(G1, G0).T,\n",
    "                           ), axis=0)))\n",
    "        \n",
    "        self.register_buffer('rc_ee', torch.Tensor(\n",
    "            np.concatenate((np.kron(H0r, H0r), \n",
    "                            np.kron(G0r, G0r),\n",
    "                           ), axis=0)))\n",
    "        self.register_buffer('rc_eo', torch.Tensor(\n",
    "            np.concatenate((np.kron(H0r, H1r), \n",
    "                            np.kron(G0r, G1r),\n",
    "                           ), axis=0)))\n",
    "        self.register_buffer('rc_oe', torch.Tensor(\n",
    "            np.concatenate((np.kron(H1r, H0r), \n",
    "                            np.kron(G1r, G0r),\n",
    "                           ), axis=0)))\n",
    "        self.register_buffer('rc_oo', torch.Tensor(\n",
    "            np.concatenate((np.kron(H1r, H1r), \n",
    "                            np.kron(G1r, G1r),\n",
    "                           ), axis=0)))\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        B, c, ich, Nx, Ny, T = x.shape # (B, c, k^2, Nx, Ny, T) assume: Nx>=Ny\n",
    "        ns_x = math.floor(np.log2(Nx))\n",
    "        ns_y = math.floor(np.log2(Ny))\n",
    "\n",
    "        Ud = torch.jit.annotate(List[Tensor], [])\n",
    "        Us = torch.jit.annotate(List[Tensor], [])\n",
    "\n",
    "#        decompose\n",
    "        oe_flag = 'both'\n",
    "        for i in range(ns_x-self.L):\n",
    "            if i > ns_y-1:\n",
    "                oe_flag = 'x'\n",
    "            d, x = self.wavelet_transform(x, oe_flag)\n",
    "            Ud += [self.A(d) + self.B(x)]\n",
    "            Us += [self.C(d)]\n",
    "\n",
    "#        coarsest scale transform\n",
    "        if oe_flag == 'both':\n",
    "            x = self.T0(x.reshape(B, c*ich, -1)).view(\n",
    "                B, c, ich, 2**self.L, 2**(ns_y-(ns_x-self.L)), T) \n",
    "        else:\n",
    "            x = self.T0(x.reshape(B, c*ich, -1)).view(\n",
    "                B, c, ich, 2**self.L, 1, T)          \n",
    "        \n",
    "\n",
    "#        reconstruct            \n",
    "        for i in range(ns_x-1-self.L,-1,-1):\n",
    "            if i < ns_y:\n",
    "                oe_flag = 'both'\n",
    "            x = x + Us[i]\n",
    "            x = torch.cat((x, Ud[i]), 2)\n",
    "            x = self.evenOdd(x, oe_flag)\n",
    "\n",
    "        return x\n",
    "\n",
    "    \n",
    "    def wavelet_transform(self, x, oe_flag):\n",
    "        waveFil = partial(torch.einsum, 'bcixyt,io->bcoxyt') \n",
    "        if oe_flag == 'both':\n",
    "            xa = torch.cat([x[:, :, :, ::2 , ::2 , :], \n",
    "                            x[:, :, :, ::2 , 1::2, :], \n",
    "                            x[:, :, :, 1::2, ::2 , :], \n",
    "                            x[:, :, :, 1::2, 1::2, :]\n",
    "                           ], 2)\n",
    "            d = waveFil(xa, self.ec_d)\n",
    "            s = waveFil(xa, self.ec_s)\n",
    "        elif oe_flag == 'x':\n",
    "            xa = torch.cat([x[:, :, :, ::2 , : , :], \n",
    "                            x[:, :, :, 1::2, : , :], \n",
    "                           ], 2)\n",
    "            d = waveFil(xa, self.ec_d_x)\n",
    "            s = waveFil(xa, self.ec_s_x)\n",
    "            \n",
    "        return d, s\n",
    "        \n",
    "        \n",
    "    def evenOdd(self, x, oe_flag):\n",
    "        \n",
    "        B, c, ich, Nx, Ny, T = x.shape # (B, c, 2*k^2, Nx, Ny)\n",
    "        assert ich == 2*self.k**2\n",
    "        evOd = partial(torch.einsum, 'bcixyt,io->bcoxyt')\n",
    "        \n",
    "        if oe_flag == 'both':\n",
    "            x_ee = evOd(x, self.rc_ee)\n",
    "            x_eo = evOd(x, self.rc_eo)\n",
    "            x_oe = evOd(x, self.rc_oe)\n",
    "            x_oo = evOd(x, self.rc_oo)\n",
    "\n",
    "            x = torch.zeros(B, c, self.k**2, Nx*2, Ny*2, T,\n",
    "                device = x.device)\n",
    "            x[:, :, :, ::2 , ::2 , :] = x_ee\n",
    "            x[:, :, :, ::2 , 1::2, :] = x_eo\n",
    "            x[:, :, :, 1::2, ::2 , :] = x_oe\n",
    "            x[:, :, :, 1::2, 1::2, :] = x_oo\n",
    "        elif oe_flag == 'x':\n",
    "            x_ee = evOd(x, self.rc_ee)\n",
    "            x_oe = evOd(x, self.rc_oe)\n",
    "\n",
    "            x = torch.zeros(B, c, self.k**2, Nx*2, Ny, T,\n",
    "                device = x.device)\n",
    "            x[:, :, :, ::2 , : , :] = x_ee\n",
    "            x[:, :, :, 1::2, : , :] = x_oe\n",
    "            \n",
    "        return x\n",
    "    \n",
    "    def reset_parameters(self, initializer):\n",
    "        initializer(self.T0.weight)\n",
    "    \n",
    "    \n",
    "class MWT_Exp(nn.Module):\n",
    "    def __init__(self,\n",
    "                 ich = 1, k = 3, alpha = 2, c = 1,\n",
    "                 p = 3, q = 4,\n",
    "                 nCZ = 3,\n",
    "                 L = 0,\n",
    "                 base = 'legendre',\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(MWT_Exp,self).__init__()\n",
    "        \n",
    "        self.k = k\n",
    "        self.c = c\n",
    "        self.L = L\n",
    "        self.nCZ = nCZ\n",
    "        self.Lk = nn.Linear(ich, c*k**2)\n",
    "        \n",
    "        self.MWT_CZ = nn.ModuleList(\n",
    "            [MWT_CZ(k, alpha, L, c, \n",
    "                p, q, base, \n",
    "                initializer) \n",
    "                for _ in range(nCZ)]\n",
    "        )\n",
    "        self.Lc0 = nn.Linear(c*k**2, 128)\n",
    "        self.Lc1 = nn.Linear(128, 1)\n",
    "        \n",
    "        if initializer is not None:\n",
    "            self.reset_parameters(initializer)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        B, Nx, Ny, T, ich = x.shape # (B, Nx, Ny, T, d)\n",
    "        ns = math.floor(np.log2(Nx))\n",
    "        x = model.Lk(x)\n",
    "        x = x.view(B, Nx, Ny, T, self.c, self.k**2)\n",
    "        x = x.permute(0, 4, 5, 1, 2, 3)\n",
    "    \n",
    "        for i in range(self.nCZ):\n",
    "            x = self.MWT_CZ[i](x)\n",
    "            if i < self.nCZ-1:\n",
    "                x = F.relu(x)\n",
    "\n",
    "        x = x.view(B, -1, Nx, Ny, T) # collapse c and k**2\n",
    "        x = x.permute(0, 2, 3, 4, 1)\n",
    "        x = model.Lc0(x)\n",
    "        x = F.relu(x)\n",
    "        x = model.Lc1(x)\n",
    "        return x.squeeze(-1)\n",
    "    \n",
    "    def reset_parameters(self, initializer):\n",
    "        initializer(self.Lc0.weight)\n",
    "        initializer(self.Lc1.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d11e8aa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ich = 14\n",
    "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n",
    "\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "c = 4\n",
    "alpha = c\n",
    "k = 4\n",
    "nCZ = 1 # number of MWT cells\n",
    "L = 0\n",
    "model = MWT_Exp(ich, \n",
    "            alpha = alpha,\n",
    "            c = c,\n",
    "            k = k, \n",
    "            p = 4,\n",
    "            q = 2,\n",
    "            base = 'legendre', \n",
    "            nCZ = nCZ,\n",
    "            L = L,\n",
    "            initializer = initializer,\n",
    "            ).to(device)\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 750\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "batch_size = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9c94920a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntrain = x_train.shape[0]\n",
    "ntest = x_test.shape[0]\n",
    "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e23628-34b5-4471-8711-776ad4219a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n",
    "myloss = LpLoss(size_average=False)\n",
    "for epoch in range(1, epochs+1):\n",
    "    train_l2 = train(model, train_loader, optimizer, epoch, device,\n",
    "        lossFn = myloss, lr_schedule = scheduler)\n",
    "    \n",
    "    test_l2 = test(model, test_loader, device, lossFn=myloss)\n",
    "    print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55e8d35b-cc54-4b13-82b5-e3cc499027c7",
   "metadata": {},
   "source": [
    "# 10-fold experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1e94adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "RND_STATE = 42\n",
    "fg = KFold(n_splits=10, shuffle=True, random_state=RND_STATE)\n",
    "\n",
    "ich = 14\n",
    "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n",
    "\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 750\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "batch_size = 10\n",
    "\n",
    "# model params\n",
    "c = 4\n",
    "alpha = c\n",
    "k = 4\n",
    "nCZ = 1\n",
    "L = 0\n",
    "\n",
    "out_log = {'epochs':epochs, 'lr':learning_rate, 'step_size':step_size, \n",
    "           'gamma':gamma, 'batch_size':batch_size,\n",
    "           'c':c, 'alpha':alpha, 'k':k, 'nCZ':nCZ, 'L':L,\n",
    "          'fg': 'KFold split', 'random_seed':RND_STATE}\n",
    "for ii, (train_ind, test_ind) in enumerate(fg.split(din)):\n",
    "    \n",
    "    \n",
    "    train_ind = shuffle(train_ind, random_state=RND_STATE)\n",
    "    \n",
    "    x_train = torch.from_numpy(din[train_ind].astype('float32'))\n",
    "    x_test = torch.from_numpy(din[test_ind].astype('float32'))\n",
    "    y_train = torch.from_numpy(dout[train_ind].astype('float32'))\n",
    "    y_test = torch.from_numpy(dout[test_ind].astype('float32'))\n",
    "    \n",
    "    ntrain, Nx, Ny, Tin = x_train.shape\n",
    "    _, _, _, Tout = y_train.shape\n",
    "    ntest = x_test.shape[0]\n",
    "    x_train = x_train.reshape(ntrain, Nx, Ny, 1, Tin).repeat([1, 1, 1, Tout, 1])\n",
    "    x_test = x_test.reshape(ntest, Nx, Ny, 1, Tin).repeat([1, 1, 1, Tout, 1])\n",
    "\n",
    "    f_in = torch.from_numpy(final_in.astype('float32')).unsqueeze(0)\n",
    "    f_in = f_in.reshape(21, Nx, Ny, 1, Tin).repeat([1, 1, 1, Tout, 1])\n",
    "    f_out = final_out.astype('float32')\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)\n",
    "    \n",
    "    torch.manual_seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    model = MWT(ich, \n",
    "                alpha = alpha,\n",
    "                c = c,\n",
    "                k = k, \n",
    "                p = 4,\n",
    "                q = 2,\n",
    "                base = 'legendre', \n",
    "                nCZ = nCZ,\n",
    "                L = L,\n",
    "                initializer = initializer,\n",
    "                ).to(device)\n",
    "    \n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n",
    "    myloss = LpLoss(size_average=False)\n",
    "    error = []\n",
    "    error_cropped = []\n",
    "    for ep in range(epochs):\n",
    "        model.train()\n",
    "        t1 = default_timer()\n",
    "        train_mse = 0\n",
    "        train_l2 = 0\n",
    "        for x, y in train_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "            bs = x.shape[0]\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            out = model(x)\n",
    "\n",
    "            l2 = myloss(out.view(bs, -1), y.view(bs, -1))\n",
    "            l2.backward() # use the l2 relative loss\n",
    "\n",
    "            optimizer.step()\n",
    "            train_l2 += l2.item()\n",
    "\n",
    "        scheduler.step()\n",
    "        model.eval()\n",
    "        test_l2 = 0.0\n",
    "        test_cropped_l2 = 0.0\n",
    "        with torch.no_grad():\n",
    "            for x, y in test_loader:\n",
    "                x, y = x.cuda(), y.cuda()\n",
    "                bs = x.shape[0]\n",
    "\n",
    "                out = model(x)\n",
    "                test_l2 += myloss(out.view(bs, -1), y.view(bs, -1)).item()\n",
    "                test_cropped_l2 += myloss(out[:,7:7+50,:3,:].reshape(bs,-1), \n",
    "                                        y[:,7:7+50,:3,:].reshape(bs,-1)).item()\n",
    "\n",
    "        train_l2 /= ntrain\n",
    "        test_l2 /= ntest\n",
    "        test_cropped_l2 /= ntest\n",
    "        error.append(test_l2)\n",
    "        error_cropped.append(test_cropped_l2)\n",
    "\n",
    "        t2 = default_timer()\n",
    "        if ep % 50 == 0:\n",
    "            print(ep, t2-t1, train_l2, test_l2, test_cropped_l2)\n",
    "    \n",
    "    \n",
    "    model.eval()\n",
    "    actual = []\n",
    "    pred = []\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "            bs = x.shape[0]\n",
    "\n",
    "            out = model(x)\n",
    "            actual.append(y)\n",
    "            pred.append(out)\n",
    "        f_pred = model(f_in.cuda()).cpu().numpy()\n",
    "    out_log_data = {\n",
    "        'train_ind': train_ind,\n",
    "        'test_ind': test_ind,\n",
    "        'error_l2': error,\n",
    "        'error_l2_cropped': error_cropped,\n",
    "        'y_test': torch.cat(actual).cpu().numpy(),\n",
    "        'y_pred': torch.cat(pred).cpu().numpy(),   \n",
    "        'f_out' : f_out,\n",
    "        'f_pred': f_pred,\n",
    "    }\n",
    "    file_name = f'Pade_Exp_corona_3d_split_{ii}.p'\n",
    "    pk.dump({**out_log, **out_log_data}, open(file_name, 'wb'))    \n",
    "    print(f'Saved data for iteration:{ii}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9451cd89",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "name": "pytorch-gpu.1-9.m75",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-9:m75"
  },
  "kernelspec": {
   "display_name": "pde18",
   "language": "python",
   "name": "pde18"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
