
<!DOCTYPE html>

<html>
  <head>
    <meta charset="utf-8" />
    <title>ReferentialGym.networks package &#8212; ReferentialGym  documentation</title>
    <link rel="stylesheet" href="_static/classic.css" type="text/css" />
    <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
    
    <script id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
    <script src="_static/jquery.js"></script>
    <script src="_static/underscore.js"></script>
    <script src="_static/doctools.js"></script>
    <script src="_static/language_data.js"></script>
    
    <link rel="index" title="Index" href="genindex.html" />
    <link rel="search" title="Search" href="search.html" />
    <link rel="next" title="ReferentialGym.utils package" href="ReferentialGym.utils.html" />
    <link rel="prev" title="ReferentialGym.modules package" href="ReferentialGym.modules.html" /> 
  </head><body>
    <div class="related" role="navigation" aria-label="related navigation">
      <h3>Navigation</h3>
      <ul>
        <li class="right" style="margin-right: 10px">
          <a href="genindex.html" title="General Index"
             accesskey="I">index</a></li>
        <li class="right" >
          <a href="py-modindex.html" title="Python Module Index"
             >modules</a> |</li>
        <li class="right" >
          <a href="ReferentialGym.utils.html" title="ReferentialGym.utils package"
             accesskey="N">next</a> |</li>
        <li class="right" >
          <a href="ReferentialGym.modules.html" title="ReferentialGym.modules package"
             accesskey="P">previous</a> |</li>
        <li class="nav-item nav-item-0"><a href="index.html">ReferentialGym  documentation</a> &#187;</li>
          <li class="nav-item nav-item-1"><a href="modules.html" >ReferentialGym</a> &#187;</li>
          <li class="nav-item nav-item-2"><a href="ReferentialGym.html" accesskey="U">ReferentialGym package</a> &#187;</li> 
      </ul>
    </div>  

    <div class="document">
      <div class="documentwrapper">
        <div class="bodywrapper">
          <div class="body" role="main">
            
  <div class="section" id="referentialgym-networks-package">
