<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>Add a new model &mdash; anonymous-toolkit 1.0 documentation</title>
      <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script src="../_static/jquery.js"></script>
        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/sphinx_highlight.js"></script>
    <script src="../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../genindex.html" />
    <link rel="search" title="Search" href="../search.html" />
    <link rel="next" title="Add a new dataset" href="adddataset.html" />
    <link rel="prev" title="Multimodal VAE Comparison Toolkit" href="../index.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../index.html" class="icon icon-home">
            anonymous-toolkit
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Add a new model</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#general-requirements">General requirements</a></li>
<li class="toctree-l2"><a class="reference internal" href="#adding-a-new-model">Adding a new model</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="adddataset.html">Add a new dataset</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Code documentation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../code/trainer.html">MultimodalVAE class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/mmvae_base.html">Multimodal VAE Base Class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/mmvae_models.html">Multimodal VAE models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/encoders.html">Encoders</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/decoders.html">Decoders</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/vae.html">VAE class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/objectives.html">Objectives</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/dataloader.html">DataLoader</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/datasets.html">Dataset Classes</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/infer.html">Inference module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/eval_cdsprites.html">Evaluate on CdSprites+ dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/config_cls.html">Config class</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">anonymous-toolkit</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
      <li class="breadcrumb-item active">Add a new model</li>
      <li class="wy-breadcrumbs-aside">
            <a href="../_sources/tutorials/addmodel.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <section id="add-a-new-model">
<span id="addmodel"></span><h1>Add a new model<a class="headerlink" href="#add-a-new-model" title="Permalink to this heading"></a></h1>
<p>We encourage the authors to implement their own multimodal VAE models into our toolkit. Here we describe how to do it.</p>
<section id="general-requirements">
<h2>General requirements<a class="headerlink" href="#general-requirements" title="Permalink to this heading"></a></h2>
<p>The toolkit is written in PyTorch using the <a class="reference external" href="https://www.pytorchlightning.ai/">PyTorch Lightning</a> framework and we expect new models to use this framework as well. Currently, it is
possible to implement unimodal VAEs and any multimodal VAEs which use dedicated VAE instances for each modality.
You can add a new objective, encoder/decoder networks and of course other support modules that are needed.</p>
<p>Below we show a step-by-step tutorial on how to add a new model.</p>
</section>
<section id="adding-a-new-model">
<h2>Adding a new model<a class="headerlink" href="#adding-a-new-model" title="Permalink to this heading"></a></h2>
<p>First, we start by defining the model in <code class="docutils literal notranslate"><span class="pre">mmvae_models.py</span></code>. Our model will need a name and should inherit the TorchMMVAE
class defined in <code class="docutils literal notranslate"><span class="pre">mmvae_base.py</span></code>.
<code class="docutils literal notranslate"><span class="pre">self.modelName</span></code> will be used for model selection from the config file.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">POE</span><span class="p">(</span><span class="n">TorchMMVAE</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vaes</span><span class="p">:</span><span class="nb">list</span><span class="p">,</span> <span class="n">n_latents</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">obj_config</span><span class="p">:</span><span class="nb">dict</span><span class="p">,</span> <span class="n">model_config</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Multimodal Variaional Autoencoder with Product of Experts https://github.com/mhw32/multimodal-vae-public</span>

