<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <link rel="stylesheet" href="Gpt.css">
    <link rel="icon" type="image/x-icon" href="Gpt/favicon3.png">
    <title>Learning to Learn with Generative Models of Neural Network Checkpoints</title>
</head>
<body>
<div id="title_slide">
    <div class="gradient-bg"></div>
    <div class="title_left">
        <h1>Learning to Learn with Generative Models<br> of Neural Network Checkpoints</h1>
        <div id="abstract" class="grid-container">
            <p>
                We explore a data-driven approach for learning to optimize neural networks. We construct a dataset of
                neural network checkpoints and train a generative model on the parameters. In particular, our model is
                a conditional diffusion transformer that, given an initial input parameter vector and a prompted loss,
                error, or return, predicts the distribution over parameter updates that achieve the desired metric. At
                test time, it can optimize neural networks with unseen parameters for downstream tasks in just one
                update. We apply our method to different neural network architectures and tasks in supervised and
                reinforcement learning.
            </p>
        </div>
    </div>
</div>
<hr class="rounded">
<div id="overview">
    <h1>Generative Models of Checkpoints</h1>
    <p>
        Over the last decade, the deep learning community has generated a massive amount of neural network checkpoints. They contain
        a wealth of information: diverse parameter configurations and rich metrics such as test losses, classification errors and
        reinforcement learning returns that describe the quality of the checkpoint. We pre-train a generative model on millions of
        checkpoints. At test time, we use it to generate parameters for a neural network
        that solves a downstream task.
    </p>
    <video loop autoplay muted playsinline preload="metadata">
        <source src="Gpt/website_model.mp4" type="video/mp4">
    </video>
    <p>
        We refer to our model as G.pt (G and .pt refer to generative models and checkpoint extensions, respectively).
        G.pt is trained as a transformed-based diffusion model of neural network parameters. It takes as input a starting
        (potentially randomly-initialized) parameter vector, the starting loss/error/return and a user's <i>desired</i> loss/error/return.
        Once trained, we can sample a parameter update from the model that matches the user's prompted metric. The backbone of
        our diffusion model is a transformer that operates on sequences of neural net parameters. Similar to ViTs, our
        G.pt models leverage relatively minimal domain-specific inductive biases. In this paper, we train
        a G.pt model for a given dataset, metric and architecture
        (e.g., a loss-conditional CIFAR-10 CNN model, an error-conditional MNIST MLP model, etc.).
    </p>
    <p>
        We will release our pre-training dataset containing 23M neural network checkpoints across
        100,000+ training runs. The dataset contains trained MLPs and CNNs for vision tasks (MNIST, CIFAR-10) and
        continuous control tasks (Cartpole).
    </p>
    <div>
        <h1>Prompting for Losses, Errors and Returns</h1>
        <p>
            In contrast to hand-designed optimizers (and prior work on learned optimizers), G.pt takes as input a user's
            desired loss, error, return, etc. Conditioning on these metrics lets us train on checkpoints with good and bad performance alike.
            After training, you can prompt G.pt with whatever value you want. Below, we prompt it with a
            large range of test losses (MNIST model) and returns (Cartpole model) over the course of training. By the end of training,
            our method succesfully generates networks that achieve nearly the full range of prompts.
        </p>
        <div class="prompting">
            <embed type="text/html" class="prompt_left"  src="Gpt/prompting_mnist_loss.html">
            <embed type="text/html" class="prompt_right" src="Gpt/prompting_cartpole.html">
        </div>
    </div>


    <h1>Training in One Step</h1>
    <p>
        By prompting it for minimal loss or maximal return, G.pt can optimize unseen parameter vectors in one step.
        For example, a Cartpole model with randomly-initialized parameters can be rapidly trained with just one update.
    </p>
    <div class="cartpole">
        <video autoplay muted playsinline loop preload="metadata">
            <source src="Gpt/cartpole-init.mp4" type="video/mp4">
        </video>
        <img src="Gpt/arrow_website.png" alt="One G.pt update.">
        <video autoplay muted playsinline loop preload="metadata">
            <source src="Gpt/cartpole-evolved.mp4" type="video/mp4">
        </video>
    </div>
    <p>
        Compared to iterative gradient-based optimizers like Adam or Nesterov Momentum, our approach only needs one parameter update
        to achieve good results. We show examples for two G.pts, one conditioned on MNIST classification error and the
        other on Cartpole return.
    </p>
    <div class="zero_shot">
        <video autoplay muted playsinline preload="metadata">
            <source src="Gpt/mnist_error_zero.mp4" type="video/mp4">
        </video>
        <video autoplay muted playsinline preload="metadata">
            <source src="Gpt/cartpole_zero.mp4" type="video/mp4">
        </video>
    </div>

    <p>
        Another advantage of our method over traditional optimizers is that it can directly optimize non-differentiable metrics,
        like classification errors or RL returns. This is because G.pt is trained as a diffusion model in parameter
        space, and so we never need to backpropagate through these metrics in order to train our model.
    </p>

    <h1>Capturing Parameter Multimodality</h1>
    <p>
        Many different neural net parameter vectors can yield the same loss. This creates ambiguity: if we ask G.pt for
        a small loss, which parameter solution should it choose? Since G.pt is a diffusion model, we can sample
        different solutions. Below, we show the test error landscape of MNIST MLPs as a function of the top two parameter-space
        PCA directions. We plot several sampled solutions (prompting for low error) in this landscape. Note that the samples
        cover distinct local minima in the error landscape. You can zoom and pan to explore the landscape.
    </p>
    <div class="error_surface">
        <embed type="text/html" src="Gpt/error_surface.html">
    </div>
    <p> <i>A latent walk through parameter space.</i> We can also visualize different parameter samples directly. For example,
        we can visualize generated conv1 weights for our CIFAR-10 error-conditional model. We prompt for low test
        error and fix all other inputs except the input latent code. This lets us take a latent walk through parameter space.
        Below are interpolations for 16 conv1 filters predicted by G.pt. The latents control subtle variations in each
        filter as well as their ordering.
    </p>

    <video class="latent_walk" loop autoplay muted playsinline preload="metadata">
        <source src="Gpt/interp_plain_cifar10_error.mp4" type="video/mp4">
    </video>