<h1>ReferentialGym.networks package<a class="headerlink" href="#referentialgym-networks-package" title="Permalink to this headline">¶</a></h1>
<div class="section" id="submodules">
<h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this headline">¶</a></h2>
</div>
<div class="section" id="module-ReferentialGym.networks.autoregressive_networks">
<span id="referentialgym-networks-autoregressive-networks-module"></span><h2>ReferentialGym.networks.autoregressive_networks module<a class="headerlink" href="#module-ReferentialGym.networks.autoregressive_networks" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.Distribution">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">Distribution</code><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Distribution" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Distribution.sample">
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Distribution.sample" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Distribution.log_prob">
<code class="sig-name descname">log_prob</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">values</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Distribution.log_prob" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.Bernoulli">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">Bernoulli</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">probs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Bernoulli" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.autoregressive_networks.Distribution" title="ReferentialGym.networks.autoregressive_networks.Distribution"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.autoregressive_networks.Distribution</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Bernoulli.sample">
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Bernoulli.sample" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Bernoulli.log_prob">
<code class="sig-name descname">log_prob</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">values</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Bernoulli.log_prob" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.Normal">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">Normal</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">mean</span></em>, <em class="sig-param"><span class="n">std</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Normal" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.autoregressive_networks.Distribution" title="ReferentialGym.networks.autoregressive_networks.Distribution"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.autoregressive_networks.Distribution</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Normal.sample">
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Normal.sample" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Normal.log_prob">
<code class="sig-name descname">log_prob</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">value</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Normal.log_prob" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetEncoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ResNetEncoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetEncoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="ReferentialGym.networks.residual_networks.ModelResNet18"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetEncoder.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetEncoder.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetEncoder.encode">
<code class="sig-name descname">encode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetEncoder.encode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetEncoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetEncoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ResNetAvgPooledEncoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled" title="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.encode">
<code class="sig-name descname">encode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.encode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ResNetParallelAttentionEncoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="ReferentialGym.networks.residual_networks.ModelResNet18"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.encode">
<code class="sig-name descname">encode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.encode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.addXYSfeatures">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">addXYSfeatures</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.addXYSfeatures" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.addXYSfeatures.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.addXYSfeatures.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ResNetPHDPAEncoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="ReferentialGym.networks.residual_networks.ModelResNet18"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.encode">
<code class="sig-name descname">encode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.encode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.Decoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">Decoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">output_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">net_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">conv_dim</span><span class="o">=</span><span class="default_value">64</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Decoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Decoder.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Decoder.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.Decoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.Decoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">BroadcastingDecoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">output_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">net_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">kernel_size</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">conv_dim</span><span class="o">=</span><span class="default_value">64</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">BroadcastingDeconvDecoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">output_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">net_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">conv_dim</span><span class="o">=</span><span class="default_value">64</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ParallelAttentionBroadcastingDeconvDecoder</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">output_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">net_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">conv_dim</span><span class="o">=</span><span class="default_value">64</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">TotalCorrelationDiscriminator</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">VAE</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.update">
<code class="sig-name descname">update</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em>, <em class="sig-param"><span class="n">train</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.update" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.step">
<code class="sig-name descname">step</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.step" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.permutate_latents">
<code class="sig-name descname">permutate_latents</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator.permutate_latents" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">BetaVAE</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">beta</span><span class="o">=</span><span class="default_value">10000.0</span></em>, <em class="sig-param"><span class="n">encoder</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">decoder</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">latent_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">input_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">NormalOutputDistribution</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">maxEncodingCapacity</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">nbrEpochTillMaxEncodingCapacity</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">constrainedEncoding</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">0.05</span></em>, <em class="sig-param"><span class="n">factor_vae_gamma</span><span class="o">=</span><span class="default_value">0.0</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE._compute_feature_shape">
<code class="sig-name descname">_compute_feature_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE._compute_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.reparameterize">
<code class="sig-name descname">reparameterize</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">mu</span></em>, <em class="sig-param"><span class="n">log_var</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.reparameterize" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.encode">
<code class="sig-name descname">encode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.encode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.encodeZ">
<code class="sig-name descname">encodeZ</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.encodeZ" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE._forward">
<code class="sig-name descname">_forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">evaluation</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">fixed_latent</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">data</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE._forward" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.get_feat_map">
<code class="sig-name descname">get_feat_map</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.get_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.BetaVAE.compute_loss">
<code class="sig-name descname">compute_loss</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">fixed_latent</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">data</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">evaluation</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE.compute_loss" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.UNetBlock">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">UNetBlock</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_channel</span></em>, <em class="sig-param"><span class="n">out_channel</span></em>, <em class="sig-param"><span class="n">upsample</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">interpolate</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">interpolation_factor</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">batch_norm</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.UNetBlock" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.UNetBlock.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.UNetBlock.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.UNet">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">UNet</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">in_channel</span></em>, <em class="sig-param"><span class="n">out_channel</span></em>, <em class="sig-param"><span class="n">basis_nbr_channel</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">block_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">batch_norm</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.UNet" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.UNet.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.UNet.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.AttentionNetwork">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">AttentionNetwork</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">in_channel</span></em>, <em class="sig-param"><span class="n">attention_basis_nbr_channel</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">attention_block_depth</span><span class="o">=</span><span class="default_value">3</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.AttentionNetwork" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py attribute">
<dt id="ReferentialGym.networks.autoregressive_networks.AttentionNetwork.in_channel">
<code class="sig-name descname">in_channel</code><em class="property"> = None</em><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.AttentionNetwork.in_channel" title="Permalink to this definition">¶</a></dt>
<dd><dl class="simple">
<dt>self.unet = UNet(input_shape=self.input_shape,</dt><dd><p>in_channel=self.in_channel,
out_channel=1,
basis_nbr_channel=attention_basis_nbr_channel,
block_depth=attention_block_depth)</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.AttentionNetwork.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">logscope</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.AttentionNetwork.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelAttentionNetwork">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ParallelAttentionNetwork</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">in_channel</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">attention_basis_nbr_channel</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">attention_block_depth</span><span class="o">=</span><span class="default_value">3</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelAttentionNetwork" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelAttentionNetwork.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">logscope</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelAttentionNetwork.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">MONet</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">gamma</span><span class="o">=</span><span class="default_value">0.5</span></em>, <em class="sig-param"><span class="n">input_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">anet_basis_nbr_channel</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">anet_block_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">cvae_beta</span><span class="o">=</span><span class="default_value">0.5</span></em>, <em class="sig-param"><span class="n">cvae_latent_dim</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">cvae_decoder_conv_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">cvae_pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">cvae_resnet_encoder</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">cvae_resnet_nbr_layer</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">cvae_decoder_nbr_layer</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">cvae_EncodingCapacityStep</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">cvae_maxEncodingCapacity</span><span class="o">=</span><span class="default_value">100</span></em>, <em class="sig-param"><span class="n">cvae_nbrEpochTillMaxEncodingCapacity</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">cvae_constrainedEncoding</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">cvae_observation_sigma</span><span class="o">=</span><span class="default_value">0.05</span></em>, <em class="sig-param"><span class="n">compactness_factor</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE" title="ReferentialGym.networks.autoregressive_networks.BetaVAE"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.autoregressive_networks.BetaVAE</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet.encodeZ">
<code class="sig-name descname">encodeZ</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet.encodeZ" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">compute_loss</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.MONet.compute_loss">
<code class="sig-name descname">compute_loss</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.MONet.compute_loss" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.autoregressive_networks.</code><code class="sig-name descname">ParallelMONet</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">gamma</span><span class="o">=</span><span class="default_value">0.5</span></em>, <em class="sig-param"><span class="n">input_shape</span><span class="o">=</span><span class="default_value">[3, 64, 64]</span></em>, <em class="sig-param"><span class="n">nbr_attention_slot</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">anet_basis_nbr_channel</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">anet_block_depth</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">cvae_beta</span><span class="o">=</span><span class="default_value">0.5</span></em>, <em class="sig-param"><span class="n">cvae_latent_dim</span><span class="o">=</span><span class="default_value">10</span></em>, <em class="sig-param"><span class="n">cvae_decoder_conv_dim</span><span class="o">=</span><span class="default_value">32</span></em>, <em class="sig-param"><span class="n">cvae_pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">cvae_resnet_encoder</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">cvae_resnet_nbr_layer</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">cvae_decoder_nbr_layer</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">cvae_EncodingCapacityStep</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">cvae_maxEncodingCapacity</span><span class="o">=</span><span class="default_value">100</span></em>, <em class="sig-param"><span class="n">cvae_nbrEpochTillMaxEncodingCapacity</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">cvae_constrainedEncoding</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">cvae_observation_sigma</span><span class="o">=</span><span class="default_value">0.05</span></em>, <em class="sig-param"><span class="n">compactness_factor</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.autoregressive_networks.BetaVAE" title="ReferentialGym.networks.autoregressive_networks.BetaVAE"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.autoregressive_networks.BetaVAE</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet.encodeZ">
<code class="sig-name descname">encodeZ</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet.encodeZ" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet.decode">
<code class="sig-name descname">decode</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">z</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet.decode" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">compute_loss</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.autoregressive_networks.ParallelMONet.compute_loss">
<code class="sig-name descname">compute_loss</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">observation_sigma</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.autoregressive_networks.ParallelMONet.compute_loss" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

