<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>FishLeg.fishleg &mdash; FishLeg  documentation</title>
      <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
        <script src="../../_static/jquery.js"></script>
        <script src="../../_static/underscore.js"></script>
        <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script src="../../_static/doctools.js"></script>
        <script src="../../_static/sphinx_highlight.js"></script>
    <script src="../../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../../genindex.html" />
    <link rel="search" title="Search" href="../../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
            <a href="../../index.html" class="icon icon-home"> FishLeg
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../modules.html">optim</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../index.html">FishLeg</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../../index.html" class="icon icon-home"></a></li>
          <li class="breadcrumb-item"><a href="../index.html">Module code</a></li>
      <li class="breadcrumb-item active">FishLeg.fishleg</li>
      <li class="wy-breadcrumbs-aside">
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <h1>Source code for FishLeg.fishleg</h1><div class="highlight"><pre>
<span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span><span class="p">,</span> <span class="n">Adam</span>

<span class="kn">from</span> <span class="nn">.fishleg_layers</span> <span class="kn">import</span> <span class="n">FISH_LAYERS</span>


<div class="viewcode-block" id="FishLeg"><a class="viewcode-back" href="../../FishLeg.html#FishLeg.fishleg.FishLeg">[docs]</a><span class="k">class</span> <span class="nc">FishLeg</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Implement FishLeg algorithm.</span>

<span class="sd">    :param torch.nn.Module model: a pytorch neural network module, </span>
<span class="sd">                can be nested in a tree structure</span>
<span class="sd">    :param float lr: learning rate,</span>
<span class="sd">                for the parameters of the input model using FishLeg (default: 1e-2)</span>
<span class="sd">    :param float eps: a small scalar, to evaluate the auxiliary loss </span>
<span class="sd">                in the direction of gradient of model parameters (default: 1e-4)</span>
<span class="sd">    :param int aux_K: number of sample to evaluate the entropy (default: 5)</span>

