{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.utils.data as data_utils\n",
    "from typing import List, Tuple\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "from scipy.io import loadmat, savemat\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.special import eval_legendre\n",
    "from sympy import Poly, legendre, Symbol\n",
    "import h5py\n",
    "\n",
    "\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "\n",
    "from utils import LpLoss, get_filter, exp_pade_coeff, train, test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_initializer(name):\n",
    "    \n",
    "    if name == 'xavier_normal':\n",
    "        init_ = partial(nn.init.xavier_normal_)\n",
    "    elif name == 'kaiming_uniform':\n",
    "        init_ = partial(nn.init.kaiming_uniform_)\n",
    "    elif name == 'kaiming_normal':\n",
    "        init_ = partial(nn.init.kaiming_normal_)\n",
    "    return init_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compl_mul1d(x, weights):\n",
    "    # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)\n",
    "    return torch.einsum(\"bix,iox->box\", x, weights)\n",
    "\n",
    "\n",
    "class sparseKernelFT(nn.Module):\n",
    "    def __init__(self,\n",
    "                 k, alpha, c=1, \n",
    "                 nl = 1,\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(sparseKernelFT, self).__init__()        \n",
    "        \n",
    "        self.modes1 = alpha\n",
    "        self.scale = (1 / (c*k*c*k))\n",
    "        self.weights1 = nn.Parameter(self.scale * torch.rand(c*k, c*k, self.modes1, dtype=torch.cfloat))\n",
    "        self.k = k\n",
    "        \n",
    "    def forward(self, x):\n",
    "        B, N, ck = x.shape # (B, N, c, k)\n",
    "        \n",
    "        x = x.permute(0, 2, 1)\n",
    "        x_fft = torch.fft.rfft(x)\n",
    "        # Multiply relevant Fourier modes\n",
    "        l = min(self.modes1, N//2+1)\n",
    "        out_ft = torch.zeros(B, ck, N//2 + 1,  device=x.device, dtype=torch.cfloat)\n",
    "        \n",
    "        out_ft[:, :, :l] = compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l])\n",
    "        \n",
    "        #Return to physical space\n",
    "        x = torch.fft.irfft(out_ft, n=N)\n",
    "        x = x.permute(0, 2, 1)\n",
    "        return x\n",
    "    \n",
    "\n",
    "class pade_exponential(nn.Module):\n",
    "    def __init__(self, \n",
    "                k, alpha, c=1,\n",
    "                p = 3, q = 4,\n",
    "                initializer = None,\n",
    "                **kwargs):\n",
    "        super(pade_exponential, self).__init__()\n",
    "        \n",
    "        self.p = p\n",
    "        self.q = q\n",
    "        Pp, Pq = exp_pade_coeff(p, q)\n",
    "        \n",
    "        self.LinOperator = sparseKernelFT(k, alpha, c)\n",
    "        self.Linear = nn.Linear(c*k, c*k)\n",
    "        \n",
    "        self.register_buffer('Pp', torch.Tensor(Pp))\n",
    "        self.register_buffer('Pq', torch.Tensor(Pq))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        B, N, c, k = x.shape\n",
    "        \n",
    "        x = x.view(B, N, -1)\n",
    "        aggr_q = self.Pq[0] * x\n",
    "        for i in range(1, self.q):\n",
    "            x = self.LinOperator(x)\n",
    "            aggr_q += self.Pq[i] * x\n",
    "            \n",
    "        aggr_q = self.Linear(aggr_q)\n",
    "        aggr_q = F.relu(aggr_q)\n",
    "        \n",
    "        x = self.Pp[0] * aggr_q\n",
    "        for i in range(1, self.p):\n",
    "            aggr_q = self.LinOperator(aggr_q)\n",
    "            x += self.Pp[i] * aggr_q\n",
    "        \n",
    "        return x.view(B, N, c, k)\n",
    "\n",
    "    \n",
    "class MWT_CZ(nn.Module):\n",
    "    def __init__(self,\n",
    "                 k = 3, alpha = 5, \n",
    "                 L = 0, c = 1,\n",
    "                 p = 4, q = 2,\n",
    "                 base = 'legendre',\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(MWT_CZ, self).__init__()\n",
    "        \n",
    "        self.k = k\n",
    "        self.L = L\n",
    "        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)\n",
    "        H0r = H0@PHI0\n",
    "        G0r = G0@PHI0\n",
    "        H1r = H1@PHI1\n",
    "        G1r = G1@PHI1\n",
    "        \n",
    "        H0r[np.abs(H0r)<1e-8]=0\n",
    "        H1r[np.abs(H1r)<1e-8]=0\n",
    "        G0r[np.abs(G0r)<1e-8]=0\n",
    "        G1r[np.abs(G1r)<1e-8]=0\n",
    "        \n",
    "        self.A = pade_exponential(k, alpha, c, p, q)\n",
    "        self.B = pade_exponential(k, alpha, c, p, q)\n",
    "        self.C = pade_exponential(k, alpha, c, p, q)\n",
    "        \n",
    "        self.T0 = nn.Linear(k, k)\n",
    "\n",
    "        self.register_buffer('ec_s', torch.Tensor(\n",
    "            np.concatenate((H0.T, H1.T), axis=0)))\n",
    "        self.register_buffer('ec_d', torch.Tensor(\n",
    "            np.concatenate((G0.T, G1.T), axis=0)))\n",
    "        \n",
    "        self.register_buffer('rc_e', torch.Tensor(\n",
    "            np.concatenate((H0r, G0r), axis=0)))\n",
    "        self.register_buffer('rc_o', torch.Tensor(\n",
    "            np.concatenate((H1r, G1r), axis=0)))\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        B, N, c, ich = x.shape # (B, N, k)\n",
    "        ns = math.floor(np.log2(N))\n",
    "\n",
    "        Ud = torch.jit.annotate(List[Tensor], [])\n",
    "        Us = torch.jit.annotate(List[Tensor], [])\n",
    "#         decompose\n",
    "        for i in range(ns-self.L):\n",
    "            d, x = self.wavelet_transform(x)\n",
    "            Ud += [self.A(d) + self.B(x)]\n",
    "            Us += [self.C(d)]\n",
    "        x = self.T0(x) # coarsest scale transform\n",
    "\n",
    "#        reconstruct            \n",
    "        for i in range(ns-1-self.L,-1,-1):\n",
    "            x = x + Us[i]\n",
    "            x = torch.cat((x, Ud[i]), -1)\n",
    "            x = self.evenOdd(x)\n",
    "        return x\n",
    "\n",
    "    \n",
    "    def wavelet_transform(self, x):\n",
    "        xa = torch.cat([x[:, ::2, :, :], \n",
    "                        x[:, 1::2, :, :], \n",
    "                       ], -1)\n",
    "        d = torch.matmul(xa, self.ec_d)\n",
    "        s = torch.matmul(xa, self.ec_s)\n",
    "        return d, s\n",
    "        \n",
    "        \n",
    "    def evenOdd(self, x):\n",
    "        \n",
    "        B, N, c, ich = x.shape # (B, N, c, k)\n",
    "        assert ich == 2*self.k\n",
    "        x_e = torch.matmul(x, self.rc_e)\n",
    "        x_o = torch.matmul(x, self.rc_o)\n",
    "        \n",
    "        x = torch.zeros(B, N*2, c, self.k, \n",
    "            device = x.device)\n",
    "        x[..., ::2, :, :] = x_e\n",
    "        x[..., 1::2, :, :] = x_o\n",
    "        return x\n",
    "    \n",
    "    \n",
    "class MWT_exp(nn.Module):\n",
    "    def __init__(self,\n",
    "                 ich = 1, k = 3, alpha = 2, c = 1,\n",
    "                 p = 4, q = 2,\n",
    "                 nCZ = 3,\n",
    "                 L = 0,\n",
    "                 base = 'legendre',\n",
    "                 initializer = None,\n",
    "                 **kwargs):\n",
    "        super(MWT_exp,self).__init__()\n",
    "        \n",
    "        self.k = k\n",
    "        self.c = c\n",
    "        self.L = L\n",
    "        self.nCZ = nCZ\n",
    "        self.Lk = nn.Linear(ich, c*k)\n",
    "        \n",
    "        self.MWT_CZ = nn.ModuleList(\n",
    "            [MWT_CZ(k, alpha, L, c, \n",
    "                p, q, base, \n",
    "                initializer) \n",
    "                for _ in range(nCZ)]\n",
    "        )\n",
    "        self.Lc0 = nn.Linear(c*k, 128)\n",
    "        self.Lc1 = nn.Linear(128, 1)\n",
    "        \n",
    "        if initializer is not None:\n",
    "            self.reset_parameters(initializer)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        B, N, ich = x.shape # (B, N, d)\n",
    "        ns = math.floor(np.log2(N))\n",
    "        x = self.Lk(x)\n",
    "        x = x.view(B, N, self.c, self.k)\n",
    "    \n",
    "        for i in range(self.nCZ):\n",
    "            x = self.MWT_CZ[i](x)\n",
    "#             \n",
    "            if i < self.nCZ-1:\n",
    "                x = F.relu(x)\n",
    "\n",
    "        x = x.view(B, N, -1) # collapse c and k\n",
    "        x = self.Lc0(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.Lc1(x)\n",
    "        return x.squeeze(-1)\n",
    "    \n",
    "    def reset_parameters(self, initializer):\n",
    "        initializer(self.Lc0.weight)\n",
    "        initializer(self.Lc1.weight)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 20\n",
    "ntrain = 1000\n",
    "ntest = 200\n",
    "\n",
    "sub = 2**3 #subsampling rate\n",
    "h = 2**13 // sub #total grid size divided by the subsampling rate\n",
    "s = h\n",
    "\n",
    "rw_sqd = loadmat('data/KdV/kdv_train_test.mat')\n",
    "x_data = rw_sqd['input'].astype(np.float32)\n",
    "y_data = rw_sqd['output'].astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ich = 1\n",
    "initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform\n",
    "\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "model = MWT_exp(ich,\n",
    "          alpha = 16,\n",
    "          c = 4*4,\n",
    "          k = 4,\n",
    "          p = 5,\n",
    "          q = 6,\n",
    "          L = 0,\n",
    "          base = 'legendre',\n",
    "          nCZ = 1,\n",
    "          initializer = initializer,\n",
    "          ).to(device)\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 500\n",
    "step_size = 100\n",
    "gamma = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = x_data[:ntrain,::sub]\n",
    "y_train = y_data[:ntrain,::sub]\n",
    "\n",
    "x_test = x_data[-ntest:,::sub]\n",
    "y_test = y_data[-ntest:,::sub]\n",
    "\n",
    "x_train = torch.from_numpy(x_train)\n",
    "x_test = torch.from_numpy(x_test)\n",
    "\n",
    "y_train = torch.from_numpy(y_train)\n",
    "y_test = torch.from_numpy(y_test)\n",
    "\n",
    "x_train = x_train.unsqueeze(-1)\n",
    "x_test = x_test.unsqueeze(-1)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n",
    "myloss = LpLoss(size_average=False)\n",
    "for epoch in range(1, epochs+1):\n",
    "    train_l2 = train(model, train_loader, optimizer, epoch, device,\n",
    "        lossFn = myloss, lr_schedule = scheduler)\n",
    "    \n",
    "    test_l2 = test(model, test_loader, device, lossFn=myloss)\n",
    "    print(f'epoch: {epoch}, train l2 = {train_l2}, test l2 = {test_l2}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pde18",
   "language": "python",
   "name": "pde18"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
