<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>Add a new dataset &mdash; anonymous-toolkit 1.0 documentation</title>
      <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script src="../_static/jquery.js"></script>
        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/sphinx_highlight.js"></script>
    <script src="../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../genindex.html" />
    <link rel="search" title="Search" href="../search.html" />
    <link rel="next" title="MultimodalVAE class" href="../code/trainer.html" />
    <link rel="prev" title="Add a new model" href="addmodel.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../index.html" class="icon icon-home">
            anonymous-toolkit
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="addmodel.html">Add a new model</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Add a new dataset</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#supported-data-formats-config">Supported data formats, config</a></li>
<li class="toctree-l2"><a class="reference internal" href="#adding-a-new-dataset-class">Adding a new dataset class</a></li>
<li class="toctree-l2"><a class="reference internal" href="#different-data-formats">Different data formats</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Code documentation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../code/trainer.html">MultimodalVAE class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/mmvae_base.html">Multimodal VAE Base Class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/mmvae_models.html">Multimodal VAE models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/encoders.html">Encoders</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/decoders.html">Decoders</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/vae.html">VAE class</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/objectives.html">Objectives</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/dataloader.html">DataLoader</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/datasets.html">Dataset Classes</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/infer.html">Inference module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/eval_cdsprites.html">Evaluate on CdSprites+ dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../code/config_cls.html">Config class</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">anonymous-toolkit</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
      <li class="breadcrumb-item active">Add a new dataset</li>
      <li class="wy-breadcrumbs-aside">
            <a href="../_sources/tutorials/adddataset.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <section id="add-a-new-dataset">