<span class="sd">        :param vaes: list of modality-specific vae objects</span>
<span class="sd">        :type vaes: list</span>
<span class="sd">        :param n_latents: dimensionality of the (shared) latent space</span>
<span class="sd">        :type n_latents: int</span>
<span class="sd">        :param obj_cofig: config with objective-specific parameters (obj name, beta.)</span>
<span class="sd">        :type obj_config: dict</span>
<span class="sd">        :param model_cofig: config with model-specific parameters</span>
<span class="sd">        :type model_config: dict</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">n_latents</span><span class="p">,</span> <span class="o">**</span><span class="n">obj_config</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">vaes</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleDict</span><span class="p">(</span><span class="n">vaes</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model_config</span> <span class="o">=</span> <span class="n">model_config</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">modelName</span> <span class="o">=</span> <span class="s1">&#39;poe&#39;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">pz</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">Normal</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">prior_dist</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">Normal</span>
</pre></div>
</div>
<p>The TorchMMVAE class includes the bare functional minimum for a multimodal VAE, i.e. the forward pass, encode and decode functions and modality_mixing function.
The newly added model can override these methods or keep them as they are and only add the modality_mixing method. Here we add the <code class="docutils literal notranslate"><span class="pre">forward()</span></code> pass and all methods necessary for the multimodal data integration. The first input parameter
will be the multimodal data specified in a config where the keys label the modalities and values contain the data (and possibly masks where applicable).</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="hll"> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">K</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
</span><span class="w">     </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">     Forward pass that takes input data and outputs a dict with  posteriors, reconstructions and latent samples</span>
<span class="sd">     :param inputs: input data, a dict of modalities where missing modalities are replaced with None</span>
<span class="sd">     :type inputs: dict</span>
<span class="sd">     :param K: sample K samples from the posterior</span>
<span class="sd">     :type K: int</span>
<span class="sd">     :return: dict where keys are modalities and values are a named tuple</span>
<span class="sd">     :rtype: dict</span>
<span class="sd">     &quot;&quot;&quot;</span>
     <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">,</span> <span class="n">single_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">modality_mixing</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
     <span class="n">qz_x</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="o">*</span><span class="p">[</span><span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">])</span>
     <span class="n">z</span> <span class="o">=</span> <span class="n">qz_x</span><span class="o">.</span><span class="n">rsample</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
     <span class="n">qz_d</span><span class="p">,</span> <span class="n">px_d</span><span class="p">,</span> <span class="n">z_d</span> <span class="o">=</span> <span class="p">{},</span> <span class="p">{},</span> <span class="p">{}</span>
     <span class="k">for</span> <span class="n">mod</span><span class="p">,</span> <span class="n">vae</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">vaes</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="n">px_d</span><span class="p">[</span><span class="n">mod</span><span class="p">]</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">px_z</span><span class="p">(</span><span class="o">*</span><span class="n">vae</span><span class="o">.</span><span class="n">dec</span><span class="p">({</span><span class="s2">&quot;latents&quot;</span><span class="p">:</span> <span class="n">z</span><span class="p">,</span> <span class="s2">&quot;masks&quot;</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="n">mod</span><span class="p">][</span><span class="s2">&quot;masks&quot;</span><span class="p">]}))</span>
     <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
         <span class="n">qz_d</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">qz_x</span>
         <span class="n">z_d</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;latents&quot;</span><span class="p">:</span> <span class="n">z</span><span class="p">,</span> <span class="s2">&quot;masks&quot;</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="n">key</span><span class="p">][</span><span class="s2">&quot;masks&quot;</span><span class="p">]}</span>
     <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_output_dict</span><span class="p">(</span><span class="n">single_params</span><span class="p">,</span> <span class="n">px_d</span><span class="p">,</span> <span class="n">z_d</span><span class="p">,</span> <span class="n">joint_dist</span><span class="o">=</span><span class="n">qz_d</span><span class="p">)</span>

 <span class="k">def</span> <span class="nf">modality_mixing</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">K</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="w">     </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">     Inference module, calculates the joint posterior</span>
