<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
<meta name="generator" content="pdoc 0.10.0" />
<title>fathom.training.structured_flags API documentation</title>
<meta name="description" content="Addendum to fedjax.training.structured_flags …" />
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
<link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Module <code>fathom.training.structured_flags</code></h1>
</header>
<section id="section-intro">
<p>Addendum to fedjax.training.structured_flags.</p>
<p>Structured flags commonly used in experiment binaries.</p>
<p>Structured flags are often used to construct complex structures via multiple
simple flags (e.g. an optimizer can be created by controlling learning rate and
other hyper parameters).</p>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python"># Copyright 2022 FATHOM Authors
#
# Licensed under the Apache License, Version 2.0 (the &#34;License&#34;);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an &#34;AS IS&#34; BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
&#34;&#34;&#34;Addendum to fedjax.training.structured_flags.

Structured flags commonly used in experiment binaries.

Structured flags are often used to construct complex structures via multiple
simple flags (e.g. an optimizer can be created by controlling learning rate and
other hyper parameters).
&#34;&#34;&#34;

import sys
from typing import Optional, Tuple

from absl import flags
import jax
import jax.numpy as jnp
import fedjax
from fedjax.core import optimizers
from fedjax.core import client_datasets
from fedjax.training.structured_flags import NamedFlags
import fathom
from fathom.algorithms.fathom_fedavg import HyperParams

FLAGS = flags.FLAGS