<h1>Add a new dataset<a class="headerlink" href="#add-a-new-dataset" title="Permalink to this heading"></a></h1>
<p>By default, we support the proposed CdSprites+ dataset as well as MNIST-SVHN, CelebA, SPRITES, PolyMNIST, FashionMNIST or the Caltech-UCSD Birds (CUB) dataset. Here we describe how
you can train the models on your own data.</p>
<section id="supported-data-formats-config">
<h2>Supported data formats, config<a class="headerlink" href="#supported-data-formats-config" title="Permalink to this heading"></a></h2>
<p>In general, the preferred data formats (supported by default) are:</p>
<ul class="simple">
<li><p>pickle (<code class="docutils literal notranslate"><span class="pre">.pkl</span></code>)</p></li>
<li><p>the pytorch format (<code class="docutils literal notranslate"><span class="pre">.pth</span></code>, <code class="docutils literal notranslate"><span class="pre">.pt</span></code>)</p></li>
<li><p>numpy format (<code class="docutils literal notranslate"><span class="pre">.npy</span></code>)</p></li>
<li><p>hdf5 format (<code class="docutils literal notranslate"><span class="pre">.h5</span></code>)</p></li>
<li><p>a directory containing <code class="docutils literal notranslate"><span class="pre">.png</span></code> or <code class="docutils literal notranslate"><span class="pre">.jpg</span></code> images</p></li>
</ul>
<p>To train with any of these, specify the path to your data in the <code class="docutils literal notranslate"><span class="pre">config.yml</span></code>:</p>
<div class="highlight-yaml notranslate"><div class="highlight"><pre><span></span><span class="nt">batch_size</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">16</span>
<span class="nt">epochs</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">600</span>
<span class="nt">exp_name</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">cub</span>
<span class="nt">labels</span><span class="p">:</span>
<span class="nt">beta</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">1</span>
<span class="nt">lr</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">1e-3</span>
<span class="nt">mixing</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">moe</span>
<span class="nt">n_latents</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">16</span>
<span class="nt">obj</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">elbo</span>
<span class="nt">optimizer</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">adam</span>
<span class="nt">pre_trained</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">null</span>
<span class="nt">seed</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">2</span>
<span class="nt">viz_freq</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">1</span>
<span class="nt">test_split</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">0.1</span>
<span class="hll"><span class="nt">dataset_name</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">cub</span>
</span><span class="nt">modality_1</span><span class="p">:</span>
<span class="hll"><span class="w">  </span><span class="nt">decoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">CNN</span>
</span><span class="hll"><span class="w">  </span><span class="nt">encoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">CNN</span>
</span><span class="hll"><span class="w">  </span><span class="nt">mod_type</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">image</span>
</span><span class="hll"><span class="w">  </span><span class="nt">recon_loss</span><span class="p">:</span><span class="w">  </span><span class="l l-Scalar l-Scalar-Plain">bce</span>
</span><span class="hll"><span class="w">  </span><span class="nt">path</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">./data/cub/images</span>
</span><span class="nt">modality_2</span><span class="p">:</span>
<span class="hll"><span class="w">  </span><span class="nt">decoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">TxtTransformer</span>
</span><span class="hll"><span class="w">  </span><span class="nt">encoder</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">TxtTransformer</span>
</span><span class="hll"><span class="w">  </span><span class="nt">mod_type</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">text</span>
</span><span class="hll"><span class="w">  </span><span class="nt">recon_loss</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">category_ce</span>
</span><span class="hll"><span class="w">  </span><span class="nt">path</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">./data/cub/cub_captions.pkl</span>
</span></pre></div>
</div>
<p>This is an example of the config file for the CUB dataset (for download, see our
README).</p>
<p>As you can see, we specified the path to an image folder (<code class="docutils literal notranslate"><span class="pre">./data/cub/images</span></code>) and to the pickled captions (<code class="docutils literal notranslate"><span class="pre">./data/cub/cub_captions.pkl</span></code>). Both
modalities are expected to be ordered so that they can be semantically matched into pairs (e.g. the first image should match with the first caption).</p>
</section>
<section id="adding-a-new-dataset-class">
<h2>Adding a new dataset class<a class="headerlink" href="#adding-a-new-dataset-class" title="Permalink to this heading"></a></h2>
<p>If you wish to train on your own data, you will need to make a custom dataset class in <code class="docutils literal notranslate"><span class="pre">datasets.py</span></code>. Any new dataset must inherit
from BaseDataset to have some common methods used by the DataModule.</p>
<p>In case of CUB we add it in datasets.py like this:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="linenos"> 1</span> <span class="k">class</span> <span class="nc">CUB</span><span class="p">(</span><span class="n">BaseDataset</span><span class="p">):</span>
<span class="linenos"> 2</span><span class="w">     </span><span class="sd">&quot;&quot;&quot;Dataset class for our processed version of Caltech-UCSD birds dataset. We use the original images and text</span>
<span class="linenos"> 3</span><span class="sd">     represented as sequences of one-hot-encodings for each character (incl. spaces)&quot;&quot;&quot;</span>
<span class="linenos"> 4</span>     <span class="n">feature_dims</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span>
<span class="linenos"> 5</span>                     <span class="s2">&quot;text&quot;</span><span class="p">:</span> <span class="p">[</span><span class="mi">246</span><span class="p">,</span> <span class="mi">27</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
<span class="linenos"> 6</span>                     <span class="p">}</span>  <span class="c1"># these feature_dims are also used by the encoder and decoder networks</span>
<span class="linenos"> 7</span>
<span class="linenos"> 8</span>     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pth</span><span class="p">,</span> <span class="n">testpth</span><span class="p">,</span> <span class="n">mod_type</span><span class="p">):</span>
<span class="linenos"> 9</span>         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">pth</span><span class="p">,</span> <span class="n">testpth</span><span class="p">,</span> <span class="n">mod_type</span><span class="p">)</span>
<span class="linenos">10</span>         <span class="bp">self</span><span class="o">.</span><span class="n">mod_type</span> <span class="o">=</span> <span class="n">mod_type</span>
<span class="linenos">11</span>         <span class="bp">self</span><span class="o">.</span><span class="n">text2img_size</span> <span class="o">=</span> <span class="p">(</span><span class="mi">64</span><span class="p">,</span><span class="mi">380</span><span class="p">,</span><span class="mi">3</span><span class="p">)</span>
<span class="linenos">12</span>
<span class="linenos">13</span>     <span class="k">def</span> <span class="nf">_preprocess_text_onehot</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">14</span><span class="w">         </span><span class="sd">&quot;&quot;&quot;</span>
<span class="linenos">15</span><span class="sd">         General function for loading text strings and preparing them as torch one-hot encodings</span>
<span class="linenos">16</span>
<span class="linenos">17</span><span class="sd">         :return: torch with text encodings and masks</span>
<span class="linenos">18</span><span class="sd">         :rtype: torch.tensor</span>
<span class="linenos">19</span><span class="sd">         &quot;&quot;&quot;</span>
<span class="linenos">20</span>         <span class="bp">self</span><span class="o">.</span><span class="n">has_masks</span> <span class="o">=</span> <span class="kc">True</span>
<span class="linenos">21</span>         <span class="bp">self</span><span class="o">.</span><span class="n">categorical</span> <span class="o">=</span> <span class="kc">True</span>
<span class="linenos">22</span>         <span class="n">data</span> <span class="o">=</span> <span class="p">[</span><span class="n">one_hot_encode</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">f</span><span class="p">),</span> <span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_data_raw</span><span class="p">()]</span>
<span class="linenos">23</span>         <span class="n">data</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]</span>
<span class="linenos">24</span>         <span class="n">masks</span> <span class="o">=</span> <span class="n">lengths_to_mask</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="linenos">25</span>         <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="linenos">26</span>         <span class="n">data_and_masks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">data</span><span class="p">,</span> <span class="n">masks</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="linenos">27</span>         <span class="k">return</span> <span class="n">data_and_masks</span>
<span class="linenos">28</span>
<span class="linenos">29</span>     <span class="k">def</span> <span class="nf">_postprocess_text</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="linenos">30</span>         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="linenos">31</span>             <span class="n">masks</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">&quot;masks&quot;</span><span class="p">]</span>
<span class="linenos">32</span>             <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">&quot;data&quot;</span><span class="p">]</span>
<span class="linenos">33</span>             <span class="n">text</span> <span class="o">=</span> <span class="n">output_onehot2text</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="linenos">34</span>             <span class="k">if</span> <span class="n">masks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="linenos">35</span>                 <span class="n">masks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">count_nonzero</span><span class="p">(</span><span class="n">masks</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="linenos">36</span>                 <span class="n">text</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">[:</span><span class="n">masks</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">text</span><span class="p">)]</span>
<span class="linenos">37</span>         <span class="k">else</span><span class="p">:</span>
<span class="linenos">38</span>             <span class="n">text</span> <span class="o">=</span> <span class="n">output_onehot2text</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="linenos">39</span>         <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">phrase</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">text</span><span class="p">):</span>
<span class="linenos">40</span>             <span class="n">phr</span> <span class="o">=</span> <span class="n">phrase</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">)</span>
<span class="linenos">41</span>             <span class="n">newphr</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">phr</span><span class="p">)</span>
<span class="linenos">42</span>             <span class="n">stringcount</span> <span class="o">=</span> <span class="mi">0</span>
<span class="linenos">43</span>             <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">w</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">phr</span><span class="p">):</span>
<span class="linenos">44</span>                 <span class="n">stringcount</span> <span class="o">+=</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">w</span><span class="p">))</span><span class="o">+</span><span class="mi">1</span>
<span class="linenos">45</span>                 <span class="k">if</span> <span class="n">stringcount</span> <span class="o">&gt;</span> <span class="mi">40</span><span class="p">:</span>
<span class="linenos">46</span>                     <span class="n">newphr</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="linenos">47</span>                     <span class="n">stringcount</span> <span class="o">=</span> <span class="mi">0</span>
<span class="linenos">48</span>             <span class="n">text</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">newphr</span><span class="p">))</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">  &quot;</span><span class="p">,</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2"> &quot;</span><span class="p">)</span>
<span class="linenos">49</span>         <span class="k">return</span> <span class="n">text</span>
<span class="linenos">50</span>
<span class="linenos">51</span>     <span class="k">def</span> <span class="nf">labels</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">52</span><span class="w">         </span><span class="sd">&quot;&quot;&quot;</span>
<span class="linenos">53</span><span class="sd">         No labels for T-SNAE available</span>
<span class="linenos">54</span><span class="sd">         &quot;&quot;&quot;</span>
<span class="linenos">55</span>         <span class="k">return</span> <span class="kc">None</span>
<span class="linenos">56</span>
<span class="linenos">57</span>     <span class="k">def</span> <span class="nf">_preprocess_text</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">58</span>         <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_data_raw</span><span class="p">()</span>
<span class="linenos">59</span>         <span class="bp">self</span><span class="o">.</span><span class="n">has_masks</span> <span class="o">=</span> <span class="kc">True</span>
<span class="linenos">60</span>         <span class="bp">self</span><span class="o">.</span><span class="n">categorical</span> <span class="o">=</span> <span class="kc">True</span>
<span class="linenos">61</span>         <span class="n">data</span> <span class="o">=</span> <span class="p">[</span><span class="n">one_hot_encode</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">f</span><span class="p">),</span> <span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">d</span><span class="p">]</span>
<span class="linenos">62</span>         <span class="n">data</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]</span>
<span class="linenos">63</span>         <span class="n">masks</span> <span class="o">=</span> <span class="n">lengths_to_mask</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="linenos">64</span>         <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="linenos">65</span>         <span class="n">data_and_masks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">data</span><span class="p">,</span> <span class="n">masks</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="linenos">66</span>         <span class="k">return</span> <span class="n">data_and_masks</span>
<span class="linenos">67</span>
<span class="linenos">68</span>     <span class="k">def</span> <span class="nf">_preprocess_images</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">69</span>         <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_data_raw</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">feature_dims</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">]])</span>
<span class="linenos">70</span>         <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
<span class="linenos">71</span>         <span class="k">return</span> <span class="n">data</span>
<span class="linenos">72</span>
<span class="linenos">73</span>     <span class="k">def</span> <span class="nf">_mod_specific_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">74</span>         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_preprocess_images</span><span class="p">,</span> <span class="s2">&quot;text&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_preprocess_text</span><span class="p">}</span>
<span class="linenos">75</span>
<span class="linenos">76</span>     <span class="k">def</span> <span class="nf">_mod_specific_savers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="linenos">77</span>         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_postprocess_images</span><span class="p">,</span> <span class="s2">&quot;text&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_postprocess_text</span><span class="p">}</span>
<span class="linenos">78</span>
<span class="linenos">79</span>     <span class="k">def</span> <span class="nf">save_recons</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">recons</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">mod_names</span><span class="p">):</span>
<span class="linenos">80</span>         <span class="n">output_processed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_postprocess_all2img</span><span class="p">(</span><span class="n">recons</span><span class="p">)</span>
<span class="linenos">81</span>         <span class="n">outs</span> <span class="o">=</span> <span class="n">add_recon_title</span><span class="p">(</span><span class="n">output_processed</span><span class="p">,</span> <span class="s2">&quot;output</span><span class="se">\n</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mod_type</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">170</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="linenos">82</span>         <span class="n">input_processed</span> <span class="o">=</span> <span class="p">[]</span>
<span class="linenos">83</span>         <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="linenos">84</span>             <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mod_specific_savers</span><span class="p">()[</span><span class="n">mod_names</span><span class="p">[</span><span class="n">key</span><span class="p">]](</span><span class="n">d</span><span class="p">)</span>
<span class="linenos">85</span>             <span class="n">images</span> <span class="o">=</span> <span class="n">turn_text2image</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">img_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">text2img_size</span><span class="p">)</span> <span class="k">if</span> <span class="n">mod_names</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;text&quot;</span> \
<span class="linenos">86</span>                 <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">output</span><span class="p">,(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">feature_dims</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]))</span>
<span class="linenos">87</span>             <span class="n">images</span> <span class="o">=</span> <span class="n">add_recon_title</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="s2">&quot;input</span><span class="se">\n</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">mod_names</span><span class="p">[</span><span class="n">key</span><span class="p">]),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">))</span>
<span class="linenos">88</span>             <span class="n">input_processed</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">images</span><span class="p">))</span>
<span class="linenos">89</span>             <span class="n">input_processed</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">images</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">*</span><span class="mi">125</span><span class="p">)</span>
<span class="linenos">90</span>         <span class="n">inputs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">(</span><span class="n">input_processed</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;uint8&quot;</span><span class="p">)</span>
<span class="linenos">91</span>         <span class="n">final</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">inputs</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">outs</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;uint8&quot;</span><span class="p">)))</span>
<span class="linenos">92</span>         <span class="n">cv2</span><span class="o">.</span><span class="n">imwrite</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">final</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">))</span>
</pre></div>
</div>
<p>Eventhough the dataset is multimodal, a new instance of it will be created for each modality. Therefore,
the constructor gets two arguments: path to the modality (str) and eventually path to the test data (this is used for evaluation after training), and modality_type (str). Modality type is any string
that you assign to the given modality to distinguish it from the others. For CUB we chose “image” for images and “text” for text, for MNIST_SVHN
we have “mnist” and “svhn”. You specify mod_type in the config.
You also need to specify the expected shape of the data in the class attribute “feature_dims”. This will be used by the dataset class to postprocess the data (i.e. reconstructions produced by the model), but also by the encoder and decoder networks to adjust sizes of the network layers.</p>
<p>Next thing you need are methods that prepare each modality for training (<code class="docutils literal notranslate"><span class="pre">_preprocess_text</span></code> and <code class="docutils literal notranslate"><span class="pre">_preprocess_images</span></code>). Here we use <code class="docutils literal notranslate"><span class="pre">_preprocess_images</span></code> from CdSprites+, since it is the same format, and only rewrite _preprocess_text.  Data loading is handled automatically by BaseDataset, so you
only perform reshaping, converting to tensors etc., so that these functions return tensors of the same length on the output.
Note: In case of sequential data (like text here), we make boolean masks and concatenate them with the last dimension of the text data. This is then automatically handled by the collate function.</p>
<p>Another thing we need to do is map the data processing functions to the modality types, i.e. define <code class="docutils literal notranslate"><span class="pre">_mod_specific_loaders()</span></code> and <code class="docutils literal notranslate"><span class="pre">_mod_specific_savers()</span></code>:</p>
<div class="highlight-yaml notranslate"><div class="highlight"><pre><span></span><span class="nt">def _mod_specific_loaders(self)</span><span class="p">:</span>
<span class="w">    </span><span class="l l-Scalar l-Scalar-Plain">return {&quot;image&quot;</span><span class="p p-Indicator">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">self._preprocess_images, &quot;text&quot;</span><span class="p p-Indicator">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">self._preprocess_text}</span>