<span class="sd">     :param inputs: input data, a dict of modalities where missing modalities are replaced with None</span>
<span class="sd">     :type inputs: dict</span>
<span class="sd">     :param K: sample K samples from the posterior</span>
<span class="sd">     :type K: int</span>
<span class="sd">     :return: joint posterior and individual posteriors</span>
<span class="sd">     :rtype: tuple(torch.tensor, torch.tensor, list, list)</span>
<span class="sd">     &quot;&quot;&quot;</span>
     <span class="n">batch_size</span> <span class="o">=</span> <span class="n">find_out_batch_size</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
     <span class="c1"># initialize the universal prior expert</span>
     <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prior_expert</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_latents</span><span class="p">),</span> <span class="n">use_cuda</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
     <span class="n">single_params</span> <span class="o">=</span> <span class="p">{}</span>
     <span class="k">for</span> <span class="n">m</span><span class="p">,</span> <span class="n">vae</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">vaes</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">if</span> <span class="n">x</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="s2">&quot;data&quot;</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="n">mod_mu</span><span class="p">,</span> <span class="n">mod_logvar</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">enc</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="n">m</span><span class="p">])</span>
             <span class="n">single_params</span><span class="p">[</span><span class="n">m</span><span class="p">]</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="o">*</span><span class="p">[</span><span class="n">mod_mu</span><span class="p">,</span> <span class="n">mod_logvar</span><span class="p">])</span>
             <span class="n">mu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">mu</span><span class="p">,</span> <span class="n">mod_mu</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
             <span class="n">logvar</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">logvar</span><span class="p">,</span> <span class="n">mod_logvar</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
     <span class="c1"># product of experts to combine gaussians</span>
     <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">POE</span><span class="p">,</span> <span class="n">POE</span><span class="p">)</span><span class="o">.</span><span class="n">product_of_experts</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">)</span>
     <span class="k">return</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">,</span> <span class="n">single_params</span>


 <span class="k">def</span> <span class="nf">prior_expert</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">use_cuda</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">     </span><span class="sd">&quot;&quot;&quot;Universal prior expert. Here we use a spherical</span>
<span class="sd">     Gaussian: N(0, 1).</span>
<span class="sd">     @param size: integer</span>
<span class="sd">                  dimensionality of Gaussian</span>
<span class="sd">     @param use_cuda: boolean [default: False]</span>
<span class="sd">                      cast CUDA on variables</span>
<span class="sd">     &quot;&quot;&quot;</span>
     <span class="n">mu</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="p">))</span>
     <span class="n">logvar</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">size</span><span class="p">)))</span>
     <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
         <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="n">mu</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">),</span> <span class="n">logvar</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
     <span class="k">return</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">forward()</span></code> method must return the VAEOutput object located in output_storage.py. Proper placement of the outputs inside this object is handled automatically by TorchMMVAE, you can thus call
<code class="docutils literal notranslate"><span class="pre">self.make_output_dict(encoder_dist=None,</span> <span class="pre">decoder_dist=None,</span> <span class="pre">latent_samples=None,</span> <span class="pre">joint_dist=None,</span> <span class="pre">enc_dist_private=None,</span> <span class="pre">dec_dist_private=None,</span> <span class="pre">joint_decoder_dist=None,</span> <span class="pre">cross_decoder_dist=None)</span></code>. All these arguments are optional
(depends on what your objective function will need) and must be dictionaries with modality names as keys (i.e. {“mod_1: data,, “mod_2”: data2}).</p>
<p>Next, we need to specify the objective() function for this model which will define the training procedure.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span> <span class="k">def</span> <span class="nf">objective</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mods</span><span class="p">):</span>
<span class="w">     </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">     Objective function for PoE</span>

<span class="sd">     :param data: input data with modalities as keys</span>
<span class="sd">     :type data: dict</span>
<span class="sd">     :return obj: dictionary with the obligatory &quot;loss&quot; key on which the model is optimized, plus any other keys that you wish to log</span>
<span class="sd">     :rtype obj: dict</span>
<span class="sd">     &quot;&quot;&quot;</span>
     <span class="n">lpx_zs</span><span class="p">,</span> <span class="n">klds</span><span class="p">,</span> <span class="n">losses</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">mods</span><span class="o">.</span><span class="n">keys</span><span class="p">()))],</span> <span class="p">[],</span> <span class="p">[]</span>
     <span class="n">mods_inputs</span> <span class="o">=</span> <span class="n">subsample_input_modalities</span><span class="p">(</span><span class="n">mods</span><span class="p">)</span>
     <span class="k">for</span> <span class="n">m</span><span class="p">,</span> <span class="n">mods_input</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">mods_inputs</span><span class="p">):</span>
