<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
<meta name="generator" content="pdoc 0.10.0" />
<title>fathom.algorithms.fathom_fedavg API documentation</title>
<meta name="description" content="" />
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
<link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Module <code>fathom.algorithms.fathom_fedavg</code></h1>
</header>
<section id="section-intro">
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python"># Copyright 2022 FATHOM Authors
#
# Licensed under the Apache License, Version 2.0 (the &#34;License&#34;);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an &#34;AS IS&#34; BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Mapping, Sequence, Tuple, Union, Optional

from fedjax.core import client_datasets
from fedjax.core import dataclasses
from fedjax.core import federated_algorithm
from fedjax.core import federated_data
from fedjax.core import for_each_client
from fedjax.core import optimizers
from fedjax.core import tree_util
from fedjax.core.typing import BatchExample
from fedjax.core.typing import Params, PyTree
from fedjax.core.typing import PRNGKey
from fedjax.core.typing import OptState, BatchExample
import jax
import jax.numpy as jnp
import fathom

Grads = Params

def create_train_for_each_client(grad_fn: Params, client_optimizer: optimizers.Optimizer):
    &#34;&#34;&#34;Builds client_init, client_step, client_final for for_each_client.&#34;&#34;&#34;
    normalize_hypergrad = 1

    def client_init(shared_input, client_rng):
        opt_state = client_optimizer.init(shared_input[&#39;params&#39;])
        opt_state.hyperparams[&#39;learning_rate&#39;] = shared_input[&#39;eta_c&#39;] # Already sigmoided
        client_step_state = {
            &#39;params&#39;: shared_input[&#39;params&#39;],
            &#39;params0&#39;: shared_input[&#39;params&#39;],
            &#39;opt_state&#39;: opt_state,
            &#39;rng&#39;: client_rng,
            &#39;step_idx&#39;: 0,
            &#39;min_hypergrad&#39;: tree_util.tree_l2_norm(shared_input[&#39;params&#39;]), # Large enough value
            }
        return client_step_state

    def client_step(client_step_state, batch):
        rng, use_rng, metrics_rng = jax.random.split(client_step_state[&#39;rng&#39;], num=3)
        grad_opt = grad_fn(client_step_state[&#39;params&#39;], batch, use_rng)
        opt_state, params = client_optimizer.apply(grad_opt, client_step_state[&#39;opt_state&#39;], client_step_state[&#39;params&#39;])
        delta_params = jax.tree_util.tree_multimap(jnp.subtract, client_step_state[&#39;params0&#39;], params)
        grad_opt_norm = tree_util.tree_l2_norm(grad_opt)
        delta_params_norm = tree_util.tree_l2_norm(delta_params)
        hypergrad = jnp.where(grad_opt_norm == 0., jnp.array(0),
            jnp.where(delta_params_norm == 0, jnp.array(0),
                jnp.where(normalize_hypergrad == 0 ,
                    fathom.core.tree_util.tree_dot(grad_opt, delta_params),
                    fathom.core.tree_util.tree_dot(grad_opt, delta_params) / grad_opt_norm / delta_params_norm)
            )
        )
        min_hypergrad = jnp.where(hypergrad &gt; client_step_state[&#39;min_hypergrad&#39;],
            client_step_state[&#39;min_hypergrad&#39;], hypergrad
        )
        next_client_step_state = {
            &#39;params&#39;: params,
            &#39;params0&#39;: client_step_state[&#39;params0&#39;],
            &#39;opt_state&#39;: opt_state,
            &#39;rng&#39;: rng,
            &#39;step_idx&#39;: client_step_state[&#39;step_idx&#39;] + 1,
            &#39;min_hypergrad&#39;: min_hypergrad,
        }
        return next_client_step_state

    def client_final(shared_input, client_step_state) -&gt; Tuple[Params, jnp.ndarray]:
        delta_params = jax.tree_util.tree_multimap(jnp.subtract, shared_input[&#39;params&#39;], client_step_state[&#39;params&#39;])
        return delta_params, client_step_state[&#39;min_hypergrad&#39;], client_step_state[&#39;step_idx&#39;]

    return for_each_client.for_each_client(client_init, client_step, client_final)


@dataclasses.dataclass
class HyperParams:
    eta_c: float                # Local learning rate or step size
    Ep: float                  # Number of local steps = ceil(local_samples / batch_size) * Ep, where Ep ~ num epochs worth of data.
    bs: float                   # Local batch size
    alpha: float                # Momentum for glob grad estimation
    eta_h: jnp.ndarray          # Hyper optimizer learning rates for Ep, eta_c, and bs, respectively
    hparam_ub: jnp.ndarray      # Upperbound vals for Ep, eta_c, and bs, respectively.


@dataclasses.dataclass
class HyperState:
    hyperparams: HyperParams 
    init_hparams: HyperParams
    opt_state: optimizers.OptState
    opt_param: jnp.ndarray
    hypergrad_glob: float
    hypergrad_local: float


@dataclasses.dataclass
class ServerState:
    &#34;&#34;&#34;State of server passed between rounds.

    Attributes:
        params: A pytree representing the server model parameters.
        opt_state: A pytree representing the server optimizer state.
    &#34;&#34;&#34;
    params: Params
    params_bak: Params
    opt_state: optimizers.OptState
    round_index: int
    hyper_state: HyperState
    grad_glob: Params


@jax.jit
def estimate_grad_glob(server_state: ServerState, mean_delta_params: Params) -&gt; Params:
    grad_glob = tree_util.tree_weight(server_state.grad_glob, server_state.hyper_state.hyperparams.alpha)
    delta_params = tree_util.tree_weight(mean_delta_params, 1. - server_state.hyper_state.hyperparams.alpha)
    grad_glob = tree_util.tree_add(grad_glob, delta_params)
    # grad_glob = fathom.core.tree_util.tree_inverse_weight(
    #     grad_glob, 
    #     (1. - server_state.hyper_state.hyperparams.alpha ** server_state.round_index)
    # )
    return grad_glob


def federated_averaging(
        grad_fn: Callable[[Params, BatchExample, PRNGKey], Grads],
        client_optimizer: optimizers.Optimizer,
        server_optimizer: optimizers.Optimizer,
        hyper_optimizer: optimizers.Optimizer,
        client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
        fathom_init_hparams: HyperParams,
) -&gt; federated_algorithm.FederatedAlgorithm:
    &#34;&#34;&#34;Builds federated averaging.

    Args:
        grad_fn: A function from (params, batch_example, rng) to gradients.
            This can be created with :func:`fedjax.core.model.model_grad`.
        client_optimizer: Optimizer for local client training.
        server_optimizer: Optimizer for server update.
        client_batch_hparams: Hyperparameters for batching client dataset for train.

    Returns:
        FederatedAlgorithm
    &#34;&#34;&#34;
    train_for_each_client = create_train_for_each_client(grad_fn, client_optimizer)

    def server_reset(params: Params, init_hparams: HyperParams) -&gt; ServerState:
        opt_state_server = server_optimizer.init(params)
        opt_param_hyper = jnp.log(jnp.array([init_hparams.Ep, init_hparams.eta_c, init_hparams.bs]))
        opt_state_hyper = hyper_optimizer.init(opt_param_hyper)
        hyper_state = HyperState(
            hyperparams = init_hparams,
            init_hparams = init_hparams,
            opt_state = opt_state_hyper,
            opt_param = opt_param_hyper,
            hypergrad_glob = 0.,
            hypergrad_local = 0.,
        )
        # Need to initialize round_index to 1 for bias comp
        return ServerState(
            params = params, 
            params_bak = params,
            opt_state = opt_state_server, 
            round_index = 1, 
            grad_glob = tree_util.tree_zeros_like(params),
            hyper_state = hyper_state,
        )        

    def init(params: Params) -&gt; ServerState:
        return server_reset(params, fathom_init_hparams)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[
            federated_data.ClientId, 
            client_datasets.ClientDataset, 
            PRNGKey
        ]],
    ) -&gt; Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        Ep: float = server_state.hyper_state.hyperparams.Ep
        bs: int = max(int(server_state.hyper_state.hyperparams.bs + 0.5), 1) if server_state.hyper_state.hyperparams.bs &gt; 0 else -1
        eta_c: float = server_state.hyper_state.hyperparams.eta_c
        batch_clients = [(cid, cds.shuffle_repeat_batch(
            client_datasets.ShuffleRepeatBatchHParams(
                batch_size = bs if bs &gt; 0 else len(cds), 
                num_steps = max(int(jnp.ceil(Ep * client_num_examples[cid] / bs)), 1),
                num_epochs = None, # This is required.  See ShuffleRepeatBatchView implementation in fedjax.core.client_datasets.py.
                drop_remainder = client_batch_hparams.drop_remainder,
                seed = client_batch_hparams.seed,
                skip_shuffle = client_batch_hparams.skip_shuffle,
            )
        ), crng) for cid, cds, crng in clients]
        shared_input = {&#39;params&#39;: server_state.params, &#39;eta_c&#39;: eta_c, &#39;Ep&#39;: Ep, &#39;bs&#39;: bs}
        client_diagnostics = {}
        # Running weighted mean of client updates. We do this iteratively to avoid
        # loading all the client outputs into memory since they can be prohibitively
        # large depending on the model parameters size.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        hypergrad_local_sum = jnp.array(0.0)
        num_examples_sum = 0.
        for client_id, (delta_params, min_hypergrad, num_steps) in train_for_each_client(shared_input, batch_clients):
            # Server collecting stats before sending them to server_update for updating params and metrics.
            num_examples = client_num_examples[client_id]
            weighted_delta_params =  tree_util.tree_weight(delta_params, num_examples)
            delta_params_sum = tree_util.tree_add(delta_params_sum, weighted_delta_params)
            hypergrad_local_sum = hypergrad_local_sum + min_hypergrad * num_examples
            num_examples_sum += num_examples
            # We record the l2 norm of client updates as an example, but it is not
            # required for the algorithm.
            client_diagnostics[client_id] = {
                &#39;delta_l2_norm&#39;: tree_util.tree_l2_norm(delta_params),
            }
        mean_delta_params = fathom.core.tree_util.tree_inverse_weight(delta_params_sum, num_examples_sum)
        mean_hypergrad_local = hypergrad_local_sum / num_examples_sum
        server_state = server_update(
            server_state = server_state, 
            mean_delta_params = mean_delta_params, 
            hypergrad_local = mean_hypergrad_local, 
        )
        return server_state, client_diagnostics

    def server_update(
        server_state: ServerState, 
        mean_delta_params: Params, 
        hypergrad_local: jnp.ndarray, 
    ) -&gt; ServerState:
        opt_state_server, params = server_optimizer.apply(
            mean_delta_params, 
            server_state.opt_state, 
            server_state.params,
        )
        grad_glob: Params = estimate_grad_glob(server_state, mean_delta_params)
        hyper_state: HyperState = hyper_update(
            server_state = server_state, 
            params = params, 
            delta_params = mean_delta_params,
            hypergrad_local = hypergrad_local,
        )
        return ServerState(
            params = params,
            params_bak = server_state.params_bak, # Keep initial params rather than update
            opt_state = opt_state_server,
            round_index = server_state.round_index + 1,
            hyper_state = hyper_state,
            grad_glob = grad_glob,
        )

    def hyper_update(
        server_state: ServerState,
        params: Params,
        delta_params: Params,
        hypergrad_local: jnp.ndarray,
    ) -&gt; HyperState:

        opt_param, opt_state = server_state.hyper_state.opt_param, server_state.hyper_state.opt_state
        # Do not use the most current grad_glob as the result will bias positive
        hypergrad_glob: float = fathom.core.tree_util.tree_dot(server_state.grad_glob, delta_params)
        # Normalizing hypergrad_global here
        # hypergrad_local already normalized from local calculations
        grad_glob_norm = tree_util.tree_l2_norm(server_state.grad_glob)
        delta_params_norm = tree_util.tree_l2_norm(delta_params)
        hypergrad_glob = jnp.where((grad_glob_norm &gt; 0. and delta_params_norm &gt; 0.),
            hypergrad_glob / grad_glob_norm / delta_params_norm,
            0
        )
        hypergrad = - jnp.array([
            hypergrad_glob + hypergrad_local,   # Ep
            hypergrad_glob,                     # eta_c
            -hypergrad_local,                   # bs
        ]) 
        # This is where individual learning rates are applied, assuming opt is SGD.
        # With any other opt, individual learning rates need to be set at opt instantiation.
        hypergrad = hypergrad * server_state.hyper_state.hyperparams.eta_h

        # EGN gradients are already normalized
        opt_state, opt_param = hyper_optimizer.apply(hypergrad, opt_state, opt_param)
        hparams_vals = jnp.exp(opt_param) # Convert back to linear from log scale
        hparams_vals = jnp.clip(hparams_vals, a_max = server_state.hyper_state.hyperparams.hparam_ub)

        Ep, eta_c, bs = hparams_vals[0], hparams_vals[1], hparams_vals[2]
        hyperparams = HyperParams(
            Ep = Ep,
            eta_c = eta_c,
            bs = bs,
            eta_h = server_state.hyper_state.hyperparams.eta_h,
            alpha = server_state.hyper_state.hyperparams.alpha, 
            hparam_ub = server_state.hyper_state.hyperparams.hparam_ub, 
        )
        hyper_state = HyperState(
            hyperparams = hyperparams,
            init_hparams = server_state.hyper_state.init_hparams,
            opt_state = opt_state,
            opt_param = opt_param,
            hypergrad_glob = hypergrad_glob,
            hypergrad_local = hypergrad_local,
        )        
        return hyper_state

    return federated_algorithm.FederatedAlgorithm(init, apply)</code></pre>
</details>
</section>
<section>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-functions">Functions</h2>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.create_train_for_each_client"><code class="name flex">
<span>def <span class="ident">create_train_for_each_client</span></span>(<span>grad_fn: Any, client_optimizer: fedjax.core.optimizers.Optimizer)</span>
</code></dt>
<dd>
<div class="desc"><p>Builds client_init, client_step, client_final for for_each_client.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def create_train_for_each_client(grad_fn: Params, client_optimizer: optimizers.Optimizer):
    &#34;&#34;&#34;Builds client_init, client_step, client_final for for_each_client.&#34;&#34;&#34;
    normalize_hypergrad = 1

    def client_init(shared_input, client_rng):
        opt_state = client_optimizer.init(shared_input[&#39;params&#39;])
        opt_state.hyperparams[&#39;learning_rate&#39;] = shared_input[&#39;eta_c&#39;] # Already sigmoided
        client_step_state = {
            &#39;params&#39;: shared_input[&#39;params&#39;],
            &#39;params0&#39;: shared_input[&#39;params&#39;],
            &#39;opt_state&#39;: opt_state,
            &#39;rng&#39;: client_rng,
            &#39;step_idx&#39;: 0,
            &#39;min_hypergrad&#39;: tree_util.tree_l2_norm(shared_input[&#39;params&#39;]), # Large enough value
            }
        return client_step_state

    def client_step(client_step_state, batch):
        rng, use_rng, metrics_rng = jax.random.split(client_step_state[&#39;rng&#39;], num=3)
        grad_opt = grad_fn(client_step_state[&#39;params&#39;], batch, use_rng)
        opt_state, params = client_optimizer.apply(grad_opt, client_step_state[&#39;opt_state&#39;], client_step_state[&#39;params&#39;])
        delta_params = jax.tree_util.tree_multimap(jnp.subtract, client_step_state[&#39;params0&#39;], params)
        grad_opt_norm = tree_util.tree_l2_norm(grad_opt)
        delta_params_norm = tree_util.tree_l2_norm(delta_params)
        hypergrad = jnp.where(grad_opt_norm == 0., jnp.array(0),
            jnp.where(delta_params_norm == 0, jnp.array(0),
                jnp.where(normalize_hypergrad == 0 ,
                    fathom.core.tree_util.tree_dot(grad_opt, delta_params),
                    fathom.core.tree_util.tree_dot(grad_opt, delta_params) / grad_opt_norm / delta_params_norm)
            )
        )
        min_hypergrad = jnp.where(hypergrad &gt; client_step_state[&#39;min_hypergrad&#39;],
            client_step_state[&#39;min_hypergrad&#39;], hypergrad
        )
        next_client_step_state = {
            &#39;params&#39;: params,
            &#39;params0&#39;: client_step_state[&#39;params0&#39;],
            &#39;opt_state&#39;: opt_state,
            &#39;rng&#39;: rng,
            &#39;step_idx&#39;: client_step_state[&#39;step_idx&#39;] + 1,
            &#39;min_hypergrad&#39;: min_hypergrad,
        }
        return next_client_step_state

    def client_final(shared_input, client_step_state) -&gt; Tuple[Params, jnp.ndarray]:
        delta_params = jax.tree_util.tree_multimap(jnp.subtract, shared_input[&#39;params&#39;], client_step_state[&#39;params&#39;])
        return delta_params, client_step_state[&#39;min_hypergrad&#39;], client_step_state[&#39;step_idx&#39;]

    return for_each_client.for_each_client(client_init, client_step, client_final)</code></pre>
</details>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.estimate_grad_glob"><code class="name flex">
<span>def <span class="ident">estimate_grad_glob</span></span>(<span>server_state: <a title="fathom.algorithms.fathom_fedavg.ServerState" href="#fathom.algorithms.fathom_fedavg.ServerState">ServerState</a>, mean_delta_params: Any) ‑> Any</span>
</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">@jax.jit
def estimate_grad_glob(server_state: ServerState, mean_delta_params: Params) -&gt; Params:
    grad_glob = tree_util.tree_weight(server_state.grad_glob, server_state.hyper_state.hyperparams.alpha)
    delta_params = tree_util.tree_weight(mean_delta_params, 1. - server_state.hyper_state.hyperparams.alpha)
    grad_glob = tree_util.tree_add(grad_glob, delta_params)
    # grad_glob = fathom.core.tree_util.tree_inverse_weight(
    #     grad_glob, 
    #     (1. - server_state.hyper_state.hyperparams.alpha ** server_state.round_index)
    # )
    return grad_glob</code></pre>
</details>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.federated_averaging"><code class="name flex">
<span>def <span class="ident">federated_averaging</span></span>(<span>grad_fn: Callable[[Any, Mapping[str, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray], Any], client_optimizer: fedjax.core.optimizers.Optimizer, server_optimizer: fedjax.core.optimizers.Optimizer, hyper_optimizer: fedjax.core.optimizers.Optimizer, client_batch_hparams: fedjax.core.client_datasets.ShuffleRepeatBatchHParams, fathom_init_hparams: <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a>) ‑> fedjax.core.federated_algorithm.FederatedAlgorithm</span>
</code></dt>
<dd>
<div class="desc"><p>Builds federated averaging.</p>
<h2 id="args">Args</h2>
<dl>
<dt><strong><code>grad_fn</code></strong></dt>
<dd>A function from (params, batch_example, rng) to gradients.
This can be created with :func:<code>fedjax.core.model.model_grad</code>.</dd>
<dt><strong><code>client_optimizer</code></strong></dt>
<dd>Optimizer for local client training.</dd>
<dt><strong><code>server_optimizer</code></strong></dt>
<dd>Optimizer for server update.</dd>
<dt><strong><code>client_batch_hparams</code></strong></dt>
<dd>Hyperparameters for batching client dataset for train.</dd>
</dl>
<h2 id="returns">Returns</h2>
<p>FederatedAlgorithm</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def federated_averaging(
        grad_fn: Callable[[Params, BatchExample, PRNGKey], Grads],
        client_optimizer: optimizers.Optimizer,
        server_optimizer: optimizers.Optimizer,
        hyper_optimizer: optimizers.Optimizer,
        client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
        fathom_init_hparams: HyperParams,
) -&gt; federated_algorithm.FederatedAlgorithm:
    &#34;&#34;&#34;Builds federated averaging.

    Args:
        grad_fn: A function from (params, batch_example, rng) to gradients.
            This can be created with :func:`fedjax.core.model.model_grad`.
        client_optimizer: Optimizer for local client training.
        server_optimizer: Optimizer for server update.
        client_batch_hparams: Hyperparameters for batching client dataset for train.

    Returns:
        FederatedAlgorithm
    &#34;&#34;&#34;
    train_for_each_client = create_train_for_each_client(grad_fn, client_optimizer)

    def server_reset(params: Params, init_hparams: HyperParams) -&gt; ServerState:
        opt_state_server = server_optimizer.init(params)
        opt_param_hyper = jnp.log(jnp.array([init_hparams.Ep, init_hparams.eta_c, init_hparams.bs]))
        opt_state_hyper = hyper_optimizer.init(opt_param_hyper)
        hyper_state = HyperState(
            hyperparams = init_hparams,
            init_hparams = init_hparams,
            opt_state = opt_state_hyper,
            opt_param = opt_param_hyper,
            hypergrad_glob = 0.,
            hypergrad_local = 0.,
        )
        # Need to initialize round_index to 1 for bias comp
        return ServerState(
            params = params, 
            params_bak = params,
            opt_state = opt_state_server, 
            round_index = 1, 
            grad_glob = tree_util.tree_zeros_like(params),
            hyper_state = hyper_state,
        )        

    def init(params: Params) -&gt; ServerState:
        return server_reset(params, fathom_init_hparams)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[
            federated_data.ClientId, 
            client_datasets.ClientDataset, 
            PRNGKey
        ]],
    ) -&gt; Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        Ep: float = server_state.hyper_state.hyperparams.Ep
        bs: int = max(int(server_state.hyper_state.hyperparams.bs + 0.5), 1) if server_state.hyper_state.hyperparams.bs &gt; 0 else -1
        eta_c: float = server_state.hyper_state.hyperparams.eta_c
        batch_clients = [(cid, cds.shuffle_repeat_batch(
            client_datasets.ShuffleRepeatBatchHParams(
                batch_size = bs if bs &gt; 0 else len(cds), 
                num_steps = max(int(jnp.ceil(Ep * client_num_examples[cid] / bs)), 1),
                num_epochs = None, # This is required.  See ShuffleRepeatBatchView implementation in fedjax.core.client_datasets.py.
                drop_remainder = client_batch_hparams.drop_remainder,
                seed = client_batch_hparams.seed,
                skip_shuffle = client_batch_hparams.skip_shuffle,
            )
        ), crng) for cid, cds, crng in clients]
        shared_input = {&#39;params&#39;: server_state.params, &#39;eta_c&#39;: eta_c, &#39;Ep&#39;: Ep, &#39;bs&#39;: bs}
        client_diagnostics = {}
        # Running weighted mean of client updates. We do this iteratively to avoid
        # loading all the client outputs into memory since they can be prohibitively
        # large depending on the model parameters size.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        hypergrad_local_sum = jnp.array(0.0)
        num_examples_sum = 0.
        for client_id, (delta_params, min_hypergrad, num_steps) in train_for_each_client(shared_input, batch_clients):
            # Server collecting stats before sending them to server_update for updating params and metrics.
            num_examples = client_num_examples[client_id]
            weighted_delta_params =  tree_util.tree_weight(delta_params, num_examples)
            delta_params_sum = tree_util.tree_add(delta_params_sum, weighted_delta_params)
            hypergrad_local_sum = hypergrad_local_sum + min_hypergrad * num_examples
            num_examples_sum += num_examples
            # We record the l2 norm of client updates as an example, but it is not
            # required for the algorithm.
            client_diagnostics[client_id] = {
                &#39;delta_l2_norm&#39;: tree_util.tree_l2_norm(delta_params),
            }
        mean_delta_params = fathom.core.tree_util.tree_inverse_weight(delta_params_sum, num_examples_sum)
        mean_hypergrad_local = hypergrad_local_sum / num_examples_sum
        server_state = server_update(
            server_state = server_state, 
            mean_delta_params = mean_delta_params, 
            hypergrad_local = mean_hypergrad_local, 
        )
        return server_state, client_diagnostics

    def server_update(
        server_state: ServerState, 
        mean_delta_params: Params, 
        hypergrad_local: jnp.ndarray, 
    ) -&gt; ServerState:
        opt_state_server, params = server_optimizer.apply(
            mean_delta_params, 
            server_state.opt_state, 
            server_state.params,
        )
        grad_glob: Params = estimate_grad_glob(server_state, mean_delta_params)
        hyper_state: HyperState = hyper_update(
            server_state = server_state, 
            params = params, 
            delta_params = mean_delta_params,
            hypergrad_local = hypergrad_local,
        )
        return ServerState(
            params = params,
            params_bak = server_state.params_bak, # Keep initial params rather than update
            opt_state = opt_state_server,
            round_index = server_state.round_index + 1,
            hyper_state = hyper_state,
            grad_glob = grad_glob,
        )

    def hyper_update(
        server_state: ServerState,
        params: Params,
        delta_params: Params,
        hypergrad_local: jnp.ndarray,
    ) -&gt; HyperState:

        opt_param, opt_state = server_state.hyper_state.opt_param, server_state.hyper_state.opt_state
        # Do not use the most current grad_glob as the result will bias positive
        hypergrad_glob: float = fathom.core.tree_util.tree_dot(server_state.grad_glob, delta_params)
        # Normalizing hypergrad_global here
        # hypergrad_local already normalized from local calculations
        grad_glob_norm = tree_util.tree_l2_norm(server_state.grad_glob)
        delta_params_norm = tree_util.tree_l2_norm(delta_params)
        hypergrad_glob = jnp.where((grad_glob_norm &gt; 0. and delta_params_norm &gt; 0.),
            hypergrad_glob / grad_glob_norm / delta_params_norm,
            0
        )
        hypergrad = - jnp.array([
            hypergrad_glob + hypergrad_local,   # Ep
            hypergrad_glob,                     # eta_c
            -hypergrad_local,                   # bs
        ]) 
        # This is where individual learning rates are applied, assuming opt is SGD.
        # With any other opt, individual learning rates need to be set at opt instantiation.
        hypergrad = hypergrad * server_state.hyper_state.hyperparams.eta_h

        # EGN gradients are already normalized
        opt_state, opt_param = hyper_optimizer.apply(hypergrad, opt_state, opt_param)
        hparams_vals = jnp.exp(opt_param) # Convert back to linear from log scale
        hparams_vals = jnp.clip(hparams_vals, a_max = server_state.hyper_state.hyperparams.hparam_ub)

        Ep, eta_c, bs = hparams_vals[0], hparams_vals[1], hparams_vals[2]
        hyperparams = HyperParams(
            Ep = Ep,
            eta_c = eta_c,
            bs = bs,
            eta_h = server_state.hyper_state.hyperparams.eta_h,
            alpha = server_state.hyper_state.hyperparams.alpha, 
            hparam_ub = server_state.hyper_state.hyperparams.hparam_ub, 
        )
        hyper_state = HyperState(
            hyperparams = hyperparams,
            init_hparams = server_state.hyper_state.init_hparams,
            opt_state = opt_state,
            opt_param = opt_param,
            hypergrad_glob = hypergrad_glob,
            hypergrad_local = hypergrad_local,
        )        
        return hyper_state

    return federated_algorithm.FederatedAlgorithm(init, apply)</code></pre>
</details>
</dd>
</dl>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams"><code class="flex name class">
<span>class <span class="ident">HyperParams</span></span>
<span>(</span><span>eta_c: float, Ep: float, bs: float, alpha: float, eta_h: jax._src.numpy.lax_numpy.ndarray, hparam_ub: jax._src.numpy.lax_numpy.ndarray)</span>
</code></dt>
<dd>
<div class="desc"><p>HyperParams(eta_c: float, Ep: float, bs: float, alpha: float, eta_h: jax._src.numpy.lax_numpy.ndarray, hparam_ub: jax._src.numpy.lax_numpy.ndarray)</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class HyperParams:
    eta_c: float                # Local learning rate or step size
    Ep: float                  # Number of local steps = ceil(local_samples / batch_size) * Ep, where Ep ~ num epochs worth of data.
    bs: float                   # Local batch size
    alpha: float                # Momentum for glob grad estimation
    eta_h: jnp.ndarray          # Hyper optimizer learning rates for Ep, eta_c, and bs, respectively
    hparam_ub: jnp.ndarray      # Upperbound vals for Ep, eta_c, and bs, respectively.</code></pre>
</details>
<h3>Class variables</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.Ep"><code class="name">var <span class="ident">Ep</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.alpha"><code class="name">var <span class="ident">alpha</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.bs"><code class="name">var <span class="ident">bs</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.eta_c"><code class="name">var <span class="ident">eta_c</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.eta_h"><code class="name">var <span class="ident">eta_h</span> : jax._src.numpy.lax_numpy.ndarray</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.hparam_ub"><code class="name">var <span class="ident">hparam_ub</span> : jax._src.numpy.lax_numpy.ndarray</code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Methods</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.HyperParams.replace"><code class="name flex">
<span>def <span class="ident">replace</span></span>(<span>self, **updates)</span>
</code></dt>
<dd>
<div class="desc"><p>"Returns a new object replacing the specified fields with new values.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def replace(self, **updates):
  &#34;&#34;&#34;&#34;Returns a new object replacing the specified fields with new values.&#34;&#34;&#34;
  return dataclasses.replace(self, **updates)</code></pre>
</details>
</dd>
</dl>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState"><code class="flex name class">
<span>class <span class="ident">HyperState</span></span>
<span>(</span><span>hyperparams: <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a>, init_hparams: <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a>, opt_state: Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]], opt_param: jax._src.numpy.lax_numpy.ndarray, hypergrad_glob: float, hypergrad_local: float)</span>
</code></dt>
<dd>
<div class="desc"><p>HyperState(hyperparams: fathom.algorithms.fathom_fedavg.HyperParams, init_hparams: fathom.algorithms.fathom_fedavg.HyperParams, opt_state: Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]], opt_param: jax._src.numpy.lax_numpy.ndarray, hypergrad_glob: float, hypergrad_local: float)</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class HyperState:
    hyperparams: HyperParams 
    init_hparams: HyperParams
    opt_state: optimizers.OptState
    opt_param: jnp.ndarray
    hypergrad_glob: float
    hypergrad_local: float</code></pre>
</details>
<h3>Class variables</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.hypergrad_glob"><code class="name">var <span class="ident">hypergrad_glob</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.hypergrad_local"><code class="name">var <span class="ident">hypergrad_local</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.hyperparams"><code class="name">var <span class="ident">hyperparams</span> : <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.init_hparams"><code class="name">var <span class="ident">init_hparams</span> : <a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.opt_param"><code class="name">var <span class="ident">opt_param</span> : jax._src.numpy.lax_numpy.ndarray</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.opt_state"><code class="name">var <span class="ident">opt_state</span> : Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]</code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Methods</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.HyperState.replace"><code class="name flex">
<span>def <span class="ident">replace</span></span>(<span>self, **updates)</span>
</code></dt>
<dd>
<div class="desc"><p>"Returns a new object replacing the specified fields with new values.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def replace(self, **updates):
  &#34;&#34;&#34;&#34;Returns a new object replacing the specified fields with new values.&#34;&#34;&#34;
  return dataclasses.replace(self, **updates)</code></pre>
</details>
</dd>
</dl>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState"><code class="flex name class">
<span>class <span class="ident">ServerState</span></span>
<span>(</span><span>params: Any, params_bak: Any, opt_state: Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]], round_index: int, hyper_state: <a title="fathom.algorithms.fathom_fedavg.HyperState" href="#fathom.algorithms.fathom_fedavg.HyperState">HyperState</a>, grad_glob: Any)</span>
</code></dt>
<dd>
<div class="desc"><p>State of server passed between rounds.</p>
<h2 id="attributes">Attributes</h2>
<dl>
<dt><strong><code>params</code></strong></dt>
<dd>A pytree representing the server model parameters.</dd>
<dt><strong><code>opt_state</code></strong></dt>
<dd>A pytree representing the server optimizer state.</dd>
</dl></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class ServerState:
    &#34;&#34;&#34;State of server passed between rounds.

    Attributes:
        params: A pytree representing the server model parameters.
        opt_state: A pytree representing the server optimizer state.
    &#34;&#34;&#34;
    params: Params
    params_bak: Params
    opt_state: optimizers.OptState
    round_index: int
    hyper_state: HyperState
    grad_glob: Params</code></pre>
</details>
<h3>Class variables</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.grad_glob"><code class="name">var <span class="ident">grad_glob</span> : Any</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.hyper_state"><code class="name">var <span class="ident">hyper_state</span> : <a title="fathom.algorithms.fathom_fedavg.HyperState" href="#fathom.algorithms.fathom_fedavg.HyperState">HyperState</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.opt_state"><code class="name">var <span class="ident">opt_state</span> : Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.params"><code class="name">var <span class="ident">params</span> : Any</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.params_bak"><code class="name">var <span class="ident">params_bak</span> : Any</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.round_index"><code class="name">var <span class="ident">round_index</span> : int</code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Methods</h3>
<dl>
<dt id="fathom.algorithms.fathom_fedavg.ServerState.replace"><code class="name flex">
<span>def <span class="ident">replace</span></span>(<span>self, **updates)</span>
</code></dt>
<dd>
<div class="desc"><p>"Returns a new object replacing the specified fields with new values.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def replace(self, **updates):
  &#34;&#34;&#34;&#34;Returns a new object replacing the specified fields with new values.&#34;&#34;&#34;
  return dataclasses.replace(self, **updates)</code></pre>
</details>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<h1>Index</h1>
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3>Super-module</h3>
<ul>
<li><code><a title="fathom.algorithms" href="index.html">fathom.algorithms</a></code></li>
</ul>
</li>
<li><h3><a href="#header-functions">Functions</a></h3>
<ul class="">
<li><code><a title="fathom.algorithms.fathom_fedavg.create_train_for_each_client" href="#fathom.algorithms.fathom_fedavg.create_train_for_each_client">create_train_for_each_client</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.estimate_grad_glob" href="#fathom.algorithms.fathom_fedavg.estimate_grad_glob">estimate_grad_glob</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.federated_averaging" href="#fathom.algorithms.fathom_fedavg.federated_averaging">federated_averaging</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="fathom.algorithms.fathom_fedavg.HyperParams" href="#fathom.algorithms.fathom_fedavg.HyperParams">HyperParams</a></code></h4>
<ul class="two-column">
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.Ep" href="#fathom.algorithms.fathom_fedavg.HyperParams.Ep">Ep</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.alpha" href="#fathom.algorithms.fathom_fedavg.HyperParams.alpha">alpha</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.bs" href="#fathom.algorithms.fathom_fedavg.HyperParams.bs">bs</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.eta_c" href="#fathom.algorithms.fathom_fedavg.HyperParams.eta_c">eta_c</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.eta_h" href="#fathom.algorithms.fathom_fedavg.HyperParams.eta_h">eta_h</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.hparam_ub" href="#fathom.algorithms.fathom_fedavg.HyperParams.hparam_ub">hparam_ub</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperParams.replace" href="#fathom.algorithms.fathom_fedavg.HyperParams.replace">replace</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="fathom.algorithms.fathom_fedavg.HyperState" href="#fathom.algorithms.fathom_fedavg.HyperState">HyperState</a></code></h4>
<ul class="two-column">
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.hypergrad_glob" href="#fathom.algorithms.fathom_fedavg.HyperState.hypergrad_glob">hypergrad_glob</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.hypergrad_local" href="#fathom.algorithms.fathom_fedavg.HyperState.hypergrad_local">hypergrad_local</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.hyperparams" href="#fathom.algorithms.fathom_fedavg.HyperState.hyperparams">hyperparams</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.init_hparams" href="#fathom.algorithms.fathom_fedavg.HyperState.init_hparams">init_hparams</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.opt_param" href="#fathom.algorithms.fathom_fedavg.HyperState.opt_param">opt_param</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.opt_state" href="#fathom.algorithms.fathom_fedavg.HyperState.opt_state">opt_state</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.HyperState.replace" href="#fathom.algorithms.fathom_fedavg.HyperState.replace">replace</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="fathom.algorithms.fathom_fedavg.ServerState" href="#fathom.algorithms.fathom_fedavg.ServerState">ServerState</a></code></h4>
<ul class="two-column">
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.grad_glob" href="#fathom.algorithms.fathom_fedavg.ServerState.grad_glob">grad_glob</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.hyper_state" href="#fathom.algorithms.fathom_fedavg.ServerState.hyper_state">hyper_state</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.opt_state" href="#fathom.algorithms.fathom_fedavg.ServerState.opt_state">opt_state</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.params" href="#fathom.algorithms.fathom_fedavg.ServerState.params">params</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.params_bak" href="#fathom.algorithms.fathom_fedavg.ServerState.params_bak">params_bak</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.replace" href="#fathom.algorithms.fathom_fedavg.ServerState.replace">replace</a></code></li>
<li><code><a title="fathom.algorithms.fathom_fedavg.ServerState.round_index" href="#fathom.algorithms.fathom_fedavg.ServerState.round_index">round_index</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.10.0</a>.</p>
</footer>
</body>
</html>