{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "torch.manual_seed(123456)\n",
    "np.random.seed(123456)\n",
    "\n",
    "# CHOOSE EXMAPLE\n",
    "FULL_BATCH = False # True for full batch, False for mini-batch\n",
    "EXAMPLE =  \"approximate_optimum\"  # \"switching\" \"selection\" \"approximate_optimum\" \"quadratic-deep\" \"all\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = 20\n",
    "hidden_dim = 512\n",
    "output_dim = 7\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)  \n",
    "        self.fc2 = nn.Linear(hidden_dim, output_dim)  \n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))  \n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOSSES \n",
    "\n",
    "def poly_loss(outputs, y_tensor, H, alpha, epsilon=0.0):\n",
    "    \"\"\"\n",
    "    Custom loss function that implements (outputs-y_tensor)T H (outputs-y_tensor) ** n\n",
    "    Args:\n",
    "        outputs (torch.tensor): Model outputs\n",
    "        y_tensor (torch.tensor): Target tensor\n",
    "        H (torch.tensor): Hessian matrix\n",
    "        alpha (int): Polynomial degree\n",
    "    Returns:\n",
    "        loss (torch.tensor): Custom loss value\n",
    "    \"\"\"\n",
    "\n",
    "    # Calculate the difference between outputs and target tensor\n",
    "    diff = (outputs - y_tensor + epsilon)# [BATCH, OUTDIM]\n",
    "    # Calculate the quadratic form using the Hessian matrix\n",
    "    quad_form = torch.matmul(diff, H) # [BATCH, OUTDIM]    \n",
    "    quad_form = torch.sum(quad_form * diff, dim=1) # [BATCH]\n",
    "    # Raise the result to the power of n -> effective power = 2n\n",
    "    loss = quad_form ** alpha    \n",
    "    return loss.mean()\n",
    "\n",
    "\n",
    "if EXAMPLE == \"switching\":\n",
    "    # switching example\n",
    "    H = torch.eye(output_dim)\n",
    "    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0)\n",
    "    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5)\n",
    "    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0)\n",
    "\n",
    "elif EXAMPLE == \"selection\":\n",
    "    # selection example\n",
    "    alpha = 1.\n",
    "    H1 = torch.eye(output_dim)\n",
    "    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H1, alpha)\n",
    "    H2 = 0.01 * torch.eye(output_dim)\n",
    "    H2[0,0] = 1.\n",
    "    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H2, alpha)\n",
    "    H3 = 0.0001 * torch.eye(output_dim)\n",
    "    H3[0,0] = 1.\n",
    "    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H3, alpha)\n",
    "\n",
    "elif EXAMPLE == \"approximate_optimum\": \n",
    "    # not the same optimum exactly\n",
    "    H = torch.eye(output_dim)\n",
    "    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0, epsilon=0.0)\n",
    "    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5, epsilon=0.05)\n",
    "    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0, epsilon=-0.05)\n",
    "\n",
    "elif EXAMPLE == \"quadratic-deep\":\n",
    "    output_dim = 100\n",
    "    # quadratic-deep example\n",
    "    H = torch.eye(output_dim)\n",
    "    # 90% of eigenvalues drawn from [0.5,1], 10% drawn from [0,0.1]\n",
    "    # creates an ill conditioned problem\n",
    "    for i in range(90):\n",
    "        H[i,i] = torch.rand(1).item() * 0.5 + 0.5 # Random value between [0.5,1]\n",
    "    for i in range(90,100):\n",
    "        H[i,i] = torch.rand(1).item() * 0.001  # Random value between [0,0.01]\n",
    "    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0)\n",
    "    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5)\n",
    "    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0)\n",
    "\n",
    "elif EXAMPLE == \"all\":\n",
    "    output_dim = 100\n",
    "    H = torch.eye(output_dim)\n",
    "    # 90% of eigenvalues drawn from [0.5,1], 10% drawn from [0,0.1]\n",
    "    # creates an ill conditioned problem\n",
    "    for i in range(90):\n",
    "        H[i,i] = torch.rand(1).item() * 0.5 + 0.5 # Random value between [0.5,1]\n",
    "    for i in range(90,100):\n",
    "        H[i,i] = torch.rand(1).item() * 0.001  # Random value between [0,0.01]\n",
    "    criterion1 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.0, epsilon=0.0)\n",
    "    criterion2 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=1.5, epsilon=0.01)\n",
    "    criterion3 = lambda outputs, y_tensor: poly_loss(outputs, y_tensor, H, alpha=2.0, epsilon=-0.01)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GENERATE DATA\n",
    "\n",
    "if FULL_BATCH: # full batch gradient decent\n",
    "    n_gradient_steps = 200 \n",
    "    batch_size = 100 \n",
    "    X = np.repeat(np.random.uniform(-1, 1, (1, batch_size, input_dim)), repeats=n_gradient_steps, axis=0)\n",
    "else: # mini-batch gradient decent\n",
    "    n_gradient_steps = 1000\n",
    "    batch_size = 1000\n",
    "    X = np.random.uniform(-1, 1, (n_gradient_steps, batch_size, input_dim))\n",
    "\n",
    "X_tensors = torch.from_numpy(X).float()\n",
    "net = Net()\n",
    "\n",
    "if (EXAMPLE == \"quadratic-deep\" or EXAMPLE == \"all\"):\n",
    "    y_tensors = net(X_tensors).detach() / output_dim\n",
    "else:\n",
    "    y_tensors = net(X_tensors).detach() * 10. # make outputs larger to show the weight switch in the switching example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pamoo(outputs, loss1, loss2, loss3, w):\n",
    "    \n",
    "    jacobians = []\n",
    "    for loss in [loss1, loss2, loss3]:\n",
    "        grads = torch.autograd.grad(loss, outputs, create_graph=True) # [BATCH, OUTDIM]\n",
    "        jacobians.append(grads[0].sum(dim=0)) # [OUTDIM]\n",
    "    jacobian = torch.stack(jacobians, dim=0) # [NUM_LOSSES, OUTDIM]\n",
    "    jacobian = jacobian.transpose(1,0) # [OUTDIM, NUM_LOSSES]\n",
    "    diff = torch.stack([loss1, loss2, loss3],dim=0) # [NUM_LOSSES]\n",
    "    \n",
    "    A = jacobian.t() @ jacobian # [NUM_LOSSES, NUM_LOSSES]\n",
    "    A += 0.0001 * torch.eye(A.shape[0])\n",
    "    lr = 3e-3\n",
    "\n",
    "    for i in range(1000): \n",
    "        gradient = 2 * diff - 2 * torch.matmul(A, w)  \n",
    "        w = w + lr * gradient  \n",
    "        w = torch.clamp(w, min=1e-6) # projecting w to R+\n",
    "    return w\n",
    "    \n",
    "def mg_pamoo(outputs, loss1, loss2, loss3, w, use_eta=True, beta = 0.0):\n",
    "    w_old = w\n",
    "    \n",
    "    lower_bounds = (0, 0, 0) # specifically for our examples the all lower bounds are 0\n",
    "\n",
    "    diff = torch.stack([loss1-lower_bounds[0], loss2-lower_bounds[1], loss3-lower_bounds[2]],dim=0) # [NUM_LOSSES]\n",
    "    I = torch.argmax(diff)\n",
    "    w = torch.tensor([0, 0., 0.])\n",
    "    w[I] = 1.0\n",
    "\n",
    "    if not use_eta:\n",
    "         # apply momentum\n",
    "        w = beta * w_old + (1. - beta) * w\n",
    "        return w\n",
    "    \n",
    "    # get jacabian \n",
    "    loss = [loss1, loss2, loss3][I]\n",
    "    grads = torch.autograd.grad(loss, outputs, create_graph=True) # [BATCH, OUTDIM]\n",
    "    jacobian = grads[0].sum(dim=0) # [OUTDIM]\n",
    "    jacobian_norm = torch.norm(jacobian)**2 # [OUTDIM]\n",
    "    # compute eta\n",
    "    eta = diff[I] / jacobian_norm / 2. # [1]\n",
    "    eta = torch.clamp(eta, min=1e-6, max=1e3) # [1]\n",
    "\n",
    "    # apply momentum\n",
    "    w_new = w * eta\n",
    "    w = beta * w_old + (1. - beta) * w_new\n",
    "    return w \n",
    "\n",
    "def mg_amoo(outputs, loss1, loss2, loss3, w):\n",
    "    return mg_pamoo(outputs, loss1, loss2, loss3, w, use_eta=False)\n",
    "\n",
    "def mg_pamoo_with_momentum(outputs, loss1, loss2, loss3, w):\n",
    "    return mg_pamoo(outputs, loss1, loss2, loss3, w, beta=0.95)\n",
    "\n",
    "def mg_amoo_with_momentum(outputs, loss1, loss2, loss3, w):\n",
    "    return mg_pamoo(outputs, loss1, loss2, loss3, w, use_eta=False, beta=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def run_experiment(method=pamoo, optimizer= torch.optim.SGD, lr=0.005):\n",
    "    \n",
    "    # Initialize the neural network with random weights\n",
    "    net_learn = Net()\n",
    "    # Weight initialization\n",
    "    w = torch.ones(3)\n",
    "    # Initialize optimizer\n",
    "    optimizer = optimizer(net_learn.parameters(), lr=lr)\n",
    "    # Initialize dictionaries to store results\n",
    "    results = {\n",
    "        'weighted_losses': [],\n",
    "        'max_gap': [],\n",
    "        'losses': {'loss1': [], 'loss2': [], 'loss3': []},\n",
    "        'weights': {'w1': [], 'w2': [], 'w3': []}\n",
    "    }\n",
    "    \n",
    "    # Train the network\n",
    "    for _, (X_tensor, y_tensor) in tqdm(enumerate(zip(X_tensors, y_tensors)), total=len(X_tensors), desc=method.__name__ if method else \"EW\"):\n",
    "        # Forward pass\n",
    "        outputs = net_learn(X_tensor)\n",
    "        loss1 = criterion1(outputs, y_tensor)\n",
    "        loss2 = criterion2(outputs, y_tensor)\n",
    "        loss3 = criterion3(outputs, y_tensor)\n",
    "        max_gap = torch.max(torch.stack([loss1, loss2, loss3])) # we assume F* = 0\n",
    "        # Update weights using the specified method\n",
    "        if method:\n",
    "            with torch.no_grad():\n",
    "                w = method(outputs, loss1, loss2, loss3, w)\n",
    "        # Compute weighted loss\n",
    "        weighted_loss = w[0] * loss1 + w[1] * loss2 + w[2] * loss3\n",
    "        optimizer.zero_grad()\n",
    "        weighted_loss.backward()\n",
    "        optimizer.step()\n",
    "        # Log results\n",
    "        results['weighted_losses'].append(weighted_loss.item())\n",
    "        results['max_gap'].append(max_gap.item())\n",
    "        results['losses']['loss1'].append(loss1.item())\n",
    "        results['losses']['loss2'].append(loss2.item())\n",
    "        results['losses']['loss3'].append(loss3.item())\n",
    "        results['weights']['w1'].append(w[0].item())\n",
    "        results['weights']['w2'].append(w[1].item())\n",
    "        results['weights']['w3'].append(w[2].item())\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    # SGD-results\n",
    "    # EW \n",
    "    sgd_ew = run_experiment(method=None, optimizer= torch.optim.SGD,lr=0.0005)\n",
    "    # PAMOO\n",
    "    sgd_pamoo = run_experiment(method=pamoo,optimizer= torch.optim.SGD,lr=0.0005)\n",
    "    # MG-AMOO\n",
    "    sgd_mg_amoo = run_experiment(method=mg_amoo, optimizer= torch.optim.SGD,lr=0.0005)\n",
    "    # MG-PAMOO\n",
    "    sgd_mg_pamoo = run_experiment(method=mg_pamoo, optimizer= torch.optim.SGD,lr=0.0005)\n",
    "    # MG-AMOO with momentum\n",
    "    sgd_mg_amoo_with_momentum = run_experiment(method=mg_amoo_with_momentum, optimizer= torch.optim.SGD,lr=0.0005)\n",
    "    # MG-PAMOO with momentum\n",
    "    sgd_mg_pamoo_with_momentum = run_experiment(method=mg_pamoo_with_momentum, optimizer= torch.optim.SGD,lr=0.0005)\n",
    "\n",
    "    # Adam-results\n",
    "    # EW \n",
    "    adam_ew = run_experiment(method=None, optimizer= torch.optim.Adam,lr=0.002)\n",
    "    # PAMOO\n",
    "    adam_pamoo = run_experiment(method=pamoo,optimizer= torch.optim.Adam,lr=0.002)\n",
    "    # MG-AMOO\n",
    "    adam_mg_amoo = run_experiment(method=mg_amoo, optimizer= torch.optim.Adam,lr=0.002)\n",
    "    # MG-PAMOO\n",
    "    adam_mg_pamoo = run_experiment(method=mg_pamoo, optimizer= torch.optim.Adam,lr=0.002)\n",
    "    # MG-AMOO with momentum\n",
    "    adam_mg_amoo_with_momentum = run_experiment(method=mg_amoo_with_momentum, optimizer= torch.optim.Adam,lr=0.002)\n",
    "    # MG-PAMOO with momentum\n",
    "    adam_mg_pamoo_with_momentum = run_experiment(method=mg_pamoo_with_momentum, optimizer= torch.optim.Adam,lr=0.002)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict_sgd = {\n",
    "    'Equal Weighting': sgd_ew,\n",
    "    'PAMOO (Efroni et al. 2025)': sgd_pamoo, \n",
    "    'MG-AMOO w/ polyak': sgd_mg_amoo,\n",
    "    'MG-AMOO w/ polyak and momentum': sgd_mg_amoo_with_momentum,\n",
    "    'MG-AMOO w/ GD': sgd_mg_pamoo,\n",
    "    'MG-AMOO w/ GD and momentum': sgd_mg_pamoo_with_momentum,\n",
    "    \n",
    "}\n",
    "results_dict_adam = {\n",
    "    'Equal Weighting': adam_ew,\n",
    "    'PAMOO (Efroni et al. 2025)': adam_pamoo, \n",
    "    'MG-AMOO w/ polyak': adam_mg_amoo,\n",
    "    'MG-AMOO w/ polyak and momentum': adam_mg_amoo_with_momentum,\n",
    "    'MG-AMOO w/ GD': adam_mg_pamoo,\n",
    "    'MG-AMOO w/ GD and momentum': adam_mg_pamoo_with_momentum,\n",
    "    \n",
    "}\n",
    "\n",
    "results_dict = results_dict_sgd.copy()\n",
    "# add adam results but change key names\n",
    "for key in results_dict_adam.keys():\n",
    "    results_dict[key + \" (Adam)\"] = results_dict_adam[key]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results(results_dict, filename):\n",
    "    # Create a figure with 2 rows and 2 columns\n",
    "    fig, axs = plt.subplots(2, 2, figsize=(12, 8))\n",
    "    \n",
    "    # Plot individual losses\n",
    "    for method_name, results in results_dict.items():\n",
    "        axs[0, 0].plot(results['losses']['loss1'], label=method_name)\n",
    "        axs[0, 1].plot(results['losses']['loss2'], label=method_name)\n",
    "        axs[1, 0].plot(results['losses']['loss3'], label=method_name)\n",
    "        axs[1, 1].plot(results['weighted_losses'], label=method_name)\n",
    "    \n",
    "    # Set titles\n",
    "    if (EXAMPLE == \"switching\") or (EXAMPLE == \"quadratic-deep\"):\n",
    "        axs[0, 0].set_title('$f_1(x)=((y-t)^T H (y-t))^1$')\n",
    "        axs[0, 1].set_title('$f_2(x)=(y-t)^T H (y-t)^{1.5}$')\n",
    "        axs[1, 0].set_title('$f_3(x)=((y-t)^T H (y-t))^2$')\n",
    "    elif EXAMPLE == \"selection\":\n",
    "        axs[0, 0].set_title('$f_1(x)=(y-t)^T H_{1} (y-t)$')\n",
    "        axs[0, 1].set_title('$f_2(x)=(y-t)^T H_{0.01} (y-t)$')\n",
    "        axs[1, 0].set_title('$f_3(x)=(y-t)^T H_{0.0001} (y-t)$')\n",
    "    elif EXAMPLE == \"approximate_optimum\":\n",
    "        axs[0, 0].set_title('$f_1(x)=((y-t)^T H (y-t))$')\n",
    "        axs[0, 1].set_title('$f_2(x)=((y-t+\\epsilon)^T H (y-t+\\epsilon))$')\n",
    "        axs[1, 0].set_title('$f_3(x)=((y-t-\\epsilon)^T H (y-t-\\epsilon))$')\n",
    "    axs[1, 1].set_title('$f(x)= \\sum w_i f_i(x)$')\n",
    "\n",
    "    \n",
    "    # Set y-axis to log scale\n",
    "    for ax in axs.flat:\n",
    "        ax.set_yscale('log')\n",
    "        ax.legend()\n",
    "        #ax.set_ylim(1e-1, 1e2)\n",
    "    \n",
    "    # Layout so plots do not overlap\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(filename, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_results(results_dict_sgd, EXAMPLE + '_all_plots_sgd')\n",
    "plot_results(results_dict_adam, EXAMPLE + '_all_plots_adam')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results(results_dict,filename):\n",
    "    # Create a figure with 2 rows and 2 columns\n",
    "    fig, axs = plt.subplots(1, 1, figsize=(6, 4))\n",
    "    \n",
    "    # Plot individual losses\n",
    "    for method_name, results in results_dict.items():\n",
    "        axs.plot(results['max_gap'], label=method_name)\n",
    "\n",
    "    axs.set_ylabel(r'$\\max_{i} \\text{  }| F_i(x)$- F*|')\n",
    "    axs.set_xlabel('Gradient Steps')\n",
    "    # set limits if needed\n",
    "    #axs.set_ylim(1e-3, 2e0)\n",
    "    # axs.set_xlim(0, 200)\n",
    "    \n",
    "    # Set y-axis to log scale\n",
    "    axs.set_yscale('log')\n",
    "    axs.legend()\n",
    "\n",
    "    # Layout so plots do not overlap\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(filename, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_results(results_dict_sgd, EXAMPLE + '_max_gap_plot_sgd')\n",
    "plot_results(results_dict_adam, EXAMPLE + '_max_gap_plot_adam')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_weights(results, filename):\n",
    "    # Create a figure and axis\n",
    "    fig, ax1 = plt.subplots()\n",
    "    # Plot average loss on left y-axis\n",
    "    ax1.plot(results['max_gap'], color='blue')\n",
    "    ax1.set_xlabel('Gradient Steps')\n",
    "    ax1.set_ylabel('Mean Squared Error', color='blue')\n",
    "    ax1.tick_params(axis='y', labelcolor='blue')\n",
    "    ax1.grid(False)  # Turn off grid for left subplot\n",
    "    # ax1.set_yscale('log')\n",
    "    # Create a new y-axis on the right\n",
    "    ax2 = ax1.twinx()\n",
    "    colors = plt.cm.Reds(np.linspace(0.5, 1, 3))  # Generate 3 different red tones\n",
    "    if (EXAMPLE == \"switching\") or (EXAMPLE == \"quadratic-deep\"):\n",
    "        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=((y-t)^T H (y-t))$')\n",
    "        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=(y-t)^T H (y-t)^{1.5}$')\n",
    "        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=((y-t)^T H (y-t))^2$')\n",
    "    elif EXAMPLE == \"selection\":\n",
    "        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=(y-t)^T H_{1} (y-t)$')\n",
    "        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=(y-t)^T H_{0.01} (y-t)$')\n",
    "        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=(y-t)^T H_{0.0001} (y-t)$')\n",
    "    elif EXAMPLE == \"approximate_optimum\" or EXAMPLE == \"all\":\n",
    "        ax2.plot(results['weights']['w1'], color=colors[0], label='$w_1$  scales $f_1(x)=((y-t)^T H (y-t))$')\n",
    "        ax2.plot(results['weights']['w2'], color=colors[1], label='$w_2$  scales $f_2(x)=((y-t+\\epsilon)^T H (y-t+\\epsilon))$')\n",
    "        ax2.plot(results['weights']['w3'], color=colors[2], label='$w_3$  scales $f_3(x)=((y-t-\\epsilon)^T H (y-t-\\epsilon))$')\n",
    "    ax2.set_ylabel('Weight for respective Loss Function', color='red')\n",
    "    ax2.tick_params(axis='y', labelcolor='red')\n",
    "    ax2.grid(False)  # Turn off grid for right subplot\n",
    "    # Add legend\n",
    "    ax2.legend(loc='upper right')\n",
    "    plt.savefig(filename, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_weights(adam_pamoo, EXAMPLE + '_weights_PAMOO-ADAM')\n",
    "plot_weights(adam_mg_pamoo_with_momentum, EXAMPLE + '_weights_POLYAK-ADAM-with-momentum')\n",
    "plot_weights(adam_mg_pamoo, EXAMPLE + '_weights_POLYAK-ADAM')\n",
    "\n",
    "plot_weights(sgd_pamoo, EXAMPLE + '_weights_PAMOO-sgd')\n",
    "plot_weights(sgd_mg_pamoo_with_momentum, EXAMPLE + '_weights_mg-pamoo-sgd-with-momentum')\n",
    "plot_weights(sgd_mg_pamoo, EXAMPLE + '_weights_mg-pamoo-sgd')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "moo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