class OptimizerFlags(fedjax.training.structured_flags.OptimizerFlags):
    &#34;&#34;&#34;Constructs a fathom.core.Optimizer from flags.&#34;&#34;&#34;
    &#34;&#34;&#34;Note: OptimizerFlags is being re-defined because the one from FedJax uses fedjax.core.optimizers,&#34;&#34;&#34;
    &#34;&#34;&#34;but we need to instantiate with fathom.core.optimizers.&#34;&#34;&#34;

    def get(self) -&gt; optimizers.Optimizer:
        &#34;&#34;&#34;Gets the specified optimizer.&#34;&#34;&#34;
        optimizer_name = self._get_flag(&#39;optimizer&#39;)
        learning_rate = self._get_flag(&#39;learning_rate&#39;)
        if optimizer_name == &#39;sgd&#39;:
            return fathom.core.optimizers.sgd(learning_rate)
        elif optimizer_name == &#39;momentum&#39;:
            return fathom.core.optimizers.sgd(learning_rate, self._get_flag(&#39;momentum&#39;))
        elif optimizer_name == &#39;adam&#39;:
            return fathom.core.optimizers.adam(learning_rate, self._get_flag(&#39;adam_beta1&#39;),
                self._get_flag(&#39;adam_beta2&#39;),
                self._get_flag(&#39;adam_epsilon&#39;),
            )
        elif optimizer_name == &#39;rmsprop&#39;:
            return fathom.core.optimizers.rmsprop(learning_rate, 
                self._get_flag(&#39;rmsprop_decay&#39;),
                self._get_flag(&#39;rmsprop_epsilon&#39;),
            )
        elif optimizer_name == &#39;adagrad&#39;:
            return fathom.core.optimizers.adagrad(learning_rate, eps=self._get_flag(&#39;adagrad_epsilon&#39;))
        elif optimizer_name == &#39;yogi&#39;:
            return fathom.core.optimizers.yogi(learning_rate, 
                self._get_flag(&#39;yogi_beta1&#39;),
                self._get_flag(&#39;yogi_beta2&#39;),
                self._get_flag(&#39;yogi_epsilon&#39;)
            )
        else:
            raise ValueError(f&#39;Unsupported optimizer {optimizer_name!r} from &#39;
                f&#39;--{self._prefix}optimizer.&#39;)

class ShuffleRepeatBatchHParamsFlags(NamedFlags):
    &#34;&#34;&#34;Constructs ShuffleRepeatBatchHParams from flags.&#34;&#34;&#34;

    def __init__(self, name: Optional[str] = None, default_batch_size: int = 128, default_batch_seed: int = 123):
        super().__init__(name)
        defaults = client_datasets.ShuffleRepeatBatchHParams(batch_size=-1)
        # TODO(wuke): Support other fields.
        self._integer(&#39;batch_size&#39;, default_batch_size, &#39;Batch size&#39;)
        self._integer(&#39;batch_seed&#39;, default_batch_seed, &#39;Batch seed&#39;)

    def get(self):
        return client_datasets.ShuffleRepeatBatchHParams(
            batch_size=self._get_flag(&#39;batch_size&#39;),
            seed=jax.random.PRNGKey(self._get_flag(&#39;batch_seed&#39;))
        )


class FathomFlags(NamedFlags):
    &#34;&#34;&#34;Constructs HyperParams and a fathom.optimizer from flags.&#34;&#34;&#34;

    def __init__(self, name: Optional[str] = None, 
        default_learning_rate: float = 0.1, default_epochs: float = 1.0, default_batch_size: float = 16.0,
        default_alpha: float = 0.5, default_eta_h: float = 1.0,
        default_eta_h012: Tuple[float] = (0.01, 0.01, 0.1), default_ub: Tuple[float] = (10.0, 0.5, 5096),
    ):
        super().__init__(name)
        self._float(&#39;initial_learning_rate&#39;, default_learning_rate, &#39;Initial learning rate&#39;)
        self._float(&#39;initial_epochs&#39;, default_epochs, &#39;Initial epochs&#39;)
        self._float(&#39;initial_batch_size&#39;, default_batch_size, &#39;Initial batch size&#39;)
        self._float(&#39;alpha&#39;, default_alpha, &#39;Fathom alpha&#39;)
        self._float(&#39;eta_h&#39;, default_alpha, &#39;Fathom eta_h&#39;)
        self._float(&#39;eta_h0&#39;, default_alpha, &#39;Fathom eta_h0&#39;)
        self._float(&#39;eta_h1&#39;, default_alpha, &#39;Fathom eta_h1&#39;)
        self._float(&#39;eta_h2&#39;, default_alpha, &#39;Fathom eta_h2&#39;)
        self._float(&#39;Ep_ub&#39;, default_alpha, &#39;Fathom E_ub&#39;)
        self._float(&#39;eta_c_ub&#39;, default_alpha, &#39;Fathom eta_c_ub&#39;)
        self._float(&#39;bs_ub&#39;, default_alpha, &#39;Fathom bs_ub&#39;)

    def get(self) -&gt; Tuple[optimizers.Optimizer, HyperParams, client_datasets.ShuffleRepeatBatchHParams]:
        fathom_opt = fathom.core.optimizers.sgd(learning_rate = self._get_flag(&#39;eta_h&#39;))
        fathom_hparams = HyperParams(
            eta_c = float(self._get_flag(&#39;initial_learning_rate&#39;)),
            Ep = float(self._get_flag(&#39;initial_epochs&#39;)), # Initialize with 1 epoch&#39;s worth of data
            bs = float(self._get_flag(&#39;initial_batch_size&#39;)),
            alpha = float(self._get_flag(&#39;alpha&#39;)),
            eta_h = jnp.array([
                self._get_flag(&#39;eta_h0&#39;), 
                self._get_flag(&#39;eta_h1&#39;), 
                self._get_flag(&#39;eta_h2&#39;)
            ]),
            hparam_ub = jnp.array([
                self._get_flag(&#39;Ep_ub&#39;),
                self._get_flag(&#39;eta_c_ub&#39;),
                self._get_flag(&#39;bs_ub&#39;),
            ]),
        )
        SRBatchHParams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size = round(self._get_flag(&#39;initial_batch_size&#39;)),
            seed = jax.random.PRNGKey(FLAGS.client_batch_seed),
        )
        return fathom_opt, fathom_hparams, SRBatchHParams</code></pre>
</details>
</section>
<section>
</section>
<section>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="fathom.training.structured_flags.FathomFlags"><code class="flex name class">
<span>class <span class="ident">FathomFlags</span></span>
<span>(</span><span>name: Optional[str] = None, default_learning_rate: float = 0.1, default_epochs: float = 1.0, default_batch_size: float = 16.0, default_alpha: float = 0.5, default_eta_h: float = 1.0, default_eta_h012: Tuple[float] = (0.01, 0.01, 0.1), default_ub: Tuple[float] = (10.0, 0.5, 5096))</span>
</code></dt>
<dd>
<div class="desc"><p>Constructs HyperParams and a fathom.optimizer from flags.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class FathomFlags(NamedFlags):
    &#34;&#34;&#34;Constructs HyperParams and a fathom.optimizer from flags.&#34;&#34;&#34;

    def __init__(self, name: Optional[str] = None, 
        default_learning_rate: float = 0.1, default_epochs: float = 1.0, default_batch_size: float = 16.0,
        default_alpha: float = 0.5, default_eta_h: float = 1.0,
        default_eta_h012: Tuple[float] = (0.01, 0.01, 0.1), default_ub: Tuple[float] = (10.0, 0.5, 5096),
    ):
        super().__init__(name)
        self._float(&#39;initial_learning_rate&#39;, default_learning_rate, &#39;Initial learning rate&#39;)
        self._float(&#39;initial_epochs&#39;, default_epochs, &#39;Initial epochs&#39;)
        self._float(&#39;initial_batch_size&#39;, default_batch_size, &#39;Initial batch size&#39;)
        self._float(&#39;alpha&#39;, default_alpha, &#39;Fathom alpha&#39;)
        self._float(&#39;eta_h&#39;, default_alpha, &#39;Fathom eta_h&#39;)
        self._float(&#39;eta_h0&#39;, default_alpha, &#39;Fathom eta_h0&#39;)
        self._float(&#39;eta_h1&#39;, default_alpha, &#39;Fathom eta_h1&#39;)
        self._float(&#39;eta_h2&#39;, default_alpha, &#39;Fathom eta_h2&#39;)
        self._float(&#39;Ep_ub&#39;, default_alpha, &#39;Fathom E_ub&#39;)
        self._float(&#39;eta_c_ub&#39;, default_alpha, &#39;Fathom eta_c_ub&#39;)
        self._float(&#39;bs_ub&#39;, default_alpha, &#39;Fathom bs_ub&#39;)

    def get(self) -&gt; Tuple[optimizers.Optimizer, HyperParams, client_datasets.ShuffleRepeatBatchHParams]:
        fathom_opt = fathom.core.optimizers.sgd(learning_rate = self._get_flag(&#39;eta_h&#39;))
        fathom_hparams = HyperParams(
            eta_c = float(self._get_flag(&#39;initial_learning_rate&#39;)),
            Ep = float(self._get_flag(&#39;initial_epochs&#39;)), # Initialize with 1 epoch&#39;s worth of data
            bs = float(self._get_flag(&#39;initial_batch_size&#39;)),
            alpha = float(self._get_flag(&#39;alpha&#39;)),
            eta_h = jnp.array([
                self._get_flag(&#39;eta_h0&#39;), 
                self._get_flag(&#39;eta_h1&#39;), 
                self._get_flag(&#39;eta_h2&#39;)
            ]),
            hparam_ub = jnp.array([
                self._get_flag(&#39;Ep_ub&#39;),
                self._get_flag(&#39;eta_c_ub&#39;),
                self._get_flag(&#39;bs_ub&#39;),
            ]),
        )
        SRBatchHParams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size = round(self._get_flag(&#39;initial_batch_size&#39;)),
            seed = jax.random.PRNGKey(FLAGS.client_batch_seed),
        )
        return fathom_opt, fathom_hparams, SRBatchHParams</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>fedjax.training.structured_flags.NamedFlags</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="fathom.training.structured_flags.FathomFlags.get"><code class="name flex">
<span>def <span class="ident">get</span></span>(<span>self) ‑> Tuple[fedjax.core.optimizers.Optimizer, <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="../algorithms/fathom_fedavg.html#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a>, fedjax.core.client_datasets.ShuffleRepeatBatchHParams]</span>
</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def get(self) -&gt; Tuple[optimizers.Optimizer, HyperParams, client_datasets.ShuffleRepeatBatchHParams]:
    fathom_opt = fathom.core.optimizers.sgd(learning_rate = self._get_flag(&#39;eta_h&#39;))
    fathom_hparams = HyperParams(
        eta_c = float(self._get_flag(&#39;initial_learning_rate&#39;)),
        Ep = float(self._get_flag(&#39;initial_epochs&#39;)), # Initialize with 1 epoch&#39;s worth of data
        bs = float(self._get_flag(&#39;initial_batch_size&#39;)),
        alpha = float(self._get_flag(&#39;alpha&#39;)),
        eta_h = jnp.array([
            self._get_flag(&#39;eta_h0&#39;), 
            self._get_flag(&#39;eta_h1&#39;), 
            self._get_flag(&#39;eta_h2&#39;)
        ]),
        hparam_ub = jnp.array([
            self._get_flag(&#39;Ep_ub&#39;),
            self._get_flag(&#39;eta_c_ub&#39;),
            self._get_flag(&#39;bs_ub&#39;),
        ]),
    )
    SRBatchHParams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size = round(self._get_flag(&#39;initial_batch_size&#39;)),
        seed = jax.random.PRNGKey(FLAGS.client_batch_seed),
    )
    return fathom_opt, fathom_hparams, SRBatchHParams</code></pre>
</details>
</dd>
</dl>
</dd>
<dt id="fathom.training.structured_flags.OptimizerFlags"><code class="flex name class">
<span>class <span class="ident">OptimizerFlags</span></span>
<span>(</span><span>name: Optional[str] = None, default_optimizer: str = 'sgd')</span>
</code></dt>
<dd>
<div class="desc"><p>Constructs a fathom.core.Optimizer from flags.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class OptimizerFlags(fedjax.training.structured_flags.OptimizerFlags):
    &#34;&#34;&#34;Constructs a fathom.core.Optimizer from flags.&#34;&#34;&#34;
    &#34;&#34;&#34;Note: OptimizerFlags is being re-defined because the one from FedJax uses fedjax.core.optimizers,&#34;&#34;&#34;
    &#34;&#34;&#34;but we need to instantiate with fathom.core.optimizers.&#34;&#34;&#34;

    def get(self) -&gt; optimizers.Optimizer:
        &#34;&#34;&#34;Gets the specified optimizer.&#34;&#34;&#34;
        optimizer_name = self._get_flag(&#39;optimizer&#39;)
        learning_rate = self._get_flag(&#39;learning_rate&#39;)
        if optimizer_name == &#39;sgd&#39;:
            return fathom.core.optimizers.sgd(learning_rate)
        elif optimizer_name == &#39;momentum&#39;:
            return fathom.core.optimizers.sgd(learning_rate, self._get_flag(&#39;momentum&#39;))
        elif optimizer_name == &#39;adam&#39;:
            return fathom.core.optimizers.adam(learning_rate, self._get_flag(&#39;adam_beta1&#39;),
                self._get_flag(&#39;adam_beta2&#39;),
                self._get_flag(&#39;adam_epsilon&#39;),
            )
        elif optimizer_name == &#39;rmsprop&#39;:
            return fathom.core.optimizers.rmsprop(learning_rate, 
                self._get_flag(&#39;rmsprop_decay&#39;),
                self._get_flag(&#39;rmsprop_epsilon&#39;),
            )
        elif optimizer_name == &#39;adagrad&#39;:
            return fathom.core.optimizers.adagrad(learning_rate, eps=self._get_flag(&#39;adagrad_epsilon&#39;))
        elif optimizer_name == &#39;yogi&#39;:
            return fathom.core.optimizers.yogi(learning_rate, 
                self._get_flag(&#39;yogi_beta1&#39;),
                self._get_flag(&#39;yogi_beta2&#39;),
                self._get_flag(&#39;yogi_epsilon&#39;)
            )
        else:
            raise ValueError(f&#39;Unsupported optimizer {optimizer_name!r} from &#39;
                f&#39;--{self._prefix}optimizer.&#39;)</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>fedjax.training.structured_flags.OptimizerFlags</li>
<li>fedjax.training.structured_flags.NamedFlags</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="fathom.training.structured_flags.OptimizerFlags.get"><code class="name flex">
<span>def <span class="ident">get</span></span>(<span>self) ‑> fedjax.core.optimizers.Optimizer</span>
</code></dt>
<dd>
<div class="desc"><p>Gets the specified optimizer.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def get(self) -&gt; optimizers.Optimizer:
    &#34;&#34;&#34;Gets the specified optimizer.&#34;&#34;&#34;
    optimizer_name = self._get_flag(&#39;optimizer&#39;)
    learning_rate = self._get_flag(&#39;learning_rate&#39;)
    if optimizer_name == &#39;sgd&#39;:
        return fathom.core.optimizers.sgd(learning_rate)
    elif optimizer_name == &#39;momentum&#39;:
        return fathom.core.optimizers.sgd(learning_rate, self._get_flag(&#39;momentum&#39;))
    elif optimizer_name == &#39;adam&#39;:
        return fathom.core.optimizers.adam(learning_rate, self._get_flag(&#39;adam_beta1&#39;),
            self._get_flag(&#39;adam_beta2&#39;),
            self._get_flag(&#39;adam_epsilon&#39;),
        )
    elif optimizer_name == &#39;rmsprop&#39;:
        return fathom.core.optimizers.rmsprop(learning_rate, 
            self._get_flag(&#39;rmsprop_decay&#39;),
            self._get_flag(&#39;rmsprop_epsilon&#39;),
        )
    elif optimizer_name == &#39;adagrad&#39;:
        return fathom.core.optimizers.adagrad(learning_rate, eps=self._get_flag(&#39;adagrad_epsilon&#39;))
    elif optimizer_name == &#39;yogi&#39;:
        return fathom.core.optimizers.yogi(learning_rate, 
            self._get_flag(&#39;yogi_beta1&#39;),
            self._get_flag(&#39;yogi_beta2&#39;),
            self._get_flag(&#39;yogi_epsilon&#39;)
        )
    else:
        raise ValueError(f&#39;Unsupported optimizer {optimizer_name!r} from &#39;
            f&#39;--{self._prefix}optimizer.&#39;)</code></pre>
</details>
</dd>
</dl>
</dd>
<dt id="fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags"><code class="flex name class">
<span>class <span class="ident">ShuffleRepeatBatchHParamsFlags</span></span>
<span>(</span><span>name: Optional[str] = None, default_batch_size: int = 128, default_batch_seed: int = 123)</span>
</code></dt>
<dd>
<div class="desc"><p>Constructs ShuffleRepeatBatchHParams from flags.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class ShuffleRepeatBatchHParamsFlags(NamedFlags):
    &#34;&#34;&#34;Constructs ShuffleRepeatBatchHParams from flags.&#34;&#34;&#34;

    def __init__(self, name: Optional[str] = None, default_batch_size: int = 128, default_batch_seed: int = 123):
        super().__init__(name)
        defaults = client_datasets.ShuffleRepeatBatchHParams(batch_size=-1)
        # TODO(wuke): Support other fields.
        self._integer(&#39;batch_size&#39;, default_batch_size, &#39;Batch size&#39;)
        self._integer(&#39;batch_seed&#39;, default_batch_seed, &#39;Batch seed&#39;)

    def get(self):
        return client_datasets.ShuffleRepeatBatchHParams(
            batch_size=self._get_flag(&#39;batch_size&#39;),
            seed=jax.random.PRNGKey(self._get_flag(&#39;batch_seed&#39;))
        )</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>fedjax.training.structured_flags.NamedFlags</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags.get"><code class="name flex">
<span>def <span class="ident">get</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def get(self):
    return client_datasets.ShuffleRepeatBatchHParams(
        batch_size=self._get_flag(&#39;batch_size&#39;),
        seed=jax.random.PRNGKey(self._get_flag(&#39;batch_seed&#39;))
    )</code></pre>
</details>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<h1>Index</h1>
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3>Super-module</h3>
<ul>
<li><code><a title="fathom.training" href="index.html">fathom.training</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="fathom.training.structured_flags.FathomFlags" href="#fathom.training.structured_flags.FathomFlags">FathomFlags</a></code></h4>
<ul class="">
<li><code><a title="fathom.training.structured_flags.FathomFlags.get" href="#fathom.training.structured_flags.FathomFlags.get">get</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="fathom.training.structured_flags.OptimizerFlags" href="#fathom.training.structured_flags.OptimizerFlags">OptimizerFlags</a></code></h4>
<ul class="">
<li><code><a title="fathom.training.structured_flags.OptimizerFlags.get" href="#fathom.training.structured_flags.OptimizerFlags.get">get</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags" href="#fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags">ShuffleRepeatBatchHParamsFlags</a></code></h4>
<ul class="">
<li><code><a title="fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags.get" href="#fathom.training.structured_flags.ShuffleRepeatBatchHParamsFlags.get">get</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.10.0</a>.</p>
</footer>
</body>
</html>