</div>
<div class="section" id="module-ReferentialGym.networks.homoscedastic_multitask_loss">
<span id="referentialgym-networks-homoscedastic-multitask-loss-module"></span><h2>ReferentialGym.networks.homoscedastic_multitask_loss module<a class="headerlink" href="#module-ReferentialGym.networks.homoscedastic_multitask_loss" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="ReferentialGym.networks.homoscedastic_multitask_loss.HomoscedasticMultiTasksLoss">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.homoscedastic_multitask_loss.</code><code class="sig-name descname">HomoscedasticMultiTasksLoss</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">nbr_tasks</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">use_cuda</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.homoscedastic_multitask_loss.HomoscedasticMultiTasksLoss" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.homoscedastic_multitask_loss.HomoscedasticMultiTasksLoss.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">loss_dict</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.homoscedastic_multitask_loss.HomoscedasticMultiTasksLoss.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>loss_dict</strong> – Dict[str, Tuple(float, torch.Tensor)] that associates loss names
with their pair of (linear coefficient, loss), where the loss
is in batched shape: (batch_size, 1)</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="module-ReferentialGym.networks.networks">
<span id="referentialgym-networks-networks-module"></span><h2>ReferentialGym.networks.networks module<a class="headerlink" href="#module-ReferentialGym.networks.networks" title="Permalink to this headline">¶</a></h2>
<dl class="py function">
<dt id="ReferentialGym.networks.networks.retrieve_output_shape">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">retrieve_output_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input</span></em>, <em class="sig-param"><span class="n">model</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.retrieve_output_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.hasnan">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">hasnan</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">tensor</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.hasnan" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.handle_nan">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">handle_nan</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">layer</span></em>, <em class="sig-param"><span class="n">verbose</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.handle_nan" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.layer_init">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">layer_init</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">layer</span></em>, <em class="sig-param"><span class="n">w_scale</span><span class="o">=</span><span class="default_value">1.0</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.layer_init" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.addXYfeatures">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">addXYfeatures</code><a class="headerlink" href="#ReferentialGym.networks.networks.addXYfeatures" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.addXYfeatures.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">outputFsizes</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.addXYfeatures.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.addXYRhoThetaFeatures">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">addXYRhoThetaFeatures</code><a class="headerlink" href="#ReferentialGym.networks.networks.addXYRhoThetaFeatures" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.addXYRhoThetaFeatures.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">outputFsizes</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.addXYRhoThetaFeatures.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.conv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">conv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">k</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.conv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.conv3x3">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">conv3x3</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilation</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.conv3x3" title="Permalink to this definition">¶</a></dt>
<dd><p>3x3 convolution with padding</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.conv1x1">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">conv1x1</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.conv1x1" title="Permalink to this definition">¶</a></dt>
<dd><p>1x1 convolution</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.deconv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">deconv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">k</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.deconv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coordconv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coordconv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">kernel_size</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">bias</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilation</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coordconv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coordconv3x3">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coordconv3x3</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilation</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coordconv3x3" title="Permalink to this definition">¶</a></dt>
<dd><p>3x3 coord convolution with padding</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coordconv1x1">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coordconv1x1</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coordconv1x1" title="Permalink to this definition">¶</a></dt>
<dd><p>1x1 coord convolution</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coorddeconv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coorddeconv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">kernel_size</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">bias</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coorddeconv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coord4conv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coord4conv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">kernel_size</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">0</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">bias</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilation</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coord4conv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coord4conv3x3">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coord4conv3x3</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilation</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coord4conv3x3" title="Permalink to this definition">¶</a></dt>
<dd><p>3x3 coord convolution with padding</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coord4conv1x1">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coord4conv1x1</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">in_planes</span></em>, <em class="sig-param"><span class="n">out_planes</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coord4conv1x1" title="Permalink to this definition">¶</a></dt>
<dd><p>1x1 coord convolution</p>
</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks.coord4deconv">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">coord4deconv</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">sin</span></em>, <em class="sig-param"><span class="n">sout</span></em>, <em class="sig-param"><span class="n">kernel_size</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">2</span></em>, <em class="sig-param"><span class="n">padding</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">batchNorm</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">bias</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.coord4deconv" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.FCBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">FCBody</code><span class="sig-paren">(</span><em class="sig-param">state_dim</em>, <em class="sig-param">hidden_units=(64</em>, <em class="sig-param">64)</em>, <em class="sig-param">gate=&lt;function relu&gt;</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.FCBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.FCBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.FCBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.FCBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.FCBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ConvolutionalBody</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.LeakyReLU'&gt;], use_coordconv=None</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody._compute_feat_map">
<code class="sig-name descname">_compute_feat_map</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody._compute_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody.get_feat_map">
<code class="sig-name descname">get_feat_map</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody.get_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">non_lin_output</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody.get_input_shape">
<code class="sig-name descname">get_input_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody.get_input_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalBody._compute_feature_shape">
<code class="sig-name descname">_compute_feature_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalBody._compute_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.EntityPrioredConvolutionalBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">EntityPrioredConvolutionalBody</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.LeakyReLU'&gt;], use_coordconv=None</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.EntityPrioredConvolutionalBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.networks.ConvolutionalBody" title="ReferentialGym.networks.networks.ConvolutionalBody"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.networks.ConvolutionalBody</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.EntityPrioredConvolutionalBody._compute_feat_map">
<code class="sig-name descname">_compute_feat_map</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.EntityPrioredConvolutionalBody._compute_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ConvolutionalLstmBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ConvolutionalLstmBody</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, rnn_hidden_units=(256,), dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.ReLU'&gt;], gate=&lt;function relu&gt;, use_coordconv=None</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalLstmBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.networks.ConvolutionalBody" title="ReferentialGym.networks.networks.ConvolutionalBody"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.networks.ConvolutionalBody</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalLstmBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">inputs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalLstmBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>inputs</strong> – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).</p>
</dd>
</dl>
<p>hidden_states: list of hidden_state(s) one for each self.layers.
cell_states: list of hidden_state(s) one for each self.layers.</p>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalLstmBody.get_reset_states">
<code class="sig-name descname">get_reset_states</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cuda</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">repeat</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalLstmBody.get_reset_states" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalLstmBody.get_input_shape">
<code class="sig-name descname">get_input_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalLstmBody.get_input_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalLstmBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalLstmBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ConvolutionalGruBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ConvolutionalGruBody</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, rnn_hidden_units=(256,), dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.ReLU'&gt;], gate=&lt;function relu&gt;, use_coordconv=None</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalGruBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.networks.ConvolutionalBody" title="ReferentialGym.networks.networks.ConvolutionalBody"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.networks.ConvolutionalBody</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalGruBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">inputs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalGruBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>inputs</strong> – input to GRU cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).</p>
</dd>
</dl>
<p>hidden_states: list of hidden_state(s) one for each self.layers.
cell_states: list of hidden_state(s) one for each self.layers.</p>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalGruBody.get_reset_states">
<code class="sig-name descname">get_reset_states</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cuda</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">repeat</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalGruBody.get_reset_states" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalGruBody.get_input_shape">
<code class="sig-name descname">get_input_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalGruBody.get_input_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalGruBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalGruBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.LSTMBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">LSTMBody</code><span class="sig-paren">(</span><em class="sig-param">state_dim</em>, <em class="sig-param">rnn_hidden_units=256</em>, <em class="sig-param">gate=&lt;function relu&gt;</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.LSTMBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.LSTMBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">inputs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.LSTMBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>inputs</strong> – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).</p>
</dd>
</dl>
<p>hidden_states: list of hidden_state(s) one for each self.layers.
cell_states: list of hidden_state(s) one for each self.layers.</p>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.LSTMBody.get_reset_states">
<code class="sig-name descname">get_reset_states</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cuda</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">repeat</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.LSTMBody.get_reset_states" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.LSTMBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.LSTMBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.GRUBody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">GRUBody</code><span class="sig-paren">(</span><em class="sig-param">state_dim</em>, <em class="sig-param">rnn_hidden_units=256</em>, <em class="sig-param">gate=&lt;function relu&gt;</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.GRUBody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.GRUBody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">inputs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.GRUBody.forward" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>inputs</strong> – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).</p>
</dd>
</dl>
<p>hidden_states: list of hidden_state(s) one for each self.layers.
cell_states: list of hidden_state(s) one for each self.layers.</p>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.GRUBody.get_reset_states">
<code class="sig-name descname">get_reset_states</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cuda</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">repeat</span><span class="o">=</span><span class="default_value">1</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.GRUBody.get_reset_states" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.GRUBody.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.GRUBody.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.MHDPA">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">MHDPA</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">depth_dim</span><span class="o">=</span><span class="default_value">37</span></em>, <em class="sig-param"><span class="n">interactions_dim</span><span class="o">=</span><span class="default_value">64</span></em>, <em class="sig-param"><span class="n">hidden_size</span><span class="o">=</span><span class="default_value">256</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">usef</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA.save">
<code class="sig-name descname">save</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">path</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA.save" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA.load">
<code class="sig-name descname">load</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">path</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA.load" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.MHDPA_RN">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">MHDPA_RN</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">depth_dim</span><span class="o">=</span><span class="default_value">35</span></em>, <em class="sig-param"><span class="n">nbrHead</span><span class="o">=</span><span class="default_value">3</span></em>, <em class="sig-param"><span class="n">nbrRecurrentSharedLayers</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">nbrEntity</span><span class="o">=</span><span class="default_value">7</span></em>, <em class="sig-param"><span class="n">units_per_MLP_layer</span><span class="o">=</span><span class="default_value">256</span></em>, <em class="sig-param"><span class="n">interactions_dim</span><span class="o">=</span><span class="default_value">128</span></em>, <em class="sig-param"><span class="n">output_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">dropout_prob</span><span class="o">=</span><span class="default_value">0.0</span></em>, <em class="sig-param"><span class="n">use_coord4</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA_RN" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA_RN.forwardScaledDPAhead">
<code class="sig-name descname">forwardScaledDPAhead</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em>, <em class="sig-param"><span class="n">head</span></em>, <em class="sig-param"><span class="n">reset_hidden_states</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA_RN.forwardScaledDPAhead" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA_RN.forwardStackedMHDPA">
<code class="sig-name descname">forwardStackedMHDPA</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">augx</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA_RN.forwardStackedMHDPA" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.MHDPA_RN.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">augx</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.MHDPA_RN.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ConvolutionalMHDPABody">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ConvolutionalMHDPABody</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.LeakyReLU'&gt;], use_coordconv=None, nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128, use_coord4=False</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalMHDPABody" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.networks.ConvolutionalBody" title="ReferentialGym.networks.networks.ConvolutionalBody"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.networks.ConvolutionalBody</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ConvolutionalMHDPABody.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ConvolutionalMHDPABody.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.VGG">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">VGG</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">features</span></em>, <em class="sig-param"><span class="n">num_classes</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">init_weights</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.VGG" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<p>Making the VGG architecture usable as a classification-layer-free
convolutional architecture to choose from.</p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.VGG.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.VGG.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.VGG._initialize_weights">
<code class="sig-name descname">_initialize_weights</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.VGG._initialize_weights" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py function">
<dt id="ReferentialGym.networks.networks._vgg">
<code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">_vgg</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">arch</span></em>, <em class="sig-param"><span class="n">cfg</span></em>, <em class="sig-param"><span class="n">batch_norm</span></em>, <em class="sig-param"><span class="n">pretrained</span></em>, <em class="sig-param"><span class="n">progress</span></em>, <em class="sig-param"><span class="o">**</span><span class="n">kwargs</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks._vgg" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ModelVGG16">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ModelVGG16</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">feature_dim</span><span class="o">=</span><span class="default_value">512</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">True</span></em>, <em class="sig-param"><span class="n">final_layer_idx</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ModelVGG16" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ModelVGG16.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ModelVGG16.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.networks.ModelVGG16.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ModelVGG16.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.networks.ExtractorVGG16">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.networks.</code><code class="sig-name descname">ExtractorVGG16</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">final_layer_idx</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ExtractorVGG16" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.networks.ExtractorVGG16.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.networks.ExtractorVGG16.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="module-ReferentialGym.networks.residual_networks">
<span id="referentialgym-networks-residual-networks-module"></span><h2>ReferentialGym.networks.residual_networks module<a class="headerlink" href="#module-ReferentialGym.networks.residual_networks" title="Permalink to this headline">¶</a></h2>
<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ResNet">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ResNet</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">block</span></em>, <em class="sig-param"><span class="n">layers</span></em>, <em class="sig-param"><span class="n">num_classes</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">zero_init_residual</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">width_per_group</span><span class="o">=</span><span class="default_value">64</span></em>, <em class="sig-param"><span class="n">replace_stride_with_dilation</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">norm_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet._make_layer">
<code class="sig-name descname">_make_layer</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">block</span></em>, <em class="sig-param"><span class="n">planes</span></em>, <em class="sig-param"><span class="n">blocks</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilate</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet._make_layer" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet._forward_impl">
<code class="sig-name descname">_forward_impl</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet._forward_impl" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.CoordResNet">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">CoordResNet</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">block</span></em>, <em class="sig-param"><span class="n">layers</span></em>, <em class="sig-param"><span class="n">num_classes</span><span class="o">=</span><span class="default_value">1000</span></em>, <em class="sig-param"><span class="n">zero_init_residual</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">groups</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">width_per_group</span><span class="o">=</span><span class="default_value">64</span></em>, <em class="sig-param"><span class="n">replace_stride_with_dilation</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">norm_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.CoordResNet" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.CoordResNet._make_layer">
<code class="sig-name descname">_make_layer</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">block</span></em>, <em class="sig-param"><span class="n">planes</span></em>, <em class="sig-param"><span class="n">blocks</span></em>, <em class="sig-param"><span class="n">stride</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">dilate</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.CoordResNet._make_layer" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.CoordResNet._forward_impl">
<code class="sig-name descname">_forward_impl</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.CoordResNet._forward_impl" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.CoordResNet.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.CoordResNet.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ModelResNet18</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">feature_dim</span><span class="o">=</span><span class="default_value">256</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torchvision.models.resnet.ResNet</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18._compute_feature_shape">
<code class="sig-name descname">_compute_feature_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_dim</span></em>, <em class="sig-param"><span class="n">nbr_layer</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18._compute_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18._compute_feat_map">
<code class="sig-name descname">_compute_feat_map</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18._compute_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18.get_feat_map">
<code class="sig-name descname">get_feat_map</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18.get_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18._compute_features">
<code class="sig-name descname">_compute_features</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">features_map</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18._compute_features" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ModelResNet18AvgPooled</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">feature_dim</span><span class="o">=</span><span class="default_value">256</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">detach_conv_maps</span><span class="o">=</span><span class="default_value">False</span></em>, <em class="sig-param"><span class="n">use_coordconv</span><span class="o">=</span><span class="default_value">False</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torchvision.models.resnet.ResNet</span></code></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_feature_shape">
<code class="sig-name descname">_compute_feature_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_feat_map">
<code class="sig-name descname">_compute_feat_map</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_features">
<code class="sig-name descname">_compute_features</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">features_map</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled._compute_features" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.get_feat_map">
<code class="sig-name descname">get_feat_map</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.get_feat_map" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.get_feature_shape">
<code class="sig-name descname">get_feature_shape</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled.get_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ResNet18MHDPA">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ResNet18MHDPA</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, nbr_layer=None, pretrained=False, use_coordconv=False, dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.LeakyReLU'&gt;], nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet18MHDPA" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="ReferentialGym.networks.residual_networks.ModelResNet18"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet18MHDPA.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet18MHDPA.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ResNet18AvgPooledMHDPA</code><span class="sig-paren">(</span><em class="sig-param">input_shape, feature_dim=256, nbr_layer=None, pretrained=False, detach_conv_maps=False, use_coordconv=False, dropout=0.0, non_linearities=[&lt;class 'torch.nn.modules.activation.LeakyReLU'&gt;], nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128</em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled" title="ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA._compute_feature_shape">
<code class="sig-name descname">_compute_feature_shape</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">nbr_layer</span><span class="o">=</span><span class="default_value">None</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA._compute_feature_shape" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt id="ReferentialGym.networks.residual_networks.ExtractorResNet18">
<em class="property">class </em><code class="sig-prename descclassname">ReferentialGym.networks.residual_networks.</code><code class="sig-name descname">ExtractorResNet18</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">input_shape</span></em>, <em class="sig-param"><span class="n">final_layer_idx</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">pretrained</span><span class="o">=</span><span class="default_value">True</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ExtractorResNet18" title="Permalink to this definition">¶</a></dt>
<dd><p>Bases: <a class="reference internal" href="#ReferentialGym.networks.residual_networks.ModelResNet18" title="ReferentialGym.networks.residual_networks.ModelResNet18"><code class="xref py py-class docutils literal notranslate"><span class="pre">ReferentialGym.networks.residual_networks.ModelResNet18</span></code></a></p>
<dl class="py method">
<dt id="ReferentialGym.networks.residual_networks.ExtractorResNet18.forward">
<code class="sig-name descname">forward</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">x</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.residual_networks.ExtractorResNet18.forward" title="Permalink to this definition">¶</a></dt>
<dd><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div>
</dd></dl>

</dd></dl>

</div>
<div class="section" id="module-ReferentialGym.networks">
<span id="module-contents"></span><h2>Module contents<a class="headerlink" href="#module-ReferentialGym.networks" title="Permalink to this headline">¶</a></h2>
<dl class="py function">
<dt id="ReferentialGym.networks.choose_architecture">
<code class="sig-prename descclassname">ReferentialGym.networks.</code><code class="sig-name descname">choose_architecture</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">architecture</span></em>, <em class="sig-param"><span class="n">kwargs</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">fc_hidden_units_list</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">rnn_hidden_units_list</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">input_shape</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">feature_dim</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">nbr_channels_list</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">kernels</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">strides</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">paddings</span><span class="o">=</span><span class="default_value">None</span></em>, <em class="sig-param"><span class="n">dropout</span><span class="o">=</span><span class="default_value">0.0</span></em>, <em class="sig-param"><span class="n">MHDPANbrHead</span><span class="o">=</span><span class="default_value">4</span></em>, <em class="sig-param"><span class="n">MHDPANbrRecUpdate</span><span class="o">=</span><span class="default_value">1</span></em>, <em class="sig-param"><span class="n">MHDPANbrMLPUnit</span><span class="o">=</span><span class="default_value">512</span></em>, <em class="sig-param"><span class="n">MHDPAInteractionDim</span><span class="o">=</span><span class="default_value">128</span></em><span class="sig-paren">)</span><a class="headerlink" href="#ReferentialGym.networks.choose_architecture" title="Permalink to this definition">¶</a></dt>
<dd></dd></dl>

</div>
</div>


          </div>
        </div>
      </div>
      <div class="sphinxsidebar" role="navigation" aria-label="main navigation">
        <div class="sphinxsidebarwrapper">
  <h3><a href="index.html">Table of Contents</a></h3>
  <ul>
<li><a class="reference internal" href="#">ReferentialGym.networks package</a><ul>
<li><a class="reference internal" href="#submodules">Submodules</a></li>
<li><a class="reference internal" href="#module-ReferentialGym.networks.autoregressive_networks">ReferentialGym.networks.autoregressive_networks module</a></li>
<li><a class="reference internal" href="#module-ReferentialGym.networks.homoscedastic_multitask_loss">ReferentialGym.networks.homoscedastic_multitask_loss module</a></li>
<li><a class="reference internal" href="#module-ReferentialGym.networks.networks">ReferentialGym.networks.networks module</a></li>
<li><a class="reference internal" href="#module-ReferentialGym.networks.residual_networks">ReferentialGym.networks.residual_networks module</a></li>
<li><a class="reference internal" href="#module-ReferentialGym.networks">Module contents</a></li>
</ul>
</li>
</ul>

  <h4>Previous topic</h4>
  <p class="topless"><a href="ReferentialGym.modules.html"
                        title="previous chapter">ReferentialGym.modules package</a></p>
  <h4>Next topic</h4>
  <p class="topless"><a href="ReferentialGym.utils.html"
                        title="next chapter">ReferentialGym.utils package</a></p>
  <div role="note" aria-label="source link">
    <h3>This Page</h3>
    <ul class="this-page-menu">
      <li><a href="_sources/ReferentialGym.networks.rst.txt"
            rel="nofollow">Show Source</a></li>
    </ul>
   </div>
<div id="searchbox" style="display: none" role="search">
  <h3 id="searchlabel">Quick search</h3>
    <div class="searchformwrapper">
    <form class="search" action="search.html" method="get">
      <input type="text" name="q" aria-labelledby="searchlabel" />
      <input type="submit" value="Go" />
    </form>
    </div>
</div>
<script>$('#searchbox').show(0);</script>
        </div>
      </div>
      <div class="clearer"></div>
    </div>
    <div class="related" role="navigation" aria-label="related navigation">
      <h3>Navigation</h3>
      <ul>
        <li class="right" style="margin-right: 10px">
          <a href="genindex.html" title="General Index"
             >index</a></li>
        <li class="right" >
          <a href="py-modindex.html" title="Python Module Index"
             >modules</a> |</li>
        <li class="right" >
          <a href="ReferentialGym.utils.html" title="ReferentialGym.utils package"
             >next</a> |</li>
        <li class="right" >
          <a href="ReferentialGym.modules.html" title="ReferentialGym.modules package"
             >previous</a> |</li>
        <li class="nav-item nav-item-0"><a href="index.html">ReferentialGym  documentation</a> &#187;</li>
          <li class="nav-item nav-item-1"><a href="modules.html" >ReferentialGym</a> &#187;</li>
          <li class="nav-item nav-item-2"><a href="ReferentialGym.html" >ReferentialGym package</a> &#187;</li> 
      </ul>
    </div>
    <div class="footer" role="contentinfo">
        &#169; Copyright 2019, Kevin Denamganaï.
      Created using <a href="http://sphinx-doc.org/">Sphinx</a> 3.0.4.
    </div>
  </body>
</html>