<span class="hll">         <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mods_input</span><span class="p">)</span>
</span>         <span class="n">output_dic</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">unpack_values</span><span class="p">()</span>
         <span class="n">kld</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj_fn</span><span class="o">.</span><span class="n">calc_kld</span><span class="p">(</span><span class="n">output_dic</span><span class="p">[</span><span class="s2">&quot;joint_dist&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">pz</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">pz_params</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">)))</span>
         <span class="n">klds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">kld</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
         <span class="n">loc_lpx_z</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="k">for</span> <span class="n">mod</span> <span class="ow">in</span> <span class="n">output</span><span class="o">.</span><span class="n">mods</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
             <span class="n">px_z</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">mods</span><span class="p">[</span><span class="n">mod</span><span class="p">]</span><span class="o">.</span><span class="n">decoder_dist</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">obj_fn</span><span class="o">.</span><span class="n">set_ltype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vaes</span><span class="p">[</span><span class="n">mod</span><span class="p">]</span><span class="o">.</span><span class="n">ltype</span><span class="p">)</span>
             <span class="n">lpx_z</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_fn</span><span class="o">.</span><span class="n">recon_loss_fn</span><span class="p">(</span><span class="n">px_z</span><span class="p">,</span> <span class="n">mods</span><span class="p">[</span><span class="n">mod</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">vaes</span><span class="p">[</span><span class="n">mod</span><span class="p">]</span><span class="o">.</span><span class="n">llik_scaling</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
             <span class="n">loc_lpx_z</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lpx_z</span><span class="p">)</span>
             <span class="k">if</span> <span class="n">mod</span> <span class="o">==</span> <span class="s2">&quot;mod_</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
                 <span class="n">lpx_zs</span><span class="p">[</span><span class="n">m</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lpx_z</span><span class="p">)</span>
<span class="hll">         <span class="n">d</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;lpx_z&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">loc_lpx_z</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;kld&quot;</span><span class="p">:</span> <span class="n">kld</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;qz_x&quot;</span><span class="p">:</span> <span class="n">output_dic</span><span class="p">[</span><span class="s2">&quot;encoder_dist&quot;</span><span class="p">],</span> <span class="s2">&quot;zs&quot;</span><span class="p">:</span> <span class="n">output_dic</span><span class="p">[</span><span class="s2">&quot;latent_samples&quot;</span><span class="p">],</span> <span class="s2">&quot;pz&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">pz</span><span class="p">,</span> <span class="s2">&quot;pz_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">pz_params</span><span class="p">}</span>
</span><span class="hll">         <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_fn</span><span class="o">.</span><span class="n">calculate_loss</span><span class="p">(</span><span class="n">d</span><span class="p">)[</span><span class="s2">&quot;loss&quot;</span><span class="p">])</span>
</span>     <span class="n">ind_losses</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">m</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">vaes</span><span class="p">[</span><span class="s2">&quot;mod_</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">)]</span><span class="o">.</span><span class="n">llik_scaling</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">lpx_zs</span><span class="p">)]</span>
     <span class="n">obj</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;loss&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(),</span> <span class="s2">&quot;reconstruction_loss&quot;</span><span class="p">:</span> <span class="n">ind_losses</span><span class="p">,</span> <span class="s2">&quot;kld&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">klds</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()}</span>
     <span class="k">return</span> <span class="n">obj</span>
</pre></div>
</div>
<p>In this case, we use the subsampling strategy. We retrieve outputs from the model (line 13), calculate reconstruction losses and KL-divergences. To calculate ELBO (or any other objective),
use <code class="docutils literal notranslate"><span class="pre">self.obj_fn</span> <span class="pre">which</span></code> is an instance of MultimodalObjective in objectives.py. It contains all reconstruction loss terms and objectives like ELBO or IWAE (more to be added). Using these functions helps
unifying the code parts that should be shared among models.</p>
<p>The <code class="docutils literal notranslate"><span class="pre">objective()</span></code> function must return a dictionary which includes the “loss” key and stores a 1D torch.tensor with the computed loss. This will be passed
to the optimizer. You can also add other arbitrary keys that will be automatically logged in tensorboard.</p>
<p>Finally, we need to add our model to the list of all models in <code class="docutils literal notranslate"><span class="pre">__init__.py</span></code> located in the <code class="docutils literal notranslate"><span class="pre">models</span></code> directory:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span> <span class="kn">from</span> <span class="nn">.mmvae_models</span> <span class="kn">import</span> <span class="n">MOE</span> <span class="k">as</span> <span class="n">moe</span>
<span class="hll"> <span class="kn">from</span> <span class="nn">.mmvae_models</span> <span class="kn">import</span> <span class="n">POE</span> <span class="k">as</span> <span class="n">poe</span>
</span> <span class="kn">from</span> <span class="nn">.mmvae_models</span> <span class="kn">import</span> <span class="n">MoPOE</span> <span class="k">as</span> <span class="n">mopoe</span>
 <span class="kn">from</span> <span class="nn">.mmvae_models</span> <span class="kn">import</span> <span class="n">DMVAE</span> <span class="k">as</span> <span class="n">dmvae</span>
 <span class="kn">from</span> <span class="nn">.vae</span> <span class="kn">import</span> <span class="n">VAE</span>