<span class="sd">    :param int update_aux_every: number of iteration after which an auxiliary </span>
<span class="sd">                update is executed, if negative, then run -update_aux_every auxiliary </span>
<span class="sd">                updates in each outer iteration. (default: -3)</span>
<span class="sd">    :param float aux_lr: learning rate for the auxiliary parameters, </span>
<span class="sd">                using Adam (default: 1e-3)</span>
<span class="sd">    :param Tuple[float, float] aux_betas: coefficients used for computing</span>
<span class="sd">                running averages of gradient and its square for auxiliary parameters</span>
<span class="sd">                (default: (0.9, 0.999))</span>
<span class="sd">    :param float aux_eps: term added to the denominator to improve</span>
<span class="sd">                numerical stability for auxiliary parameters (default: 1e-8)</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
        <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-2</span><span class="p">,</span>
        <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-4</span><span class="p">,</span>
        <span class="n">aux_K</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
        <span class="n">update_aux_every</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">3</span><span class="p">,</span>
        <span class="n">aux_lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-3</span><span class="p">,</span>
        <span class="n">aux_betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span>
        <span class="n">aux_eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-8</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">plus_model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">minus_model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_model_aux</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>

        <span class="c1"># partition by modules</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_param</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">param</span>
            <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()</span>
            <span class="k">if</span> <span class="s2">&quot;fishleg_aux&quot;</span> <span class="ow">in</span> <span class="n">name</span>
        <span class="p">]</span>

        <span class="n">param_groups</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">module_name</span><span class="p">,</span> <span class="n">module</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">named_modules</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="s2">&quot;fishleg_aux&quot;</span><span class="p">):</span>
                <span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
                    <span class="n">name</span><span class="p">:</span> <span class="n">param</span>
                    <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span>
                        <span class="n">module_name</span>
                    <span class="p">]</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()</span>
                    <span class="k">if</span> <span class="s2">&quot;fishleg_aux&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">name</span>
                <span class="p">}</span>
                <span class="n">g</span> <span class="o">=</span> <span class="p">{</span>
                    <span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">order</span><span class="p">],</span>
                    <span class="s2">&quot;aux_params&quot;</span><span class="p">:</span> <span class="p">{</span>
                        <span class="n">name</span><span class="p">:</span> <span class="n">param</span>
                        <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()</span>
                        <span class="k">if</span> <span class="s2">&quot;fishleg_aux&quot;</span> <span class="ow">in</span> <span class="n">name</span>
                    <span class="p">},</span>
                    <span class="s2">&quot;Qv&quot;</span><span class="p">:</span> <span class="n">module</span><span class="o">.</span><span class="n">Qv</span><span class="p">,</span>
                    <span class="s2">&quot;order&quot;</span><span class="p">:</span> <span class="n">module</span><span class="o">.</span><span class="n">order</span><span class="p">,</span>
                    <span class="s2">&quot;name&quot;</span><span class="p">:</span> <span class="n">module_name</span><span class="p">,</span>
                <span class="p">}</span>
                <span class="n">param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">g</span><span class="p">)</span>
        <span class="c1"># TODO: add param_group for modules without aux</span>
        <span class="n">defaults</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>

        <span class="nb">super</span><span class="p">(</span><span class="n">FishLeg</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">param_groups</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_opt</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">aux_param</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">aux_lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">aux_betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">aux_eps</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_K</span> <span class="o">=</span> <span class="n">aux_K</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">update_aux_every</span> <span class="o">=</span> <span class="n">update_aux_every</span> 
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_lr</span> <span class="o">=</span> <span class="n">aux_lr</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_betas</span> <span class="o">=</span> <span class="n">aux_betas</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_eps</span> <span class="o">=</span> <span class="n">aux_eps</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">step_t</span> <span class="o">=</span> <span class="mi">0</span>

<div class="viewcode-block" id="FishLeg.init_model_aux"><a class="viewcode-back" href="../../FishLeg.html#FishLeg.fishleg.FishLeg.init_model_aux">[docs]</a>    <span class="k">def</span> <span class="nf">init_model_aux</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span>
        <span class="sd">&quot;&quot;&quot;Given a model to optimize, parameters can be devided to</span>
<span class="sd">        </span>
<span class="sd">        #. those fixed as pre-trained.</span>
<span class="sd">        #. those required to optimize using FishLeg.</span>

<span class="sd">        Replace modules in the second group with FishLeg modules.</span>

<span class="sd">        Args:</span>
<span class="sd">            model (:class:`torch.nn.Module`, required): </span>
<span class="sd">                A model containing modules to replace with FishLeg modules </span>
<span class="sd">                containing extra functionality related to FishLeg algorithm.</span>
<span class="sd">        Returns:</span>
<span class="sd">            :class:`torch.nn.Module`, the replaced model.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">named_modules</span><span class="p">():</span>
            <span class="k">try</span><span class="p">:</span>
                <span class="n">replace</span> <span class="o">=</span> <span class="n">FISH_LAYERS</span><span class="p">[</span><span class="nb">type</span><span class="p">(</span><span class="n">module</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="o">.</span><span class="n">lower</span><span class="p">()](</span>
                    <span class="n">module</span><span class="o">.</span><span class="n">in_features</span><span class="p">,</span> 
                    <span class="n">module</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> 
                    <span class="n">module</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
                <span class="p">)</span>
                <span class="n">replace</span> <span class="o">=</span> <span class="n">update_dict</span><span class="p">(</span><span class="n">replace</span><span class="p">,</span> <span class="n">module</span><span class="p">)</span>
                <span class="n">model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">replace</span>
            <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
                <span class="k">pass</span>

        <span class="c1"># TODO: The above may not be a very &quot;correct&quot; way to do this, so please feel free to change, for example, we may want to check the name is in the fish_layer keys before attempting what is in the try statement.</span>
        <span class="c1"># TODO: Error checking to check that model includes some auxiliary arguments.</span>

        <span class="k">return</span> <span class="n">model</span></div>
    
<div class="viewcode-block" id="FishLeg.update_aux"><a class="viewcode-back" href="../../FishLeg.html#FishLeg.fishleg.FishLeg.update_aux">[docs]</a>    <span class="k">def</span> <span class="nf">update_aux</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="sd">&quot;&quot;&quot;Performs a single auxliarary parameter update</span>
<span class="sd">        using Adam. By minimizing the following objective:</span>

<span class="sd">        .. math::</span>
<span class="sd">            nll(model, \\theta + \epsilon Q(\lambda)g) + nll(model, \\theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g </span>

<span class="sd">        where :math:`\\theta` is the parameters of model, :math:`\lambda` is the</span>
<span class="sd">        auxliarary parameters.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
        <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">aux_K</span><span class="p">)</span>

        <span class="n">aux_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
            <span class="n">name</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span>

            <span class="n">grad</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]]</span>
            <span class="n">qg</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;Qv&quot;</span><span class="p">](</span><span class="n">group</span><span class="p">[</span><span class="s2">&quot;aux_params&quot;</span><span class="p">],</span> <span class="n">grad</span><span class="p">)</span>

            <span class="k">for</span> <span class="n">g</span><span class="p">,</span> <span class="n">d_p</span><span class="p">,</span> <span class="n">para_name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">qg</span><span class="p">,</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;order&quot;</span><span class="p">]):</span>
                <span class="n">param_plus</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">plus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span>
                <span class="n">param_plus</span> <span class="o">=</span> <span class="n">param_plus</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
                <span class="n">param_minus</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">minus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span>
                <span class="n">param_minus</span> <span class="o">=</span> <span class="n">param_minus</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>

                <span class="n">param_plus</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">d_p</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
                <span class="n">param_minus</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">d_p</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
                <span class="n">aux_loss</span> <span class="o">-=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">g</span> <span class="o">*</span> <span class="n">d_p</span><span class="p">)</span>

        <span class="n">h_plus</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">plus_model</span><span class="o">.</span><span class="n">nll</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
        <span class="n">h_minus</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">minus_model</span><span class="o">.</span><span class="n">nll</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
        <span class="n">aux_loss</span> <span class="o">+=</span> <span class="p">(</span><span class="n">h_plus</span> <span class="o">+</span> <span class="n">h_minus</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span>

        <span class="n">aux_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">aux_opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>

        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">para_name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">],</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;order&quot;</span><span class="p">]):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">plus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">minus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span></div>