<span class="nt">def _mod_specific_savers(self)</span><span class="p">:</span>
<span class="w">    </span><span class="l l-Scalar l-Scalar-Plain">return {&quot;image&quot;</span><span class="p p-Indicator">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">self._postprocess_images, &quot;text&quot;</span><span class="p p-Indicator">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">self._postprocess_text}</span>
</pre></div>
</div>
<p>Here we just assign the above-mentioned methods to the selected mod_types. Once this is done, the dataset class should be ready and you can launch training.</p>
<p>Finally, we can configure how are the outputs saved for visualization. This can be data-dependent, the <code class="docutils literal notranslate"><span class="pre">save_recons()</span></code> method shown in the example is suited
for putting images and text next to each other in one image. The <code class="docutils literal notranslate"><span class="pre">_postprocess_all2img()</span></code> method prints the string into image of the size <code class="docutils literal notranslate"><span class="pre">self.text2image_size</span></code>
(defined in __init__, see Line 11).</p>
</section>
<section id="different-data-formats">
<h2>Different data formats<a class="headerlink" href="#different-data-formats" title="Permalink to this heading"></a></h2>
<p>If you want to train on an unsupported data format, you can file an issue on our GitHub repository.
Alternatively, you can try to incorporate it on your own as it is only a matter of adjusting one function in <code class="docutils literal notranslate"><span class="pre">utils.py</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">load_data</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Returns loaded data based on path suffix</span>
<span class="sd">    :param path: Path to data</span>
<span class="sd">    :type path: str</span>
<span class="sd">    :return: loaded data</span>
<span class="sd">    :rtype: object</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">path</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">):</span>
        <span class="n">path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">get_root_folder</span><span class="p">(),</span> <span class="n">path</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="s2">&quot;Path does not exist: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">load_images</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">suffix</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;.pt&quot;</span><span class="p">,</span><span class="s2">&quot;.pth&quot;</span><span class="p">]:</span>
        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">suffix</span> <span class="o">==</span> <span class="s2">&quot;.pkl&quot;</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">load_pickle</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">suffix</span> <span class="o">==</span> <span class="s2">&quot;.h5&quot;</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">suffix</span> <span class="o">==</span> <span class="s2">&quot;.npy&quot;</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
    <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;Unrecognized dataset format. Supported types are: .pkl, .pth or directory with images&quot;</span><span class="p">)</span>
</pre></div>
</div>
<p>Please note that by default, we have incorporated encoders and decoders for images (preferably in 32x32x3 or 64x64x3 resolution, resp. 28x28x1 pixels for MNIST),
text data (arbitrary strings which we encode on the character-level) and sequential data (e.g. actions suitable for a Transformer network). If you add a new data structure or image resolution,
you will also need to add or adjust the encoder and decoder networks - you can then specify these in the config file.</p>
</section>
</section>


           </div>
          </div>
          <footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
        <a href="addmodel.html" class="btn btn-neutral float-left" title="Add a new model" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
        <a href="../code/trainer.html" class="btn btn-neutral float-right" title="MultimodalVAE class" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2022, Anonymous.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>