<span class="hll"> <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="n">dmvae</span><span class="p">,</span> <span class="n">moe</span><span class="p">,</span> <span class="n">poe</span><span class="p">,</span> <span class="n">mopoe</span><span class="p">,</span> <span class="n">VAE</span><span class="p">]</span>
</span></pre></div>
</div>
<p>If we need to, we can define specific encoder and decoder networks (although we can also re-use the existing ones).</p>
<p>Now we should be able to train using this model. We need to create a <code class="docutils literal notranslate"><span class="pre">config.yml</span></code> file as follows:</p>
<div class="highlight-yaml notranslate"><div class="highlight"><pre><span></span><span class="nt">batch_size</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">32</span>
<span class="nt">epochs</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">700</span>
<span class="nt">exp_name</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">poe_exp</span>
<span class="nt">labels</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">./data/mnist_svhn/labels.pkl</span>
<span class="nt">lr</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">1e-3</span>
<span class="nt">beta</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">1.5</span>
<span class="hll"><span class="nt">mixing</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">poe</span>
</span><span class="nt">n_latents</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">10</span>
<span class="nt">obj</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">elbo</span>
<span class="nt">optimizer</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">adam</span>
<span class="nt">pre_trained</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">null</span>
<span class="nt">seed</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">2</span>
<span class="nt">viz_freq</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">20</span>
<span class="nt">test_split</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">0.1</span>
<span class="nt">dataset_name</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">mnist_svhn</span>
<span class="nt">modality_1</span><span class="p">:</span>
<span class="w">   </span><span class="nt">decoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">MNIST</span>
<span class="w">   </span><span class="nt">encoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">MNIST</span>
<span class="w">   </span><span class="nt">mod_type</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">image</span>
<span class="w">   </span><span class="nt">recon_loss</span><span class="p">:</span><span class="w">  </span><span class="l l-Scalar l-Scalar-Plain">bce</span>
<span class="w">   </span><span class="nt">path</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">./data/mnist_svhn/mnist</span>
<span class="nt">modality_2</span><span class="p">:</span>
<span class="w">   </span><span class="nt">decoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">SVHN</span>
<span class="w">   </span><span class="nt">encoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">SVHN</span>
<span class="w">   </span><span class="nt">recon_loss</span><span class="p">:</span><span class="w">  </span><span class="l l-Scalar l-Scalar-Plain">bce</span>
<span class="w">   </span><span class="nt">mod_type</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">image</span>
<span class="w">   </span><span class="nt">path</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">./data/mnist_svhn/svhn</span>
</pre></div>
</div>
<p>You can see that we specified “poe” as our multimodal mixing model. After configuring the experiment, we can run the training:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">cd</span> <span class="n">multimodal</span><span class="o">-</span><span class="n">compare</span>
<span class="n">python</span> <span class="n">main</span><span class="o">.</span><span class="n">py</span> <span class="o">--</span><span class="n">cfg</span> <span class="o">./</span><span class="n">configs</span><span class="o">/</span><span class="n">config</span><span class="o">.</span><span class="n">yml</span>
</pre></div>
</div>
</section>
</section>


           </div>
          </div>
          <footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
        <a href="../index.html" class="btn btn-neutral float-left" title="Multimodal VAE Comparison Toolkit" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
        <a href="adddataset.html" class="btn btn-neutral float-right" title="Add a new dataset" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2022, Anonymous.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>