<div class="viewcode-block" id="FishLeg.step"><a class="viewcode-back" href="../../FishLeg.html#FishLeg.fishleg.FishLeg.step">[docs]</a>    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="sd">&quot;&quot;&quot;Performes a single optimization step of FishLeg.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">step_t</span> <span class="o">+=</span> <span class="mi">1</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_aux_every</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_t</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_aux_every</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">update_aux</span><span class="p">()</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_aux_every</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">update_aux_every</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">update_aux</span><span class="p">()</span>

        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
            <span class="n">lr</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span>
            <span class="n">order</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;order&quot;</span><span class="p">]</span>
            <span class="n">name</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span>

            <span class="k">if</span> <span class="s2">&quot;aux_params&quot;</span> <span class="ow">in</span> <span class="n">group</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
                <span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]]</span>
                <span class="n">qg</span> <span class="o">=</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;Qv&quot;</span><span class="p">](</span><span class="n">group</span><span class="p">[</span><span class="s2">&quot;aux_params&quot;</span><span class="p">],</span> <span class="n">grad</span><span class="p">)</span>

                <span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">d_p</span><span class="p">,</span> <span class="n">para_name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">],</span> <span class="n">qg</span><span class="p">,</span> <span class="n">order</span><span class="p">):</span>
                    <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">d_p</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="n">lr</span><span class="p">)</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">plus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">minus_model</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="n">para_name</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span></div></div>


<div class="viewcode-block" id="update_dict"><a class="viewcode-back" href="../../FishLeg.html#FishLeg.fishleg.update_dict">[docs]</a><span class="k">def</span> <span class="nf">update_dict</span><span class="p">(</span><span class="n">replace</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span>
        <span class="n">replace_dict</span> <span class="o">=</span> <span class="n">replace</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
        <span class="n">pretrained_dict</span> <span class="o">=</span> <span class="p">{</span>
            <span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">replace_dict</span>
        <span class="p">}</span>
        <span class="n">replace_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">pretrained_dict</span><span class="p">)</span>
        <span class="n">replace</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">replace_dict</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">replace</span></div>
</pre></div>

           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023, MTKResearch.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>