{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# 1D MLP block used in AM-FNO\n",
    "# ---------------------------------------------------------\n",
    "class MLP1d(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, mid_channels, dropout: float = 0.0):\n",
    "        \"\"\"\n",
    "        Simple 1D MLP implemented as 1x1 convolutions:\n",
    "        input:  (B, C_in, N)\n",
    "        output: (B, C_out, N)\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.linear1 = nn.Conv1d(in_channels, mid_channels, kernel_size=1)\n",
    "        self.linear2 = nn.Conv1d(mid_channels, out_channels, kernel_size=1)\n",
    "        self.act = nn.GELU()\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.linear1(x)\n",
    "        x = self.act(x)\n",
    "        if self.dropout > 0:\n",
    "            x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "        x = self.linear2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# Amortized spectral convolution (1D, MLP-based)\n",
    "# ---------------------------------------------------------\n",
    "class SpectralConv1dMLP(nn.Module):\n",
    "    \"\"\"\n",
    "    Amortized Fourier layer in 1D.\n",
    "    - Learns a complex-valued kernel in Fourier space from a low-dimensional\n",
    "      frequency embedding Tx via small MLPs (mlpxr, mlpxi).\n",
    "    - input:  (B, C_in, N)\n",
    "    - output: (B, C_out, N)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, n_basis: int, dropout: float = 0.0):\n",
    "        super().__init__()\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.n_basis = n_basis\n",
    "\n",
    "        # MLPs that map frequency features Tx -> real/imag parts of the kernel\n",
    "        # Tx will have shape (1, n_basis, N_fft_modes)\n",
    "        self.mlpxr = MLP1d(n_basis, in_channels * out_channels, 2 * n_basis, dropout=dropout)\n",
    "        self.mlpxi = MLP1d(n_basis, in_channels * out_channels, 2 * n_basis, dropout=dropout)\n",
    "\n",
    "    @staticmethod\n",
    "    def compl_mul1d(input_ft, weights):\n",
    "        \"\"\"\n",
    "        input_ft: (B, C_in, K)\n",
    "        weights:  (C_in, C_out, K)\n",
    "        returns:  (B, C_out, K)\n",
    "        \"\"\"\n",
    "        return torch.einsum(\"bix, iox -> box\", input_ft, weights)\n",
    "\n",
    "    def _build_Tx(self, N: int, device, dtype):\n",
    "        \"\"\"\n",
    "        Build the Chebyshev-like frequency embedding Tx used by mlpxr/mlpxi.\n",
    "        N is the *spatial* length; the rfft has K = N//2 + 1 modes.\n",
    "        Tx shape: (1, n_basis, K)\n",
    "        \"\"\"\n",
    "        K = N // 2 + 1\n",
    "        # frequencies in [0, 0.5], shape (1, K)\n",
    "        gridx = torch.fft.rfftfreq(N, d=1.0, device=device).unsqueeze(0).to(dtype)\n",
    "        # Chebyshev grades\n",
    "        grade = torch.arange(1, self.n_basis + 1, device=device, dtype=dtype).view(self.n_basis, 1)\n",
    "        # Tx_{k,j} = cos(k * arccos(ω_j))\n",
    "        Tx = torch.cos(grade @ torch.acos(gridx))  # (n_basis, K)\n",
    "        Tx = Tx.unsqueeze(0)                       # (1, n_basis, K)\n",
    "        return Tx\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: (B, C_in, N)\n",
    "        \"\"\"\n",
    "        B, C_in, N = x.shape\n",
    "        device, dtype = x.device, x.dtype\n",
    "\n",
    "        # Build frequency embedding Tx\n",
    "        Tx = self._build_Tx(N, device, dtype)      # (1, n_basis, K)\n",
    "        Tx = Tx.to(device=device, dtype=dtype)\n",
    "\n",
    "        # Generate real/imag parts of kernel via MLPs\n",
    "        xr = self.mlpxr(Tx)                        # (1, C_in*C_out, K)\n",
    "        xi = self.mlpxi(Tx)                        # (1, C_in*C_out, K)\n",
    "\n",
    "        K = N // 2 + 1\n",
    "        xr = xr.view(self.in_channels, self.out_channels, K)\n",
    "        xi = xi.view(self.in_channels, self.out_channels, K)\n",
    "        kernel = xr + 1j * xi                      # (C_in, C_out, K)\n",
    "\n",
    "        # FFT, multiply, inverse FFT\n",
    "        x_ft = torch.fft.rfft(x)                   # (B, C_in, K)\n",
    "        out_ft = self.compl_mul1d(x_ft, kernel)    # (B, C_out, K)\n",
    "        x_out = torch.fft.irfft(out_ft, n=N)       # (B, C_out, N)\n",
    "        return x_out\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# 1D AM-FNO (AMFNO1d) model\n",
    "# ---------------------------------------------------------\n",
    "class AMFNO1d(nn.Module):\n",
    "    \"\"\"\n",
    "    1D AM-FNO for your G_1, G_2, G_3 benchmarks.\n",
    "\n",
    "    Expected I/O:\n",
    "      - input:  (B, 1, N)   (your f samples)\n",
    "      - output: (B, 1, N)   (prediction of Gf)\n",
    "\n",
    "    Architecture:\n",
    "      - lift:  Linear([u(x), x] -> width)\n",
    "      - 4 amortized spectral conv blocks + local MLPs with residual\n",
    "      - final MLP to 1 channel\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        width: int,\n",
    "        n_basis: int = 16,\n",
    "        padding: int = 0,\n",
    "        in_channels: int = 1,\n",
    "        out_channels: int = 1,\n",
    "        mlp_dropout: float = 0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.width = width\n",
    "        self.padding = padding\n",
    "        self.n_basis = n_basis\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "\n",
    "        # Lift input: concatenate u(x) and coordinate x, then project to width\n",
    "        # input_dim = in_channels + 1 (for the coordinate x)\n",
    "        self.p = nn.Linear(in_channels + 1, width)\n",
    "\n",
    "        # 4 amortized spectral conv layers\n",
    "        self.conv0 = SpectralConv1dMLP(width, width, n_basis, dropout=mlp_dropout)\n",
    "        self.conv1 = SpectralConv1dMLP(width, width, n_basis, dropout=mlp_dropout)\n",
    "        self.conv2 = SpectralConv1dMLP(width, width, n_basis, dropout=mlp_dropout)\n",
    "        self.conv3 = SpectralConv1dMLP(width, width, n_basis, dropout=mlp_dropout)\n",
    "\n",
    "        # 4 local channel-wise MLPs\n",
    "        self.mlp0 = MLP1d(width, width, 4 * width, dropout=mlp_dropout)\n",
    "        self.mlp1 = MLP1d(width, width, 4 * width, dropout=mlp_dropout)\n",
    "        self.mlp2 = MLP1d(width, width, 4 * width, dropout=mlp_dropout)\n",
    "        self.mlp3 = MLP1d(width, width, 4 * width, dropout=mlp_dropout)\n",
    "\n",
    "        # Final projection back to out_channels=1\n",
    "        self.q = MLP1d(width, out_channels, 4 * width, dropout=mlp_dropout)\n",
    "\n",
    "    def get_grid(self, shape, device):\n",
    "        \"\"\"\n",
    "        Generate 1D coordinate grid x in [-π, π) per batch.\n",
    "        shape: input shape (B, C, N)\n",
    "        returns: (B, N, 1)\n",
    "        \"\"\"\n",
    "        batch_size, _, N = shape\n",
    "        x = torch.linspace(-torch.pi, torch.pi, N + 1, device=device)[:-1]  # periodic grid\n",
    "        x = x.view(1, N, 1).repeat(batch_size, 1, 1)\n",
    "        return x\n",
    "\n",
    "    def forward(self, u):\n",
    "        \"\"\"\n",
    "        u: (B, 1, N)\n",
    "        returns: (B, 1, N)\n",
    "        \"\"\"\n",
    "        B, C, N = u.shape\n",
    "        device = u.device\n",
    "\n",
    "        # Build spatial grid and form input [u(x), x]\n",
    "        grid = self.get_grid(u.shape, device)           # (B, N, 1)\n",
    "        x = u.permute(0, 2, 1)                          # (B, N, 1)\n",
    "        x = torch.cat((x, grid), dim=-1)                # (B, N, in_channels+1)\n",
    "\n",
    "        # Lift to width channels\n",
    "        x = self.p(x)                                   # (B, N, width)\n",
    "        x = x.permute(0, 2, 1).contiguous()             # (B, width, N)\n",
    "\n",
    "        # Optional padding in spatial dimension (usually padding=0 for your setup)\n",
    "        if self.padding > 0:\n",
    "            x = F.pad(x, (0, self.padding))             # pad last dimension\n",
    "\n",
    "        # Block 0\n",
    "        x1 = self.conv0(x)\n",
    "        x1 = self.mlp0(x1)\n",
    "        x = x + x1\n",
    "        x = F.gelu(x)\n",
    "\n",
    "        # Block 1\n",
    "        x1 = self.conv1(x)\n",
    "        x1 = self.mlp1(x1)\n",
    "        x = x + x1\n",
    "        x = F.gelu(x)\n",
    "\n",
    "        # Block 2\n",
    "        x1 = self.conv2(x)\n",
    "        x1 = self.mlp2(x1)\n",
    "        x = x + x1\n",
    "        x = F.gelu(x)\n",
    "\n",
    "        # Block 3\n",
    "        x1 = self.conv3(x)\n",
    "        x1 = self.mlp3(x1)\n",
    "        x = x + x1\n",
    "        # last block often left without activation; you can add F.gelu(x) if desired\n",
    "\n",
    "        # Remove padding if any\n",
    "        if self.padding > 0:\n",
    "            x = x[..., :-self.padding]\n",
    "\n",
    "        # Final projection to 1 channel\n",
    "        x = self.q(x)                                   # (B, out_channels, N) with out_channels=1\n",
    "\n",
    "        # Return in (B, 1, N) format\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "\n",
    "\n",
    "\n",
    "def relative_l2_loss(g_pred, Gf):\n",
    "    \"\"\"\n",
    "    Computes relative L2 error: norm(g_pred - Gf) / (norm(Gf) + 1e-8)\n",
    "    \"\"\"\n",
    "    eps = 1e-8\n",
    "    return torch.norm(g_pred - Gf) / (torch.norm(Gf) + eps)\n",
    "\n",
    "def train_operator(\n",
    "    dataset_raw, \n",
    "    model, \n",
    "    num_epochs,\n",
    "    lr,\n",
    "    device,\n",
    "    batch_size=16,\n",
    "    test_dataset_raw=None,  # (NEW) for test data\n",
    "    reduce_on='test',       # which metric to monitor: 'train' or 'test'\n",
    "    factor=0.5,             # factor to reduce LR on plateau\n",
    "    patience=5,             # epochs of no improvement\n",
    "    min_lr=1e-6,            # minimal LR\n",
    "    cooldown=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Train the FNO model with:\n",
    "      1) A 'ReduceLROnPlateau' scheduler that halves LR on plateau.\n",
    "      2) An optional test dataset to compute and print test rel L2 each epoch.\n",
    "      3) Plot both train and test curves at the end.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset_raw : list of dict\n",
    "        Train samples in FNO style: 'f' (1,n), 'Gf' (1,n), 'x' (n,), 'func_type'.\n",
    "    model : torch.nn.Module\n",
    "        The FNO model to train.\n",
    "    num_epochs : int\n",
    "        Number of epochs to train.\n",
    "    lr : float\n",
    "        Initial learning rate for the optimizer.\n",
    "    device : torch.device\n",
    "        CPU/CUDA device.\n",
    "    batch_size : int\n",
    "        Batch size.\n",
    "    test_dataset_raw : list of dict, optional\n",
    "        If provided, used to compute test loss each epoch.\n",
    "    reduce_on : str\n",
    "        Either 'test' or 'train' - which metric to monitor for LR schedule.\n",
    "    factor : float\n",
    "        Multiplicative factor by which LR is reduced on plateau.\n",
    "    patience : int\n",
    "        Number of epochs of no improvement before reducing LR.\n",
    "    min_lr : float\n",
    "        Lower bound on LR.\n",
    "    cooldown : int\n",
    "        Number of epochs to wait after LR is reduced before next reduce.\n",
    "    \"\"\"\n",
    "\n",
    "    model.to(device)\n",
    "\n",
    "    # 1) Prepare training dataset\n",
    "    #    We'll unify to a single resolution from first sample\n",
    "    target_n = dataset_raw[0][\"f\"].shape[-1]\n",
    "    dataset_fixed = [s for s in dataset_raw if s[\"f\"].shape[-1] == target_n]\n",
    "    if len(dataset_fixed) == 0:\n",
    "        raise ValueError(\"No samples with matching resolution found in training.\")\n",
    "\n",
    "    train_inputs = torch.stack([s[\"f\"] for s in dataset_fixed], dim=0)   # (N, 1, n)\n",
    "    train_targets = torch.stack([s[\"Gf\"] for s in dataset_fixed], dim=0) # (N, 1, n)\n",
    "    train_dataset = TensorDataset(train_inputs, train_targets)\n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # 2) Prepare test dataset (if any)\n",
    "    test_loader = None\n",
    "    if test_dataset_raw is not None:\n",
    "        # unify resolution as well\n",
    "        test_fixed = [s for s in test_dataset_raw if s[\"f\"].shape[-1] == target_n]\n",
    "        if len(test_fixed) == 0:\n",
    "            print(\"Warning: no test samples with matching resolution found. Test dataset ignored.\")\n",
    "            test_dataset_raw = None\n",
    "        else:\n",
    "            test_inputs = torch.stack([s[\"f\"] for s in test_fixed], dim=0)\n",
    "            test_targets = torch.stack([s[\"Gf\"] for s in test_fixed], dim=0)\n",
    "            test_dataset = TensorDataset(test_inputs, test_targets)\n",
    "            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    # 3) Define optimizer + LR scheduler\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer,\n",
    "        mode='min',     # we want to minimize loss\n",
    "        factor=factor,  \n",
    "        patience=patience,\n",
    "        threshold=1e-4,\n",
    "        cooldown=cooldown,\n",
    "        min_lr=min_lr,\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    train_losses = []\n",
    "    test_losses = []\n",
    "    start_time = time.time()\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_start = time.time()\n",
    "\n",
    "        # (a) train loop\n",
    "        running_loss = 0.0\n",
    "        for batch_f, batch_Gf in train_loader:\n",
    "            batch_f = batch_f.to(device)\n",
    "            batch_Gf = batch_Gf.to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            pred_Gf = model(batch_f)  # shape (batch, 1, n)\n",
    "            loss = relative_l2_loss(pred_Gf, batch_Gf)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * batch_f.size(0)\n",
    "\n",
    "        epoch_train_loss = running_loss / len(train_dataset)\n",
    "        train_losses.append(epoch_train_loss)\n",
    "\n",
    "        # (b) test loop (if provided)\n",
    "        if test_loader is not None:\n",
    "            model.eval()\n",
    "            test_running_loss = 0.0\n",
    "            with torch.no_grad():\n",
    "                for batch_f, batch_Gf in test_loader:\n",
    "                    batch_f = batch_f.to(device)\n",
    "                    batch_Gf = batch_Gf.to(device)\n",
    "                    pred_Gf = model(batch_f)\n",
    "                    test_loss = relative_l2_loss(pred_Gf, batch_Gf)\n",
    "                    test_running_loss += test_loss.item() * batch_f.size(0)\n",
    "            epoch_test_loss = test_running_loss / len(test_loader.dataset)\n",
    "            test_losses.append(epoch_test_loss)\n",
    "            model.train()\n",
    "        else:\n",
    "            epoch_test_loss = None\n",
    "\n",
    "        # (c) update LR via scheduler\n",
    "        # decide which metric to monitor: train or test\n",
    "        if reduce_on == 'test' and epoch_test_loss is not None:\n",
    "            scheduler.step(epoch_test_loss)\n",
    "        else:\n",
    "            # fallback: train\n",
    "            scheduler.step(epoch_train_loss)\n",
    "\n",
    "        # (d) print progress\n",
    "        epoch_time = time.time() - epoch_start\n",
    "        current_lr = optimizer.param_groups[0]['lr']\n",
    "        if epoch_test_loss is not None:\n",
    "            print(f\"Epoch [{epoch+1}/{num_epochs}] | \"\n",
    "                  f\"Train Loss={epoch_train_loss:.9f} | Test Loss={epoch_test_loss:.9f} | \"\n",
    "                  f\"Time={epoch_time:.2f}s | LR={current_lr:.2e}\")\n",
    "        else:\n",
    "            print(f\"Epoch [{epoch+1}/{num_epochs}] | \"\n",
    "                  f\"Train Loss={epoch_train_loss:.9f} | Time={epoch_time:.2f}s | \"\n",
    "                  f\"LR={current_lr:.2e}\")\n",
    "\n",
    "    total_time = time.time() - start_time\n",
    "    print(f\"\\nTotal training time: {total_time:.2f} seconds\")\n",
    "\n",
    "    # (e) plot train & test\n",
    "    plt.figure()\n",
    "    plt.semilogy(train_losses, label=\"Train Rel L2\")\n",
    "    if test_loader is not None:\n",
    "        plt.semilogy(test_losses, label=\"Test Rel L2\")\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Relative L2 Error\")\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "    plt.show()\n",
    "\n",
    "    return model, train_losses, test_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 1. Dataset conversion (same as before)\n",
    "# ------------------------------------------------------------------\n",
    "def convert_dataset_for_fno(dataset):\n",
    "    \"\"\"\n",
    "    Convert each sample in the dataset to the evaluation format:\n",
    "      - 'x': (n,)\n",
    "      - 'f': (1, n)\n",
    "      - 'Gf': (1, n)\n",
    "      - 'func_type': str\n",
    "    \"\"\"\n",
    "    new_dataset = []\n",
    "    for sample in dataset:\n",
    "        if len(sample) < 4:\n",
    "            raise ValueError(\"Each sample must have (x, f, Gf, func_type).\")\n",
    "        x, f, Gf, func_type = sample[:4]\n",
    "\n",
    "        # Force (1, n)\n",
    "        if f.ndim == 1:\n",
    "            f = f.unsqueeze(0)\n",
    "        if Gf.ndim == 1:\n",
    "            Gf = Gf.unsqueeze(0)\n",
    "\n",
    "        new_dataset.append({\n",
    "            'x': x,\n",
    "            'f': f,\n",
    "            'Gf': Gf,\n",
    "            'func_type': func_type\n",
    "        })\n",
    "    return new_dataset\n",
    "\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 2. Same relative L2 definition as training\n",
    "# ------------------------------------------------------------------\n",
    "def relative_l2_loss(g_pred, Gf, eps: float = 1e-8):\n",
    "    \"\"\"\n",
    "    Same as in the training loop:\n",
    "        ||g_pred - Gf|| / ( ||Gf|| + eps )\n",
    "    Works for any matching shapes.\n",
    "    \"\"\"\n",
    "    return torch.norm(g_pred - Gf) / (torch.norm(Gf) + eps)\n",
    "\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 3. Evaluation with *global* relative L2 per type (matches training)\n",
    "# ------------------------------------------------------------------\n",
    "def evaluate_operator(dataset, model, device=torch.device(\"cpu\")):\n",
    "    \"\"\"\n",
    "    Evaluate a 1D operator model (FNO, U-FNO, etc.) on a dataset.\n",
    "\n",
    "    - Feeds inputs with shape (batch=1, channels=1, n), same as training.\n",
    "    - Aggregates numerator and denominator over all samples of a given type:\n",
    "          E_type = sqrt( sum ||err_i||^2 ) / sqrt( sum ||Gf_i||^2 )\n",
    "      This is the dataset-level relative L2, consistent with your training loss\n",
    "      (just with batch_size = 1 instead of 16).\n",
    "\n",
    "    Returns:\n",
    "        type_errors: dict func_type -> relative L2 (float)\n",
    "        overall_error: single relative L2 over the whole dataset\n",
    "    \"\"\"\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    eps = 1e-8\n",
    "\n",
    "    # Accumulate squared errors and squared norms\n",
    "    num_total = 0.0\n",
    "    den_total = 0.0\n",
    "    num_by_type = defaultdict(float)\n",
    "    den_by_type = defaultdict(float)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for sample in dataset:\n",
    "            func_type = sample[\"func_type\"]\n",
    "            f = sample[\"f\"]    # (1, n)\n",
    "            Gf = sample[\"Gf\"]  # (1, n)\n",
    "\n",
    "            # Make a batch dimension: (1, 1, n)\n",
    "            if f.ndim == 2:\n",
    "                f_in = f.unsqueeze(0)\n",
    "            elif f.ndim == 3:\n",
    "                f_in = f\n",
    "            else:\n",
    "                raise ValueError(f\"Unexpected f.ndim={f.ndim}, expected 2 or 3.\")\n",
    "\n",
    "            if Gf.ndim == 2:\n",
    "                Gf_in = Gf.unsqueeze(0)\n",
    "            elif Gf.ndim == 3:\n",
    "                Gf_in = Gf\n",
    "            else:\n",
    "                raise ValueError(f\"Unexpected Gf.ndim={Gf.ndim}, expected 2 or 3.\")\n",
    "\n",
    "            f_in = f_in.to(device)   # (1, 1, n)\n",
    "            Gf_in = Gf_in.to(device) # (1, 1, n)\n",
    "\n",
    "            # Forward pass\n",
    "            g_pred = model(f_in)\n",
    "\n",
    "            # Try to coerce output into same shape as Gf_in\n",
    "            if g_pred.shape != Gf_in.shape:\n",
    "                # Common 1D cases: (1, n) or (1, n, 1)\n",
    "                if g_pred.ndim == 2 and g_pred.shape[0] == Gf_in.shape[0]:\n",
    "                    # (1, n) -> (1, 1, n)\n",
    "                    g_pred = g_pred.unsqueeze(1)\n",
    "                elif g_pred.ndim == 3 and g_pred.shape[1] == 1 \\\n",
    "                     and g_pred.shape[-1] == Gf_in.shape[-1]:\n",
    "                    # already (1, 1, n) – fine\n",
    "                    pass\n",
    "                else:\n",
    "                    raise RuntimeError(\n",
    "                        f\"Shape mismatch: g_pred {g_pred.shape}, Gf_in {Gf_in.shape}\"\n",
    "                    )\n",
    "\n",
    "            # Accumulate squared errors and norms\n",
    "            diff = g_pred - Gf_in\n",
    "            num = torch.sum(diff**2).item()\n",
    "            den = torch.sum(Gf_in**2).item()\n",
    "\n",
    "            num_total += num\n",
    "            den_total += den\n",
    "            num_by_type[func_type] += num\n",
    "            den_by_type[func_type] += den\n",
    "\n",
    "    # Compute dataset-level relative L2 per type\n",
    "    type_errors = {}\n",
    "    for ttype in sorted(num_by_type.keys()):\n",
    "        n = num_by_type[ttype]\n",
    "        d = den_by_type[ttype]\n",
    "        err = (n**0.5) / (d**0.5 + eps) if d > 0 else 0.0\n",
    "        type_errors[ttype] = err\n",
    "        print(f\"Function type: {ttype:20s} | Relative L2: {err:.8f}\")\n",
    "\n",
    "    # Overall dataset-level relative L2\n",
    "    overall_error = (num_total**0.5) / (den_total**0.5 + eps) if den_total > 0 else 0.0\n",
    "    print(f\"\\nOverall relative L2 error: {overall_error:.8f}\")\n",
    "\n",
    "    return type_errors, overall_error\n",
    "\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 4. Plotting helper (uses same shapes & loss)\n",
    "# ------------------------------------------------------------------\n",
    "def plot_functions(dataset, model, device=torch.device(\"cpu\"), max_plots=8):\n",
    "    \"\"\"\n",
    "    Plot input f(x), target Gf(x), and model prediction for up to max_plots samples.\n",
    "    Uses the same (1,1,n) input convention and the same relative L2 definition.\n",
    "    \"\"\"\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    num_samples = min(len(dataset), max_plots)\n",
    "    num_cols = 4\n",
    "    num_rows = (num_samples + num_cols - 1) // num_cols\n",
    "\n",
    "    plt.figure(figsize=(20, 5 * num_rows))\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for idx in range(num_samples):\n",
    "            sample = dataset[idx]\n",
    "            x = sample[\"x\"]    # (n,)\n",
    "            f = sample[\"f\"]    # (1, n)\n",
    "            Gf = sample[\"Gf\"]  # (1, n)\n",
    "            func_type = sample[\"func_type\"]\n",
    "\n",
    "            # (1,1,n)\n",
    "            if f.ndim == 2:\n",
    "                f_in = f.unsqueeze(0)\n",
    "            else:\n",
    "                f_in = f\n",
    "            if Gf.ndim == 2:\n",
    "                Gf_in = Gf.unsqueeze(0)\n",
    "            else:\n",
    "                Gf_in = Gf\n",
    "\n",
    "            f_in = f_in.to(device)\n",
    "            Gf_in = Gf_in.to(device)\n",
    "\n",
    "            g_pred = model(f_in)\n",
    "            if g_pred.ndim == 2:\n",
    "                g_pred = g_pred.unsqueeze(1)  # (1,1,n)\n",
    "\n",
    "            rel_err = relative_l2_loss(g_pred, Gf_in).item()\n",
    "\n",
    "            # Move to CPU and flatten to (n,)\n",
    "            x_np = x.cpu().numpy()\n",
    "            f_np = f.view(-1).cpu().numpy()\n",
    "            Gf_np = Gf.view(-1).cpu().numpy()\n",
    "            g_pred_np = g_pred.view(-1).cpu().numpy()\n",
    "\n",
    "            plt.subplot(num_rows, num_cols, idx + 1)\n",
    "            plt.plot(x_np, f_np, label=\"Input $f(x)$\", linewidth=2)\n",
    "            plt.plot(x_np, Gf_np, label=\"Target $Gf(x)$\", linestyle=\"--\", linewidth=2)\n",
    "            plt.plot(x_np, g_pred_np, label=\"Pred $g_{pred}(x)$\",\n",
    "                     linestyle=\"-.\", linewidth=2)\n",
    "            plt.xlabel(\"$x$\")\n",
    "            plt.ylabel(\"Value\")\n",
    "            plt.title(f\"{func_type}\\nSample {idx+1}, Rel-L2: {rel_err:.4f}\")\n",
    "            plt.legend()\n",
    "            plt.grid(True)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = convert_dataset_for_fno(torch.load('train_3.pt', weights_only=False))\n",
    "test_data = convert_dataset_for_fno(torch.load('test_3.pt', weights_only=False))\n",
    "sample_data = convert_dataset_for_fno(torch.load('sample_3.pt', weights_only=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AMFNO1d params: 547,729\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from neuralop.models import FNO1d\n",
    "from neuralop.layers.embeddings import GridEmbeddingND\n",
    "\n",
    "\n",
    "# device\n",
    "device = 'cuda:1'\n",
    "# AMFNO model: choose width & n_basis comparable to your 2-layer FNO (64 hidden, 128 modes)\n",
    "amfno = AMFNO1d(\n",
    "    width=48,\n",
    "    n_basis=12,         # #Chebyshev-like basis in frequency embedding\n",
    "    padding=0,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    mlp_dropout=0.0,\n",
    ").to(device)\n",
    "\n",
    "# Optional: parameter count\n",
    "def count_params(m):\n",
    "    return sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
    "print(f\"AMFNO1d params: {count_params(amfno):,}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jinlee/kn_mlp/kano/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/1000] | Train Loss=0.286321641 | Test Loss=0.044191527 | Time=1.98s | LR=1.00e-03\n",
      "Epoch [2/1000] | Train Loss=0.049260058 | Test Loss=0.083464422 | Time=1.69s | LR=1.00e-03\n",
      "Epoch [3/1000] | Train Loss=0.050010887 | Test Loss=0.038502580 | Time=1.69s | LR=1.00e-03\n",
      "Epoch [4/1000] | Train Loss=0.043989524 | Test Loss=0.104565120 | Time=1.69s | LR=1.00e-03\n",
      "Epoch [5/1000] | Train Loss=0.045799947 | Test Loss=0.030977814 | Time=1.71s | LR=1.00e-03\n",
      "Epoch [6/1000] | Train Loss=0.030598759 | Test Loss=0.020386922 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [7/1000] | Train Loss=0.028933001 | Test Loss=0.024046360 | Time=1.36s | LR=1.00e-03\n",
      "Epoch [8/1000] | Train Loss=0.028498358 | Test Loss=0.036491234 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [9/1000] | Train Loss=0.035136399 | Test Loss=0.059516872 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [10/1000] | Train Loss=0.034757489 | Test Loss=0.025690305 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [11/1000] | Train Loss=0.024921818 | Test Loss=0.035553168 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [12/1000] | Train Loss=0.029438545 | Test Loss=0.026953488 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [13/1000] | Train Loss=0.028154641 | Test Loss=0.022959080 | Time=1.75s | LR=1.00e-03\n",
      "Epoch [14/1000] | Train Loss=0.026662254 | Test Loss=0.023606858 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [15/1000] | Train Loss=0.023444871 | Test Loss=0.028801391 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [16/1000] | Train Loss=0.026779027 | Test Loss=0.019425216 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [17/1000] | Train Loss=0.032506680 | Test Loss=0.027078299 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [18/1000] | Train Loss=0.029249592 | Test Loss=0.024656185 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [19/1000] | Train Loss=0.024159416 | Test Loss=0.021239211 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [20/1000] | Train Loss=0.023676153 | Test Loss=0.051329618 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [21/1000] | Train Loss=0.028895415 | Test Loss=0.036913850 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [22/1000] | Train Loss=0.022005572 | Test Loss=0.023381725 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [23/1000] | Train Loss=0.022534386 | Test Loss=0.017666276 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [24/1000] | Train Loss=0.026411426 | Test Loss=0.027271644 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [25/1000] | Train Loss=0.028199531 | Test Loss=0.036552098 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [26/1000] | Train Loss=0.040598384 | Test Loss=0.018888872 | Time=1.63s | LR=1.00e-03\n",
      "Epoch [27/1000] | Train Loss=0.021306022 | Test Loss=0.024635903 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [28/1000] | Train Loss=0.032087805 | Test Loss=0.031222798 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [29/1000] | Train Loss=0.029980783 | Test Loss=0.063451846 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [30/1000] | Train Loss=0.027421822 | Test Loss=0.017570290 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [31/1000] | Train Loss=0.018634630 | Test Loss=0.020952654 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [32/1000] | Train Loss=0.035106904 | Test Loss=0.052340715 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [33/1000] | Train Loss=0.026900571 | Test Loss=0.046543280 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [34/1000] | Train Loss=0.023271625 | Test Loss=0.026472964 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [35/1000] | Train Loss=0.027527432 | Test Loss=0.029012890 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [36/1000] | Train Loss=0.025469348 | Test Loss=0.033017781 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [37/1000] | Train Loss=0.020400035 | Test Loss=0.022901535 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [38/1000] | Train Loss=0.022423178 | Test Loss=0.023138505 | Time=1.71s | LR=1.00e-03\n",
      "Epoch [39/1000] | Train Loss=0.020690542 | Test Loss=0.027729676 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [40/1000] | Train Loss=0.024976188 | Test Loss=0.025806608 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [41/1000] | Train Loss=0.019844205 | Test Loss=0.065669642 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [42/1000] | Train Loss=0.028660533 | Test Loss=0.031062770 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [43/1000] | Train Loss=0.020040778 | Test Loss=0.033105531 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [44/1000] | Train Loss=0.032296024 | Test Loss=0.060397533 | Time=1.58s | LR=1.00e-03\n",
      "Epoch [45/1000] | Train Loss=0.027476157 | Test Loss=0.014827017 | Time=1.67s | LR=1.00e-03\n",
      "Epoch [46/1000] | Train Loss=0.033682367 | Test Loss=0.027219445 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [47/1000] | Train Loss=0.023547051 | Test Loss=0.037024910 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [48/1000] | Train Loss=0.022472488 | Test Loss=0.028318854 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [49/1000] | Train Loss=0.019937076 | Test Loss=0.026808482 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [50/1000] | Train Loss=0.019466308 | Test Loss=0.051708649 | Time=1.75s | LR=1.00e-03\n",
      "Epoch [51/1000] | Train Loss=0.038193156 | Test Loss=0.033200696 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [52/1000] | Train Loss=0.022068934 | Test Loss=0.025051217 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [53/1000] | Train Loss=0.021515422 | Test Loss=0.016924898 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [54/1000] | Train Loss=0.020327656 | Test Loss=0.017215120 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [55/1000] | Train Loss=0.031737543 | Test Loss=0.025852801 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [56/1000] | Train Loss=0.025415994 | Test Loss=0.025347431 | Time=1.64s | LR=1.00e-03\n",
      "Epoch [57/1000] | Train Loss=0.018390979 | Test Loss=0.019829101 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [58/1000] | Train Loss=0.031776875 | Test Loss=0.044618657 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [59/1000] | Train Loss=0.039391323 | Test Loss=0.014310582 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [60/1000] | Train Loss=0.024021650 | Test Loss=0.016034992 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [61/1000] | Train Loss=0.018613383 | Test Loss=0.020869233 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [62/1000] | Train Loss=0.020951319 | Test Loss=0.020688273 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [63/1000] | Train Loss=0.017870479 | Test Loss=0.023805745 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [64/1000] | Train Loss=0.016894424 | Test Loss=0.023000332 | Time=1.72s | LR=1.00e-03\n",
      "Epoch [65/1000] | Train Loss=0.022549824 | Test Loss=0.031749858 | Time=1.83s | LR=1.00e-03\n",
      "Epoch [66/1000] | Train Loss=0.021583718 | Test Loss=0.012195684 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [67/1000] | Train Loss=0.014631060 | Test Loss=0.010667902 | Time=1.81s | LR=1.00e-03\n",
      "Epoch [68/1000] | Train Loss=0.023732395 | Test Loss=0.023913170 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [69/1000] | Train Loss=0.023158810 | Test Loss=0.030860372 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [70/1000] | Train Loss=0.031874003 | Test Loss=0.029922354 | Time=1.78s | LR=1.00e-03\n",
      "Epoch [71/1000] | Train Loss=0.022245338 | Test Loss=0.024350789 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [72/1000] | Train Loss=0.021229764 | Test Loss=0.018462920 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [73/1000] | Train Loss=0.015135599 | Test Loss=0.012129271 | Time=1.85s | LR=1.00e-03\n",
      "Epoch [74/1000] | Train Loss=0.017593756 | Test Loss=0.020098493 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [75/1000] | Train Loss=0.025589645 | Test Loss=0.016876508 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [76/1000] | Train Loss=0.014716155 | Test Loss=0.010421320 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [77/1000] | Train Loss=0.023476948 | Test Loss=0.032705409 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [78/1000] | Train Loss=0.016443718 | Test Loss=0.028265370 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [79/1000] | Train Loss=0.021861652 | Test Loss=0.073799379 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [80/1000] | Train Loss=0.023831376 | Test Loss=0.015900472 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [81/1000] | Train Loss=0.023905101 | Test Loss=0.015139786 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [82/1000] | Train Loss=0.017025760 | Test Loss=0.012026656 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [83/1000] | Train Loss=0.015798437 | Test Loss=0.020750548 | Time=1.73s | LR=1.00e-03\n",
      "Epoch [84/1000] | Train Loss=0.016950501 | Test Loss=0.028231430 | Time=1.84s | LR=1.00e-03\n",
      "Epoch [85/1000] | Train Loss=0.024194096 | Test Loss=0.011018839 | Time=1.83s | LR=1.00e-03\n",
      "Epoch [86/1000] | Train Loss=0.018193213 | Test Loss=0.013732928 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [87/1000] | Train Loss=0.016779652 | Test Loss=0.016689505 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [88/1000] | Train Loss=0.018723049 | Test Loss=0.013779929 | Time=1.80s | LR=1.00e-03\n",
      "Epoch [89/1000] | Train Loss=0.017200490 | Test Loss=0.015772834 | Time=1.78s | LR=1.00e-03\n",
      "Epoch [90/1000] | Train Loss=0.023845984 | Test Loss=0.017576328 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [91/1000] | Train Loss=0.025053129 | Test Loss=0.016535178 | Time=1.72s | LR=1.00e-03\n",
      "Epoch [92/1000] | Train Loss=0.017470750 | Test Loss=0.015328557 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [93/1000] | Train Loss=0.015795735 | Test Loss=0.027493582 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [94/1000] | Train Loss=0.020162104 | Test Loss=0.050045011 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [95/1000] | Train Loss=0.027104003 | Test Loss=0.022771450 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [96/1000] | Train Loss=0.015964200 | Test Loss=0.023204472 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [97/1000] | Train Loss=0.022243856 | Test Loss=0.021614029 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [98/1000] | Train Loss=0.016859436 | Test Loss=0.014490484 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [99/1000] | Train Loss=0.021974719 | Test Loss=0.051521123 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [100/1000] | Train Loss=0.014846944 | Test Loss=0.009875755 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [101/1000] | Train Loss=0.016064169 | Test Loss=0.014650204 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [102/1000] | Train Loss=0.023585076 | Test Loss=0.037385628 | Time=1.74s | LR=1.00e-03\n",
      "Epoch [103/1000] | Train Loss=0.014721341 | Test Loss=0.021590664 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [104/1000] | Train Loss=0.015664910 | Test Loss=0.035309284 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [105/1000] | Train Loss=0.017300628 | Test Loss=0.014192342 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [106/1000] | Train Loss=0.022454511 | Test Loss=0.017171880 | Time=1.80s | LR=1.00e-03\n",
      "Epoch [107/1000] | Train Loss=0.013530774 | Test Loss=0.025218476 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [108/1000] | Train Loss=0.014868569 | Test Loss=0.010239879 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [109/1000] | Train Loss=0.020663920 | Test Loss=0.015976788 | Time=1.71s | LR=1.00e-03\n",
      "Epoch [110/1000] | Train Loss=0.023690615 | Test Loss=0.011051419 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [111/1000] | Train Loss=0.014599033 | Test Loss=0.012777318 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [112/1000] | Train Loss=0.018154190 | Test Loss=0.016501982 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [113/1000] | Train Loss=0.025828299 | Test Loss=0.020361761 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [114/1000] | Train Loss=0.025293433 | Test Loss=0.012202368 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [115/1000] | Train Loss=0.013610672 | Test Loss=0.017584160 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [116/1000] | Train Loss=0.014138807 | Test Loss=0.026146063 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [117/1000] | Train Loss=0.020057485 | Test Loss=0.035322176 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [118/1000] | Train Loss=0.016040401 | Test Loss=0.020322786 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [119/1000] | Train Loss=0.016842839 | Test Loss=0.020033114 | Time=1.70s | LR=1.00e-03\n",
      "Epoch [120/1000] | Train Loss=0.027132475 | Test Loss=0.036445652 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [121/1000] | Train Loss=0.016648687 | Test Loss=0.018802919 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [122/1000] | Train Loss=0.022036496 | Test Loss=0.014410821 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [123/1000] | Train Loss=0.024882249 | Test Loss=0.023883609 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [124/1000] | Train Loss=0.023296508 | Test Loss=0.017240746 | Time=1.81s | LR=1.00e-03\n",
      "Epoch [125/1000] | Train Loss=0.015397506 | Test Loss=0.016990841 | Time=1.80s | LR=1.00e-03\n",
      "Epoch [126/1000] | Train Loss=0.014380482 | Test Loss=0.011651464 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [127/1000] | Train Loss=0.012116173 | Test Loss=0.013631743 | Time=1.68s | LR=1.00e-03\n",
      "Epoch [128/1000] | Train Loss=0.014415020 | Test Loss=0.031064336 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [129/1000] | Train Loss=0.023300636 | Test Loss=0.012197589 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [130/1000] | Train Loss=0.024623399 | Test Loss=0.010372832 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [131/1000] | Train Loss=0.014093699 | Test Loss=0.018989619 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [132/1000] | Train Loss=0.023875406 | Test Loss=0.013557786 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [133/1000] | Train Loss=0.019133084 | Test Loss=0.010642661 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [134/1000] | Train Loss=0.017177297 | Test Loss=0.034722236 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [135/1000] | Train Loss=0.017959512 | Test Loss=0.014263117 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [136/1000] | Train Loss=0.012855386 | Test Loss=0.014463738 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [137/1000] | Train Loss=0.016536507 | Test Loss=0.010780150 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [138/1000] | Train Loss=0.013978913 | Test Loss=0.025497833 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [139/1000] | Train Loss=0.014899983 | Test Loss=0.016620519 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [140/1000] | Train Loss=0.017229106 | Test Loss=0.009183191 | Time=1.65s | LR=1.00e-03\n",
      "Epoch [141/1000] | Train Loss=0.017775477 | Test Loss=0.012487877 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [142/1000] | Train Loss=0.015258681 | Test Loss=0.021552090 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [143/1000] | Train Loss=0.013136732 | Test Loss=0.014182012 | Time=1.78s | LR=1.00e-03\n",
      "Epoch [144/1000] | Train Loss=0.014373159 | Test Loss=0.012827940 | Time=1.76s | LR=1.00e-03\n",
      "Epoch [145/1000] | Train Loss=0.012744490 | Test Loss=0.013052632 | Time=1.74s | LR=1.00e-03\n",
      "Epoch [146/1000] | Train Loss=0.013579930 | Test Loss=0.011859368 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [147/1000] | Train Loss=0.015032002 | Test Loss=0.009871467 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [148/1000] | Train Loss=0.014436917 | Test Loss=0.035125303 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [149/1000] | Train Loss=0.019574842 | Test Loss=0.014465917 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [150/1000] | Train Loss=0.016246093 | Test Loss=0.019358379 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [151/1000] | Train Loss=0.012797174 | Test Loss=0.011700120 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [152/1000] | Train Loss=0.012319692 | Test Loss=0.017025892 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [153/1000] | Train Loss=0.013105163 | Test Loss=0.037310640 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [154/1000] | Train Loss=0.015937931 | Test Loss=0.012521866 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [155/1000] | Train Loss=0.013679239 | Test Loss=0.015255251 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [156/1000] | Train Loss=0.020771852 | Test Loss=0.013660915 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [157/1000] | Train Loss=0.021249762 | Test Loss=0.019196767 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [158/1000] | Train Loss=0.015843426 | Test Loss=0.016719114 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [159/1000] | Train Loss=0.013006811 | Test Loss=0.007055372 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [160/1000] | Train Loss=0.015680274 | Test Loss=0.018998266 | Time=1.72s | LR=1.00e-03\n",
      "Epoch [161/1000] | Train Loss=0.012210768 | Test Loss=0.009702132 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [162/1000] | Train Loss=0.010454313 | Test Loss=0.018517655 | Time=1.82s | LR=1.00e-03\n",
      "Epoch [163/1000] | Train Loss=0.011807802 | Test Loss=0.012092944 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [164/1000] | Train Loss=0.014235598 | Test Loss=0.011066155 | Time=1.62s | LR=1.00e-03\n",
      "Epoch [165/1000] | Train Loss=0.021740604 | Test Loss=0.024392795 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [166/1000] | Train Loss=0.013378515 | Test Loss=0.014516368 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [167/1000] | Train Loss=0.011665452 | Test Loss=0.017099380 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [168/1000] | Train Loss=0.012766481 | Test Loss=0.015301677 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [169/1000] | Train Loss=0.015107603 | Test Loss=0.009434779 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [170/1000] | Train Loss=0.009315505 | Test Loss=0.021999629 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [171/1000] | Train Loss=0.012702054 | Test Loss=0.007737099 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [172/1000] | Train Loss=0.013995299 | Test Loss=0.013206258 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [173/1000] | Train Loss=0.011916973 | Test Loss=0.017574128 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [174/1000] | Train Loss=0.015218516 | Test Loss=0.012057918 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [175/1000] | Train Loss=0.010550761 | Test Loss=0.020623040 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [176/1000] | Train Loss=0.020854627 | Test Loss=0.037892735 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [177/1000] | Train Loss=0.013688298 | Test Loss=0.017669543 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [178/1000] | Train Loss=0.014154352 | Test Loss=0.017234593 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [179/1000] | Train Loss=0.011533509 | Test Loss=0.017657225 | Time=1.60s | LR=1.00e-03\n",
      "Epoch [180/1000] | Train Loss=0.020050955 | Test Loss=0.023764076 | Time=1.79s | LR=1.00e-03\n",
      "Epoch [181/1000] | Train Loss=0.020884520 | Test Loss=0.008311835 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [182/1000] | Train Loss=0.017774845 | Test Loss=0.015125889 | Time=1.73s | LR=1.00e-03\n",
      "Epoch [183/1000] | Train Loss=0.010737523 | Test Loss=0.009104400 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [184/1000] | Train Loss=0.010575741 | Test Loss=0.013168866 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [185/1000] | Train Loss=0.016427081 | Test Loss=0.017364398 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [186/1000] | Train Loss=0.012394348 | Test Loss=0.023523842 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [187/1000] | Train Loss=0.014757471 | Test Loss=0.030786187 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [188/1000] | Train Loss=0.020584452 | Test Loss=0.025604134 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [189/1000] | Train Loss=0.017829664 | Test Loss=0.006780781 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [190/1000] | Train Loss=0.009295526 | Test Loss=0.019060050 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [191/1000] | Train Loss=0.015806454 | Test Loss=0.008798435 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [192/1000] | Train Loss=0.010121709 | Test Loss=0.009646257 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [193/1000] | Train Loss=0.014981583 | Test Loss=0.024063332 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [194/1000] | Train Loss=0.016246052 | Test Loss=0.010319642 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [195/1000] | Train Loss=0.014825736 | Test Loss=0.019797101 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [196/1000] | Train Loss=0.011445136 | Test Loss=0.013273343 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [197/1000] | Train Loss=0.014176919 | Test Loss=0.008295156 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [198/1000] | Train Loss=0.015456608 | Test Loss=0.012235371 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [199/1000] | Train Loss=0.010592930 | Test Loss=0.008813516 | Time=1.65s | LR=1.00e-03\n",
      "Epoch [200/1000] | Train Loss=0.010656401 | Test Loss=0.028178163 | Time=1.80s | LR=1.00e-03\n",
      "Epoch [201/1000] | Train Loss=0.015798726 | Test Loss=0.016801234 | Time=1.67s | LR=1.00e-03\n",
      "Epoch [202/1000] | Train Loss=0.009234674 | Test Loss=0.009090753 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [203/1000] | Train Loss=0.008153973 | Test Loss=0.013887029 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [204/1000] | Train Loss=0.010168195 | Test Loss=0.017379817 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [205/1000] | Train Loss=0.011916126 | Test Loss=0.019972793 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [206/1000] | Train Loss=0.011494529 | Test Loss=0.028922875 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [207/1000] | Train Loss=0.017309712 | Test Loss=0.012422063 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [208/1000] | Train Loss=0.010602325 | Test Loss=0.010120640 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [209/1000] | Train Loss=0.015008785 | Test Loss=0.015378133 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [210/1000] | Train Loss=0.009540380 | Test Loss=0.011865717 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [211/1000] | Train Loss=0.010044513 | Test Loss=0.009120913 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [212/1000] | Train Loss=0.014609052 | Test Loss=0.009658461 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [213/1000] | Train Loss=0.010449588 | Test Loss=0.014534015 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [214/1000] | Train Loss=0.009387262 | Test Loss=0.006456372 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [215/1000] | Train Loss=0.009040630 | Test Loss=0.012277640 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [216/1000] | Train Loss=0.008985227 | Test Loss=0.011055193 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [217/1000] | Train Loss=0.009899203 | Test Loss=0.011491880 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [218/1000] | Train Loss=0.009983163 | Test Loss=0.008966735 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [219/1000] | Train Loss=0.009506866 | Test Loss=0.012910188 | Time=1.67s | LR=1.00e-03\n",
      "Epoch [220/1000] | Train Loss=0.010662508 | Test Loss=0.016918881 | Time=1.63s | LR=1.00e-03\n",
      "Epoch [221/1000] | Train Loss=0.013031515 | Test Loss=0.008284952 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [222/1000] | Train Loss=0.010293285 | Test Loss=0.018586449 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [223/1000] | Train Loss=0.008725396 | Test Loss=0.009077109 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [224/1000] | Train Loss=0.009811476 | Test Loss=0.009899091 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [225/1000] | Train Loss=0.010123812 | Test Loss=0.012349967 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [226/1000] | Train Loss=0.010354900 | Test Loss=0.005515075 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [227/1000] | Train Loss=0.013796168 | Test Loss=0.023704638 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [228/1000] | Train Loss=0.014510186 | Test Loss=0.008930883 | Time=1.58s | LR=1.00e-03\n",
      "Epoch [229/1000] | Train Loss=0.010039835 | Test Loss=0.014190253 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [230/1000] | Train Loss=0.009574046 | Test Loss=0.013398019 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [231/1000] | Train Loss=0.009683912 | Test Loss=0.010740340 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [232/1000] | Train Loss=0.014319896 | Test Loss=0.012330639 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [233/1000] | Train Loss=0.014597493 | Test Loss=0.015419357 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [234/1000] | Train Loss=0.010571083 | Test Loss=0.005998607 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [235/1000] | Train Loss=0.013035742 | Test Loss=0.019205687 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [236/1000] | Train Loss=0.011255915 | Test Loss=0.007476071 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [237/1000] | Train Loss=0.014853228 | Test Loss=0.009798053 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [238/1000] | Train Loss=0.010133093 | Test Loss=0.007524387 | Time=1.59s | LR=1.00e-03\n",
      "Epoch [239/1000] | Train Loss=0.009915308 | Test Loss=0.011216426 | Time=1.44s | LR=1.00e-03\n",
      "Epoch [240/1000] | Train Loss=0.008451843 | Test Loss=0.006826894 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [241/1000] | Train Loss=0.010811124 | Test Loss=0.019705206 | Time=1.78s | LR=1.00e-03\n",
      "Epoch [242/1000] | Train Loss=0.018744766 | Test Loss=0.028739230 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [243/1000] | Train Loss=0.019345279 | Test Loss=0.023248717 | Time=1.77s | LR=1.00e-03\n",
      "Epoch [244/1000] | Train Loss=0.010940438 | Test Loss=0.007139184 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [245/1000] | Train Loss=0.004651311 | Test Loss=0.006738756 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [246/1000] | Train Loss=0.004952674 | Test Loss=0.005485796 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [247/1000] | Train Loss=0.004820728 | Test Loss=0.005753783 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [248/1000] | Train Loss=0.004111154 | Test Loss=0.006552751 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [249/1000] | Train Loss=0.006638076 | Test Loss=0.017585233 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [250/1000] | Train Loss=0.008102331 | Test Loss=0.008338084 | Time=1.78s | LR=5.00e-04\n",
      "Epoch [251/1000] | Train Loss=0.003869395 | Test Loss=0.008228407 | Time=1.76s | LR=5.00e-04\n",
      "Epoch [252/1000] | Train Loss=0.006107310 | Test Loss=0.010701233 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [253/1000] | Train Loss=0.005571680 | Test Loss=0.008951404 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [254/1000] | Train Loss=0.004923180 | Test Loss=0.006443322 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [255/1000] | Train Loss=0.005034020 | Test Loss=0.007950306 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [256/1000] | Train Loss=0.003838269 | Test Loss=0.007927450 | Time=1.61s | LR=5.00e-04\n",
      "Epoch [257/1000] | Train Loss=0.004949025 | Test Loss=0.007828483 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [258/1000] | Train Loss=0.003602091 | Test Loss=0.005110211 | Time=1.78s | LR=5.00e-04\n",
      "Epoch [259/1000] | Train Loss=0.004661274 | Test Loss=0.005732413 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [260/1000] | Train Loss=0.004354290 | Test Loss=0.006506987 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [261/1000] | Train Loss=0.004757305 | Test Loss=0.010369515 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [262/1000] | Train Loss=0.007664856 | Test Loss=0.007498787 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [263/1000] | Train Loss=0.005053780 | Test Loss=0.011978223 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [264/1000] | Train Loss=0.005347833 | Test Loss=0.006716546 | Time=1.81s | LR=5.00e-04\n",
      "Epoch [265/1000] | Train Loss=0.004856655 | Test Loss=0.006201205 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [266/1000] | Train Loss=0.004256435 | Test Loss=0.004861069 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [267/1000] | Train Loss=0.005767566 | Test Loss=0.007679269 | Time=1.79s | LR=5.00e-04\n",
      "Epoch [268/1000] | Train Loss=0.008114210 | Test Loss=0.007712509 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [269/1000] | Train Loss=0.008273947 | Test Loss=0.008184503 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [270/1000] | Train Loss=0.007911842 | Test Loss=0.008211769 | Time=1.78s | LR=5.00e-04\n",
      "Epoch [271/1000] | Train Loss=0.004614591 | Test Loss=0.007125329 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [272/1000] | Train Loss=0.004992995 | Test Loss=0.006666105 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [273/1000] | Train Loss=0.004695534 | Test Loss=0.006026357 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [274/1000] | Train Loss=0.004309446 | Test Loss=0.005533183 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [275/1000] | Train Loss=0.003934974 | Test Loss=0.005057023 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [276/1000] | Train Loss=0.005328283 | Test Loss=0.006108774 | Time=1.74s | LR=5.00e-04\n",
      "Epoch [277/1000] | Train Loss=0.004852713 | Test Loss=0.007050169 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [278/1000] | Train Loss=0.004797000 | Test Loss=0.009593302 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [279/1000] | Train Loss=0.005001409 | Test Loss=0.005674952 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [280/1000] | Train Loss=0.004980199 | Test Loss=0.007516134 | Time=1.76s | LR=5.00e-04\n",
      "Epoch [281/1000] | Train Loss=0.004838164 | Test Loss=0.005294175 | Time=1.81s | LR=5.00e-04\n",
      "Epoch [282/1000] | Train Loss=0.005159906 | Test Loss=0.004771134 | Time=1.79s | LR=5.00e-04\n",
      "Epoch [283/1000] | Train Loss=0.005399353 | Test Loss=0.006813499 | Time=1.79s | LR=5.00e-04\n",
      "Epoch [284/1000] | Train Loss=0.004104132 | Test Loss=0.004912550 | Time=1.78s | LR=5.00e-04\n",
      "Epoch [285/1000] | Train Loss=0.003938031 | Test Loss=0.004747078 | Time=1.78s | LR=5.00e-04\n",
      "Epoch [286/1000] | Train Loss=0.004535353 | Test Loss=0.006276592 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [287/1000] | Train Loss=0.004887426 | Test Loss=0.006878133 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [288/1000] | Train Loss=0.004570683 | Test Loss=0.006727584 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [289/1000] | Train Loss=0.004790989 | Test Loss=0.006987091 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [290/1000] | Train Loss=0.004190181 | Test Loss=0.005570428 | Time=1.61s | LR=5.00e-04\n",
      "Epoch [291/1000] | Train Loss=0.006108760 | Test Loss=0.012869818 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [292/1000] | Train Loss=0.009093952 | Test Loss=0.008446228 | Time=1.59s | LR=5.00e-04\n",
      "Epoch [293/1000] | Train Loss=0.007298711 | Test Loss=0.010078445 | Time=1.60s | LR=5.00e-04\n",
      "Epoch [294/1000] | Train Loss=0.007073690 | Test Loss=0.014322596 | Time=1.70s | LR=5.00e-04\n",
      "Epoch [295/1000] | Train Loss=0.005855039 | Test Loss=0.008821521 | Time=1.82s | LR=5.00e-04\n",
      "Epoch [296/1000] | Train Loss=0.007335057 | Test Loss=0.005544430 | Time=1.81s | LR=5.00e-04\n",
      "Epoch [297/1000] | Train Loss=0.007981779 | Test Loss=0.016018535 | Time=1.77s | LR=5.00e-04\n",
      "Epoch [298/1000] | Train Loss=0.007136343 | Test Loss=0.015172978 | Time=1.80s | LR=5.00e-04\n",
      "Epoch [299/1000] | Train Loss=0.007523652 | Test Loss=0.004468164 | Time=1.78s | LR=2.50e-04\n",
      "Epoch [300/1000] | Train Loss=0.001935575 | Test Loss=0.004661603 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [301/1000] | Train Loss=0.002215147 | Test Loss=0.004906749 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [302/1000] | Train Loss=0.002205084 | Test Loss=0.004398832 | Time=1.78s | LR=2.50e-04\n",
      "Epoch [303/1000] | Train Loss=0.002337400 | Test Loss=0.005015539 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [304/1000] | Train Loss=0.001839723 | Test Loss=0.006412080 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [305/1000] | Train Loss=0.003479043 | Test Loss=0.005130875 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [306/1000] | Train Loss=0.002238937 | Test Loss=0.004790126 | Time=1.80s | LR=2.50e-04\n",
      "Epoch [307/1000] | Train Loss=0.001995295 | Test Loss=0.004682370 | Time=1.65s | LR=2.50e-04\n",
      "Epoch [308/1000] | Train Loss=0.002356057 | Test Loss=0.004561868 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [309/1000] | Train Loss=0.002361238 | Test Loss=0.006208621 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [310/1000] | Train Loss=0.002540573 | Test Loss=0.004379115 | Time=1.60s | LR=2.50e-04\n",
      "Epoch [311/1000] | Train Loss=0.002529034 | Test Loss=0.004329353 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [312/1000] | Train Loss=0.002154304 | Test Loss=0.004599692 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [313/1000] | Train Loss=0.002650811 | Test Loss=0.004871121 | Time=1.79s | LR=2.50e-04\n",
      "Epoch [314/1000] | Train Loss=0.003337140 | Test Loss=0.004828154 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [315/1000] | Train Loss=0.003225535 | Test Loss=0.005628781 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [316/1000] | Train Loss=0.003321780 | Test Loss=0.005981723 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [317/1000] | Train Loss=0.002035020 | Test Loss=0.004428983 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [318/1000] | Train Loss=0.002509685 | Test Loss=0.004437572 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [319/1000] | Train Loss=0.002517852 | Test Loss=0.004967369 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [320/1000] | Train Loss=0.002378296 | Test Loss=0.004823300 | Time=1.76s | LR=2.50e-04\n",
      "Epoch [321/1000] | Train Loss=0.003266126 | Test Loss=0.005496332 | Time=1.80s | LR=2.50e-04\n",
      "Epoch [322/1000] | Train Loss=0.004547771 | Test Loss=0.008116354 | Time=1.79s | LR=2.50e-04\n",
      "Epoch [323/1000] | Train Loss=0.004425282 | Test Loss=0.008184951 | Time=1.78s | LR=2.50e-04\n",
      "Epoch [324/1000] | Train Loss=0.003685424 | Test Loss=0.005429763 | Time=1.73s | LR=2.50e-04\n",
      "Epoch [325/1000] | Train Loss=0.002490175 | Test Loss=0.004436566 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [326/1000] | Train Loss=0.002358598 | Test Loss=0.004296576 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [327/1000] | Train Loss=0.002566951 | Test Loss=0.005455723 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [328/1000] | Train Loss=0.002555611 | Test Loss=0.005825796 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [329/1000] | Train Loss=0.002308327 | Test Loss=0.004882907 | Time=1.60s | LR=2.50e-04\n",
      "Epoch [330/1000] | Train Loss=0.003075807 | Test Loss=0.006515483 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [331/1000] | Train Loss=0.004539238 | Test Loss=0.004940712 | Time=1.68s | LR=2.50e-04\n",
      "Epoch [332/1000] | Train Loss=0.005389344 | Test Loss=0.004886448 | Time=1.76s | LR=2.50e-04\n",
      "Epoch [333/1000] | Train Loss=0.002554405 | Test Loss=0.005091384 | Time=1.80s | LR=2.50e-04\n",
      "Epoch [334/1000] | Train Loss=0.002460880 | Test Loss=0.005085692 | Time=1.81s | LR=2.50e-04\n",
      "Epoch [335/1000] | Train Loss=0.002436109 | Test Loss=0.004447679 | Time=1.79s | LR=2.50e-04\n",
      "Epoch [336/1000] | Train Loss=0.002716991 | Test Loss=0.004224158 | Time=1.78s | LR=2.50e-04\n",
      "Epoch [337/1000] | Train Loss=0.002748584 | Test Loss=0.008979656 | Time=1.78s | LR=2.50e-04\n",
      "Epoch [338/1000] | Train Loss=0.004530673 | Test Loss=0.005465327 | Time=1.77s | LR=2.50e-04\n",
      "Epoch [339/1000] | Train Loss=0.002595369 | Test Loss=0.004442082 | Time=1.83s | LR=2.50e-04\n",
      "Epoch [340/1000] | Train Loss=0.002419281 | Test Loss=0.004457318 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [341/1000] | Train Loss=0.002519995 | Test Loss=0.004527700 | Time=1.82s | LR=2.50e-04\n",
      "Epoch [342/1000] | Train Loss=0.002273848 | Test Loss=0.004357913 | Time=1.62s | LR=2.50e-04\n",
      "Epoch [343/1000] | Train Loss=0.002471091 | Test Loss=0.004262357 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [344/1000] | Train Loss=0.002242549 | Test Loss=0.004305383 | Time=1.59s | LR=2.50e-04\n",
      "Epoch [345/1000] | Train Loss=0.002033445 | Test Loss=0.004164132 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [346/1000] | Train Loss=0.001168695 | Test Loss=0.004793173 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [347/1000] | Train Loss=0.001152206 | Test Loss=0.004116571 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [348/1000] | Train Loss=0.001196921 | Test Loss=0.004635724 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [349/1000] | Train Loss=0.000970812 | Test Loss=0.004060778 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [350/1000] | Train Loss=0.000995356 | Test Loss=0.004151147 | Time=1.71s | LR=1.25e-04\n",
      "Epoch [351/1000] | Train Loss=0.001280656 | Test Loss=0.004060220 | Time=1.83s | LR=1.25e-04\n",
      "Epoch [352/1000] | Train Loss=0.001267635 | Test Loss=0.004369719 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [353/1000] | Train Loss=0.001129107 | Test Loss=0.004374504 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [354/1000] | Train Loss=0.001656477 | Test Loss=0.004418362 | Time=1.81s | LR=1.25e-04\n",
      "Epoch [355/1000] | Train Loss=0.001218244 | Test Loss=0.004021831 | Time=1.78s | LR=1.25e-04\n",
      "Epoch [356/1000] | Train Loss=0.001265531 | Test Loss=0.003960659 | Time=1.79s | LR=1.25e-04\n",
      "Epoch [357/1000] | Train Loss=0.001165155 | Test Loss=0.004539403 | Time=1.78s | LR=1.25e-04\n",
      "Epoch [358/1000] | Train Loss=0.001211996 | Test Loss=0.004469794 | Time=1.77s | LR=1.25e-04\n",
      "Epoch [359/1000] | Train Loss=0.001704641 | Test Loss=0.004339864 | Time=1.77s | LR=1.25e-04\n",
      "Epoch [360/1000] | Train Loss=0.001496744 | Test Loss=0.004010826 | Time=1.60s | LR=1.25e-04\n",
      "Epoch [361/1000] | Train Loss=0.001209549 | Test Loss=0.004281523 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [362/1000] | Train Loss=0.001131935 | Test Loss=0.004270575 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [363/1000] | Train Loss=0.001030986 | Test Loss=0.004080498 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [364/1000] | Train Loss=0.001643044 | Test Loss=0.004047329 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [365/1000] | Train Loss=0.002323779 | Test Loss=0.004387712 | Time=1.60s | LR=1.25e-04\n",
      "Epoch [366/1000] | Train Loss=0.001797392 | Test Loss=0.004381192 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [367/1000] | Train Loss=0.002043901 | Test Loss=0.004725376 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [368/1000] | Train Loss=0.002198511 | Test Loss=0.003970881 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [369/1000] | Train Loss=0.001216820 | Test Loss=0.003922125 | Time=1.73s | LR=1.25e-04\n",
      "Epoch [370/1000] | Train Loss=0.001260287 | Test Loss=0.004533096 | Time=1.83s | LR=1.25e-04\n",
      "Epoch [371/1000] | Train Loss=0.002160160 | Test Loss=0.003883773 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [372/1000] | Train Loss=0.001082141 | Test Loss=0.003911163 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [373/1000] | Train Loss=0.001255242 | Test Loss=0.003959359 | Time=1.77s | LR=1.25e-04\n",
      "Epoch [374/1000] | Train Loss=0.001183984 | Test Loss=0.003959202 | Time=1.81s | LR=1.25e-04\n",
      "Epoch [375/1000] | Train Loss=0.001033170 | Test Loss=0.003918951 | Time=1.81s | LR=1.25e-04\n",
      "Epoch [376/1000] | Train Loss=0.000845718 | Test Loss=0.003964726 | Time=1.79s | LR=1.25e-04\n",
      "Epoch [377/1000] | Train Loss=0.001062122 | Test Loss=0.004244553 | Time=1.71s | LR=1.25e-04\n",
      "Epoch [378/1000] | Train Loss=0.001224326 | Test Loss=0.003920684 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [379/1000] | Train Loss=0.001717178 | Test Loss=0.004576565 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [380/1000] | Train Loss=0.002306726 | Test Loss=0.005763690 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [381/1000] | Train Loss=0.001750017 | Test Loss=0.003802255 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [382/1000] | Train Loss=0.001187276 | Test Loss=0.004476613 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [383/1000] | Train Loss=0.001260431 | Test Loss=0.004279447 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [384/1000] | Train Loss=0.001228685 | Test Loss=0.003890966 | Time=1.60s | LR=1.25e-04\n",
      "Epoch [385/1000] | Train Loss=0.001298870 | Test Loss=0.004330422 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [386/1000] | Train Loss=0.000925459 | Test Loss=0.003876874 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [387/1000] | Train Loss=0.001229409 | Test Loss=0.004187929 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [388/1000] | Train Loss=0.001200225 | Test Loss=0.004499081 | Time=1.73s | LR=1.25e-04\n",
      "Epoch [389/1000] | Train Loss=0.001197943 | Test Loss=0.004186841 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [390/1000] | Train Loss=0.001213427 | Test Loss=0.004128042 | Time=1.83s | LR=1.25e-04\n",
      "Epoch [391/1000] | Train Loss=0.001613426 | Test Loss=0.003839786 | Time=1.77s | LR=1.25e-04\n",
      "Epoch [392/1000] | Train Loss=0.001215231 | Test Loss=0.003837968 | Time=1.81s | LR=1.25e-04\n",
      "Epoch [393/1000] | Train Loss=0.001231510 | Test Loss=0.004087323 | Time=1.80s | LR=1.25e-04\n",
      "Epoch [394/1000] | Train Loss=0.001143299 | Test Loss=0.004342484 | Time=1.82s | LR=1.25e-04\n",
      "Epoch [395/1000] | Train Loss=0.001222715 | Test Loss=0.003956354 | Time=1.69s | LR=1.25e-04\n",
      "Epoch [396/1000] | Train Loss=0.001227424 | Test Loss=0.004481652 | Time=1.59s | LR=1.25e-04\n",
      "Epoch [397/1000] | Train Loss=0.001149895 | Test Loss=0.004665480 | Time=1.59s | LR=1.25e-04\n"
     ]
    }
   ],
   "source": [
    "print(device)\n",
    "\n",
    "# Train using your existing training loop\n",
    "amfno, train_losses, test_losses = train_operator(\n",
    "    dataset_raw      = train_data,\n",
    "    model            = amfno,\n",
    "    num_epochs       = 1000,\n",
    "    lr               = 1e-3,\n",
    "    device           = device,\n",
    "    batch_size       = 16,\n",
    "    test_dataset_raw = test_data,\n",
    "    reduce_on        = 'train',\n",
    "    factor           = 0.5,\n",
    "    patience         = 40,\n",
    "    min_lr           = 1e-10,\n",
    "    cooldown         = 0,\n",
    ")\n",
    "\n",
    "torch.save(amfno.state_dict(), 'amfno_G3.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(device)\n",
    "\n",
    "# Train using your existing training loop\n",
    "amfno, train_losses, test_losses = train_operator(\n",
    "    dataset_raw      = train_data,\n",
    "    model            = amfno,\n",
    "    num_epochs       = 1000,\n",
    "    lr               = 1e-4,\n",
    "    device           = device,\n",
    "    batch_size       = 16,\n",
    "    test_dataset_raw = test_data,\n",
    "    reduce_on        = 'train',\n",
    "    factor           = 0.5,\n",
    "    patience         = 40,\n",
    "    min_lr           = 1e-10,\n",
    "    cooldown         = 0,\n",
    ")\n",
    "\n",
    "torch.save(amfno.state_dict(), 'amfno_G3.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------train----------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function type: chirped_cosine, Average Rel L2 error: 0.00029363\n",
      "Function type: sine_beats, Average Rel L2 error: 0.00120699\n",
      "Function type: periodic, Average Rel L2 error: 0.00190618\n",
      "\n",
      "Overall average relative L2 error: 0.00113632\n",
      "--------------test----------------\n",
      "Function type: sinc_pulse, Average Rel L2 error: 0.01165068\n",
      "Function type: gaussian_hermite, Average Rel L2 error: 0.00833317\n",
      "Function type: wave_packet, Average Rel L2 error: 0.00413562\n",
      "\n",
      "Overall average relative L2 error: 0.00813788\n",
      "-------------sample----------------\n",
      "Function type: periodic, Average Rel L2 error: 0.00203312\n",
      "Function type: chirped_cosine, Average Rel L2 error: 0.00041351\n",
      "Function type: sine_beats, Average Rel L2 error: 0.00131568\n",
      "\n",
      "Overall average relative L2 error: 0.00125236\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda:1'\n",
    "# Loss in train loop is batch-wise average loss. Metric is sample-wise average loss.\n",
    "print('--------------train----------------')\n",
    "loss = evaluate_operator(dataset = train_data, model = amfno, device = device)\n",
    "print('--------------test----------------')\n",
    "loss = evaluate_operator(dataset = test_data, model = amfno, device = device)\n",
    "print('-------------sample----------------')\n",
    "loss = evaluate_operator(dataset = sample_data, model = amfno, device = device)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kano",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
