{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import statements\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import samplers\n",
    "import numpy as np\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "importlib.reload(samplers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InputMapping(nn.Module):\n",
    "    \"\"\"Fourier features mapping.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self, d_in, n_freq, sigma=1, tdiv=2, incrementalMask=True, Tperiod=None, kill=False\n",
    "    ):\n",
    "        super().__init__()\n",
    "        Bmat = torch.randn(n_freq, d_in) * np.pi * sigma / np.sqrt(d_in)  # gaussian\n",
    "        # time frequencies are a quarter of spacial frequencies.\n",
    "        # Bmat[:, d_in-1] /= tdiv\n",
    "        Bmat[:, 0] /= tdiv\n",
    "\n",
    "        self.Tperiod = Tperiod\n",
    "        if Tperiod is not None:\n",
    "            # Tcycles = (Bmat[:, d_in-1]*Tperiod/(2*np.pi)).round()\n",
    "            # K = Tcycles*(2*np.pi)/Tperiod\n",
    "            # Bmat[:, d_in-1] = K\n",
    "            Tcycles = (Bmat[:, 0] * Tperiod / (2 * np.pi)).round()\n",
    "            K = Tcycles * (2 * np.pi) / Tperiod\n",
    "            Bmat[:, 0] = K\n",
    "\n",
    "        Bnorms = torch.norm(Bmat, p=2, dim=1)\n",
    "        sortedBnorms, sortIndices = torch.sort(Bnorms)\n",
    "        Bmat = Bmat[sortIndices, :]\n",
    "\n",
    "        self.d_in = d_in\n",
    "        self.n_freq = n_freq\n",
    "        self.d_out = n_freq * 2 + d_in if Tperiod is None else n_freq * 2 + d_in - 1\n",
    "        self.B = nn.Linear(d_in, self.d_out, bias=False)\n",
    "        with torch.no_grad():\n",
    "            self.B.weight = nn.Parameter(Bmat.to(device), requires_grad=False)\n",
    "            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)\n",
    "\n",
    "        self.incrementalMask = incrementalMask\n",
    "        if not incrementalMask:\n",
    "            self.mask = nn.Parameter(torch.ones(1, n_freq), requires_grad=False)\n",
    "        if kill:\n",
    "            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)\n",
    "\n",
    "    def step(self, progressPercent):\n",
    "        if self.incrementalMask:\n",
    "            float_filled = (progressPercent * self.n_freq) / 0.7\n",
    "            int_filled = int(float_filled // 1)\n",
    "            # remainder = float_filled % 1\n",
    "\n",
    "            if int_filled >= self.n_freq:\n",
    "                self.mask[0, :] = 1\n",
    "            else:\n",
    "                self.mask[0, 0:int_filled] = 1\n",
    "                # self.mask[0, int_filled] = remainder\n",
    "\n",
    "    def forward(self, xi):\n",
    "        # pdb.set_trace()\n",
    "        dim = self.d_in - 1 # was xi.shape[1] - 1\n",
    "        y = self.B(xi)\n",
    "        # Unsqueeze y and xi at dim=0 if they are 1D tensors\n",
    "        if len(y.shape) == 1:\n",
    "            y = y.unsqueeze(0)\n",
    "        if len(xi.shape) == 1:\n",
    "            xi = xi.unsqueeze(0)\n",
    "        if self.Tperiod is None:\n",
    "            return torch.cat([torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi], dim=-1)\n",
    "        else:\n",
    "            return torch.cat(\n",
    "                [torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi[:, 1 : dim + 1]], dim=-1\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smooth_leaky_relu(x, alpha=0.1):\n",
    "    return alpha * x + (1 - alpha) * F.softplus(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a two-layer MLP with output_dim=hidden_dims -- this is our \"h\" function\n",
    "\n",
    "class MLPh(nn.Module):\n",
    "    def __init__(self, base_dims, hidden_dims, fourier_map=None, residual=False):\n",
    "        super(MLPh, self).__init__()\n",
    "        self.fourier_map = fourier_map\n",
    "        self.residual = residual\n",
    "        self.base_dims = base_dims\n",
    "        if self.fourier_map is not None:\n",
    "            self.base_dims = fourier_map.d_out\n",
    "            print(self.base_dims)\n",
    "        self.fc1 = nn.Linear(self.base_dims, hidden_dims)\n",
    "        self.fc2 = nn.Linear(hidden_dims, hidden_dims)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.fourier_map is not None:\n",
    "            x = self.fourier_map(x)\n",
    "        x = F.elu(self.fc1(x)) # elu works well!\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a two-layer MLP with output_dim=1 -- this is our \"g\" function\n",
    "\n",
    "class MLPg(nn.Module):\n",
    "    def __init__(self, hidden_dims, out_dims=1, residual=False):\n",
    "        super(MLPg, self).__init__()\n",
    "        self.fc1 = nn.Linear(hidden_dims, hidden_dims)\n",
    "        self.fc2 = nn.Linear(hidden_dims, out_dims)\n",
    "        self.residual = residual\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.elu(self.fc1(x)) # elu works well!\n",
    "        x = self.fc2(x)\n",
    "        return x.squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target function is indicator of ball of radius R in 2D\n",
    "\n",
    "def target_fn(x, radius=1):\n",
    "    return (torch.linalg.norm(x, dim=1) <= radius).float()\n",
    "    # Indicator of unit square in 2D\n",
    "    # return (torch.abs(x[:, 0]) <= radius).float() * (torch.abs(x[:, 1]) <= radius).float()\n",
    "    # Indicator of union of two unit squares in 2D\n",
    "    # return (torch.abs(x[:, 0] - 0.5) <= radius).float() * (torch.abs(x[:, 1] - 0.5) <= radius).float() + (torch.abs(x[:, 0] + 0.5) <= radius).float() * (torch.abs(x[:, 1] + 0.5) <= radius).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot target function on grid\n",
    "\n",
    "n_grid = 100\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "y_grid_target = target_fn(X_grid).squeeze().detach().cpu()\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_target.reshape(n_grid, n_grid), cmap='coolwarm')\n",
    "plt.colorbar()\n",
    "plt.xlabel('x1')\n",
    "plt.ylabel('x2')\n",
    "plt.title('Target function')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_ball(num_points, dimension, radius=1):\n",
    "    # First generate random directions by normalizing the length of a\n",
    "    # vector of random-normal values (these distribute evenly on ball).\n",
    "    random_directions = np.random.normal(size=(dimension,num_points))\n",
    "    random_directions /= np.linalg.norm(random_directions, axis=0)\n",
    "    # Second generate a random radius with probability proportional to\n",
    "    # the surface area of a ball with a given radius.\n",
    "    random_radii = np.random.random(num_points) ** (1/dimension)\n",
    "    # Return the list of random (direction & length) points.\n",
    "    return torch.from_numpy(radius * (random_directions * random_radii).T).float().to(device) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $d=5$ visualizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate visualizations for experiments in '/rof_ours_vs_exact_results'\n",
    "\n",
    "# Load results\n",
    "\n",
    "reg_param = 0.05\n",
    "\n",
    "avg_values_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/avg_values_exact_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "avg_values_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/avg_values_our_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "losses_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/exact_nuc_losses_reg_param_' + str(reg_param) + '.npy')\n",
    "losses_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_losses_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "abs_errors_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/abs_errors_exact_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "abs_errors_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/abs_errors_our_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "# Load models\n",
    "\n",
    "d = 5\n",
    "base_dims = d\n",
    "hidden_dims = 100\n",
    "fourier_map = InputMapping(d_in=base_dims, n_freq=500, sigma=1, incrementalMask=False).to(device)\n",
    "model_exact_nuc = nn.Sequential(MLPh(base_dims, hidden_dims, fourier_map, residual=False), MLPg(hidden_dims, residual=False)).to(device)\n",
    "model_exact_nuc.load_state_dict(torch.load('rof_ours_denoising_vs_exact_results_100k_iters_d5/exact_nuc_model_reg_param_' + str(reg_param) + '.pt'))\n",
    "\n",
    "g_model_our_nuc = MLPg(hidden_dims, residual=False).to(device)\n",
    "h_model_our_nuc = MLPh(base_dims, hidden_dims, fourier_map=fourier_map, residual=False).to(device)\n",
    "g_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_g_model_reg_param_' + str(reg_param) + '.pt'))\n",
    "h_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters_d5/our_nuc_h_model_reg_param_' + str(reg_param) + '.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate avg_values_exact_nuc by copying last value\n",
    "\n",
    "avg_values_exact_nuc = np.concatenate((avg_values_exact_nuc, np.repeat(avg_values_exact_nuc[-1], len(avg_values_ours) - len(avg_values_exact_nuc))))\n",
    "\n",
    "# Plot average values\n",
    "\n",
    "plt.plot(avg_values_ours, label='Our Nuclear Norm', color='C1')\n",
    "plt.plot(avg_values_exact_nuc, label='Exact Nuclear Norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average value of function on unit disc')\n",
    "plt.title(\"Comparison of average values on unit disc\")\n",
    "plt.legend()\n",
    "\n",
    "# Horizontal line at correct avg value\n",
    "\n",
    "correct_val = 1 - d*reg_param\n",
    "\n",
    "plt.axhline(y=correct_val, color='r', linestyle='--')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate abs_errors_exact_nuc by copying last value\n",
    "\n",
    "abs_errors_exact_nuc = np.concatenate((abs_errors_exact_nuc, np.repeat(abs_errors_exact_nuc[-1], len(abs_errors_ours) - len(abs_errors_exact_nuc))))\n",
    "\n",
    "# Smooth the errors with a moving average\n",
    "\n",
    "window_size = 10\n",
    "abs_errors_ours = np.convolve(abs_errors_ours, np.ones(window_size)/window_size, mode='valid')\n",
    "abs_errors_exact_nuc = np.convolve(abs_errors_exact_nuc, np.ones(window_size)/window_size, mode='valid')\n",
    "\n",
    "# Plot absolute errors\n",
    "\n",
    "plt.plot(abs_errors_ours, label='Our regularizer', color='C1')\n",
    "plt.plot(abs_errors_exact_nuc, label='Exact nuclear norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Absolute error (smoothed)')\n",
    "plt.title(\"Absolute errors, $\\eta = 0.05$\")\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('results/abs_errors_d5_eta_0p05.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate losses_exact_nuc by copying last value\n",
    "\n",
    "losses_exact_nuc = np.concatenate((losses_exact_nuc, np.repeat(losses_exact_nuc[-1], len(losses_ours) - len(losses_exact_nuc))))\n",
    "\n",
    "# Smooth the losses with a moving average\n",
    "\n",
    "window_size = 10\n",
    "losses_ours_smoothed = np.convolve(losses_ours, np.ones(window_size)/window_size, mode='valid')\n",
    "losses_exact_nuc_smoothed = np.convolve(losses_exact_nuc, np.ones(window_size)/window_size, mode='valid')\n",
    "\n",
    "# Plot loss\n",
    "\n",
    "plt.plot(np.log(losses_ours_smoothed), label='Our regularizer', color='C1')\n",
    "plt.plot(np.log(losses_exact_nuc_smoothed), label='Exact nuclear norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Log loss')\n",
    "plt.title('Log-losses, $\\eta=0.05$')\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('results/log_losses_d5_eta_0p05.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot model_exact_nuc on grid\n",
    "\n",
    "n_grid = 100\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "# Concatenate 3 cols of zeros to X_grid to match input dimension of model_exact_nuc\n",
    "X_grid = torch.cat((X_grid, torch.zeros(X_grid.shape[0], 3).to(device)), dim=1)\n",
    "y_grid_exact = model_exact_nuc(X_grid).squeeze().detach().cpu()\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_exact.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "plt.xlabel('x1')\n",
    "plt.ylabel('x2')\n",
    "plt.title('Model with Exact Nuclear Norm Regularization')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot model_our_nuc on grid\n",
    "\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "# Concatenate 3 cols of zeros to X_grid to match input dimension of model_exact_nuc\n",
    "X_grid = torch.cat((X_grid, torch.zeros(X_grid.shape[0], 3).to(device)), dim=1)\n",
    "model_our_nuc = lambda x: g_model_our_nuc(h_model_our_nuc(x))\n",
    "y_grid_ours = model_our_nuc(X_grid).squeeze().detach().cpu()\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_ours.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "plt.xlabel('x1')\n",
    "plt.ylabel('x2')\n",
    "plt.title('Model with Our Nuclear Norm Regularization')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $d=2$ visualizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate visualizations for experiments in '/rof_ours_vs_exact_results'\n",
    "\n",
    "# Load results\n",
    "\n",
    "reg_param = 0.1\n",
    "\n",
    "avg_values_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/avg_values_exact_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "avg_values_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/avg_values_our_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "losses_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/exact_nuc_losses_reg_param_' + str(reg_param) + '.npy')\n",
    "losses_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_losses_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "abs_errors_exact_nuc = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/abs_errors_exact_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "abs_errors_ours = np.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/abs_errors_our_nuc_reg_param_' + str(reg_param) + '.npy')\n",
    "\n",
    "# Load models\n",
    "\n",
    "base_dims = 2\n",
    "hidden_dims = 100\n",
    "fourier_map = InputMapping(d_in=base_dims, n_freq=500, sigma=1, incrementalMask=False).to(device)\n",
    "model_exact_nuc = nn.Sequential(MLPh(base_dims, hidden_dims, fourier_map, residual=False), MLPg(hidden_dims, residual=False)).to(device)\n",
    "model_exact_nuc.load_state_dict(torch.load('rof_ours_denoising_vs_exact_results_100k_iters/exact_nuc_model_reg_param_' + str(reg_param) + '.pt'))\n",
    "\n",
    "g_model_our_nuc = MLPg(hidden_dims, residual=False).to(device)\n",
    "h_model_our_nuc = MLPh(base_dims, hidden_dims, fourier_map=fourier_map, residual=False).to(device)\n",
    "g_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_g_model_reg_param_' + str(reg_param) + '.pt'))\n",
    "h_model_our_nuc.load_state_dict(torch.load('/results/rof_ours_denoising_vs_exact_results_100k_iters/our_nuc_h_model_reg_param_' + str(reg_param) + '.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate avg_values_exact_nuc by copying last value\n",
    "\n",
    "avg_values_exact_nuc = np.concatenate((avg_values_exact_nuc, np.repeat(avg_values_exact_nuc[-1], len(avg_values_ours) - len(avg_values_exact_nuc))))\n",
    "\n",
    "# Plot average values\n",
    "\n",
    "plt.plot(avg_values_ours, label='Our Nuclear Norm', color='C1')\n",
    "plt.plot(avg_values_exact_nuc, label='Exact Nuclear Norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Average value of function on unit disc')\n",
    "plt.title(\"Comparison of average values on unit disc\")\n",
    "plt.legend()\n",
    "\n",
    "# Horizontal line at correct avg value\n",
    "\n",
    "correct_val = 1 - base_dims*reg_param\n",
    "\n",
    "plt.axhline(y=correct_val, color='r', linestyle='--')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate abs_errors_exact_nuc by copying last value\n",
    "\n",
    "abs_errors_exact_nuc = np.concatenate((abs_errors_exact_nuc, np.repeat(abs_errors_exact_nuc[-1], len(abs_errors_ours) - len(abs_errors_exact_nuc))))\n",
    "\n",
    "# Plot absolute errors\n",
    "\n",
    "# Make lines thicker\n",
    "plt.rcParams.update({'lines.linewidth': 3})\n",
    "# Make font larger\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "# Make font bolder\n",
    "plt.rcParams.update({'font.weight': 'bold'})\n",
    "# Also make title font bolder\n",
    "plt.rcParams.update({'axes.titleweight': 'bold'})\n",
    "# And make x-label and y-label bolder\n",
    "plt.rcParams.update({'axes.labelweight': 'bold'})\n",
    "\n",
    "plt.plot(abs_errors_ours, label='Our regularizer', color='C1')\n",
    "plt.plot(abs_errors_exact_nuc, label='Exact nuclear norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Absolute error')\n",
    "plt.title(\"Absolute errors, $\\eta = 0.1$\")\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('results/abs_errors_d2_eta_0p1.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrapolate losses_exact_nuc by copying last value\n",
    "\n",
    "losses_exact_nuc = np.concatenate((losses_exact_nuc, np.repeat(losses_exact_nuc[-1], len(losses_ours) - len(losses_exact_nuc))))\n",
    "\n",
    "# Plot loss\n",
    "\n",
    "# Make lines thicker\n",
    "plt.rcParams.update({'lines.linewidth': 3})\n",
    "# Make font larger\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "# Make font bolder\n",
    "plt.rcParams.update({'font.weight': 'bold'})\n",
    "# Also make title font bolder\n",
    "plt.rcParams.update({'axes.titleweight': 'bold'})\n",
    "# And make x-label and y-label bolder\n",
    "plt.rcParams.update({'axes.labelweight': 'bold'})\n",
    "\n",
    "plt.plot(np.log(losses_ours), label='Our regularizer', color='C1')\n",
    "plt.plot(np.log(losses_exact_nuc), label='Exact nuclear norm', color='C0')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Log loss')\n",
    "plt.title('Log-losses, $\\eta=0.1$')\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('results/log_losses_d2_eta_0p1.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot model_exact_nuc on grid\n",
    "\n",
    "n_grid = 100\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "# X_grid = fourier_map(X_grid)\n",
    "y_grid_exact = model_exact_nuc(X_grid).squeeze().detach().cpu()\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "# Make font bolder\n",
    "plt.rcParams.update({'font.weight': 'bold'})\n",
    "# Also make title font bolder\n",
    "plt.rcParams.update({'axes.titleweight': 'bold'})\n",
    "# And make x-label and y-label bolder\n",
    "plt.rcParams.update({'axes.labelweight': 'bold'})\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_exact.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "plt.xlabel('x1')\n",
    "plt.ylabel('x2')\n",
    "plt.title('Exact nuclear norm, $\\eta = 0.1$')\n",
    "\n",
    "plt.savefig('results/exact_nuclear_norm_eta_0p1.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot model_our_nuc on grid\n",
    "\n",
    "n_grid = 100\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "# X_grid = fourier_map(X_grid)\n",
    "model_our_nuc = lambda x: g_model_our_nuc(h_model_our_nuc(x))\n",
    "y_grid_ours = model_our_nuc(X_grid).squeeze().detach().cpu()\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "# Make font bolder\n",
    "plt.rcParams.update({'font.weight': 'bold'})\n",
    "# Also make title font bolder\n",
    "plt.rcParams.update({'axes.titleweight': 'bold'})\n",
    "# And make x-label and y-label bolder\n",
    "plt.rcParams.update({'axes.labelweight': 'bold'})\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_ours.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "# plt.xlabel('x1')\n",
    "# plt.ylabel('x2')\n",
    "plt.title('Our regularizer, $\\eta = 0.1$')\n",
    "\n",
    "plt.savefig('results/our_nuclear_norm_eta_0p1.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot exact solution on grid -- target_fn * (1 - 2*reg_param)\n",
    "\n",
    "n_grid = 100\n",
    "x_grid = torch.linspace(-2, 2, n_grid)\n",
    "X_grid = torch.stack(torch.meshgrid(x_grid, x_grid), dim=-1).reshape(-1, 2).to(device)\n",
    "y_grid_target = target_fn(X_grid).squeeze().detach().cpu()\n",
    "y_grid_exact_solution = y_grid_target * (1 - 2*reg_param)\n",
    "\n",
    "# Heatmap\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "# Make font bolder\n",
    "plt.rcParams.update({'font.weight': 'bold'})\n",
    "# Also make title font bolder\n",
    "plt.rcParams.update({'axes.titleweight': 'bold'})\n",
    "# And make x-label and y-label bolder\n",
    "plt.rcParams.update({'axes.labelweight': 'bold'})\n",
    "\n",
    "plt.figure()\n",
    "plt.pcolormesh(x_grid, x_grid, y_grid_exact_solution.reshape(n_grid, n_grid), cmap='coolwarm', vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "plt.title('Exact solution, $\\eta = 0.1$')\n",
    "\n",
    "plt.savefig('results/exact_solution_eta_0p1.png', dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm6",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