</div>
<script type="text/javascript">
    /* https://stackoverflow.com/questions/3027707/how-to-change-the-playing-speed-of-videos-in-html5 */
    document.querySelector('video').defaultPlaybackRate = 0.65;
    document.querySelector('video').play();

    var videos =document.querySelectorAll('video');
    for (var i=0;i<1;i++)
    {
        videos[i].playbackRate = 0.65;
    }
</script>
<script>
    /* https://stackoverflow.com/questions/21163756/html5-and-javascript-to-play-videos-only-when-visible */
    var videos = document.getElementsByTagName("video");

    function checkScroll() {
        var fraction = 0.5; // Play when 70% of the player is visible.

        for(var i = 0; i < 1; i++) {  // only apply to the first video

            var video = videos[i];

            var x = video.offsetLeft, y = video.offsetTop, w = video.offsetWidth, h = video.offsetHeight, r = x + w, //right
                b = y + h, //bottom
                visibleX, visibleY, visible;

            visibleX = Math.max(0, Math.min(w, window.pageXOffset + window.innerWidth - x, r - window.pageXOffset));
            visibleY = Math.max(0, Math.min(h, window.pageYOffset + window.innerHeight - y, b - window.pageYOffset));

            visible = visibleX * visibleY / (w * h);

            if (visible > fraction) {
                video.play();
            } else {
                video.pause();
            }

        }

    }
    window.addEventListener('scroll', checkScroll, false);
    window.addEventListener('resize', checkScroll, false);
</script>
</body>
</html>
