

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>deeprobust.image.utils &mdash; DeepRobust 0.1.1 documentation</title>
  

  
  <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />

  
  
  
  

  
  <!--[if lt IE 9]>
    <script src="../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
        <script type="text/javascript" src="../../../_static/jquery.js"></script>
        <script type="text/javascript" src="../../../_static/underscore.js"></script>
        <script type="text/javascript" src="../../../_static/doctools.js"></script>
        <script type="text/javascript" src="../../../_static/language_data.js"></script>
        <script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    
    <script type="text/javascript" src="../../../_static/js/theme.js"></script>

    
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../../../index.html" class="icon icon-home" alt="Documentation Home"> DeepRobust
          

          
          </a>

          
            
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/installation.html">Installation</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/data.html">Graph Dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/attack.html">Introduction to Graph Attack with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/defense.html">Introduction to Graph Defense with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/pyg.html">Using PyTorch Geometric in DeepRobust</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/node_embedding.html">Node Embedding Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../image/example.html">Image Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.attack.html">deeprobust.image.attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.defense.html">deeprobust.image.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.netmodels.html">deeprobust.image.netmodels package</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.global_attack.html">deeprobust.graph.global_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.targeted_attack.html">deeprobust.graph.targeted_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.defense.html">deeprobust.graph.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.data.html">deeprobust.graph.data package</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../index.html">DeepRobust</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
        
          <li><a href="../../index.html">Module code</a> &raquo;</li>
        
      <li>deeprobust.image.utils</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <h1>Source code for deeprobust.image.utils</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torchvision</span>
<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">urllib.request</span>

<span class="kn">import</span> <span class="nn">os</span>

<div class="viewcode-block" id="create_train_dataset"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.create_train_dataset">[docs]</a><span class="k">def</span> <span class="nf">create_train_dataset</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> <span class="n">root</span> <span class="o">=</span> <span class="s1">&#39;../data&#39;</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Create different training dataset</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">transform_train</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    <span class="p">])</span>
    <span class="n">trainset</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform_train</span><span class="p">)</span>
    <span class="n">trainloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">trainset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">trainloader</span></div>

<span class="k">def</span> <span class="nf">create_test_dataset</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> <span class="n">root</span> <span class="o">=</span> <span class="s1">&#39;../data&#39;</span><span class="p">):</span>
    <span class="n">transform_test</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
    <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
    <span class="p">])</span>
    <span class="n">testset</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform_test</span><span class="p">)</span>
    <span class="n">testloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">testset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">testloader</span>

<span class="k">def</span> <span class="nf">download_model</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">):</span>
    <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Dowloading from </span><span class="si">{}</span><span class="s1"> to </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">))</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
    <span class="k">except</span><span class="p">:</span>
        <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;Download failed! Make sure you have stable Internet connection and enter the right name&quot;</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="n">now_epoch</span><span class="p">,</span> <span class="n">net</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">lr_scheduler</span><span class="p">,</span> <span class="n">file_name</span><span class="p">):</span>
    <span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;epoch&#39;</span><span class="p">:</span> <span class="n">now_epoch</span><span class="p">,</span>
                  <span class="s1">&#39;state_dict&#39;</span><span class="p">:</span> <span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
                  <span class="s1">&#39;optimizer_state_dict&#39;</span><span class="p">:</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
                  <span class="s1">&#39;lr_scheduler_state_dict&#39;</span><span class="p">:</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()}</span>
    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">file_name</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Overwriting </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">file_name</span><span class="p">))</span>
    <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">checkpoint</span><span class="p">,</span> <span class="n">file_name</span><span class="p">)</span>
    <span class="c1"># link_name = os.path.join(*file_name.split(os.path.sep)[:-1], &#39;last.checkpoint&#39;)</span>
    <span class="c1"># #print(link_name)</span>
    <span class="c1"># make_symlink(source = file_name, link_name=link_name)</span>

<span class="k">def</span> <span class="nf">load_checkpoint</span><span class="p">(</span><span class="n">file_name</span><span class="p">,</span> <span class="n">net</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">lr_scheduler</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">file_name</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;=&gt; loading checkpoint &#39;</span><span class="si">{}</span><span class="s2">&#39;&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">file_name</span><span class="p">))</span>
        <span class="n">check_point</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">file_name</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">net</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Loading network state dict&#39;</span><span class="p">)</span>
            <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">check_point</span><span class="p">[</span><span class="s1">&#39;state_dict&#39;</span><span class="p">])</span>
        <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Loading optimizer state dict&#39;</span><span class="p">)</span>
            <span class="n">optimizer</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">check_point</span><span class="p">[</span><span class="s1">&#39;optimizer_state_dict&#39;</span><span class="p">])</span>
        <span class="k">if</span> <span class="n">lr_scheduler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Loading lr_scheduler state dict&#39;</span><span class="p">)</span>
            <span class="n">lr_scheduler</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">check_point</span><span class="p">[</span><span class="s1">&#39;lr_scheduler_state_dict&#39;</span><span class="p">])</span>

        <span class="k">return</span> <span class="n">check_point</span><span class="p">[</span><span class="s1">&#39;epoch&#39;</span><span class="p">]</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;=&gt; no checkpoint found at &#39;</span><span class="si">{}</span><span class="s2">&#39;&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">file_name</span><span class="p">))</span>

<div class="viewcode-block" id="make_symlink"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.make_symlink">[docs]</a><span class="k">def</span> <span class="nf">make_symlink</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="n">link_name</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Note: overwriting enabled!</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">link_name</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Link name already exist! Removing &#39;</span><span class="si">{}</span><span class="s2">&#39; and overwriting&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">link_name</span><span class="p">))</span>
        <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">link_name</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">source</span><span class="p">):</span>
        <span class="n">os</span><span class="o">.</span><span class="n">symlink</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="n">link_name</span><span class="p">)</span>
        <span class="k">return</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Source path not exists&#39;</span><span class="p">)</span></div>

<span class="kn">from</span> <span class="nn">texttable</span> <span class="kn">import</span> <span class="n">Texttable</span>
<div class="viewcode-block" id="tab_printer"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.tab_printer">[docs]</a><span class="k">def</span> <span class="nf">tab_printer</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Function to print the logs in a nice tabular format.</span>
<span class="sd">    input:</span>
<span class="sd">        param args: Parameters used for the model.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">args</span> <span class="o">=</span> <span class="nb">vars</span><span class="p">(</span><span class="n">args</span><span class="p">)</span>
    <span class="n">keys</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
    <span class="n">t</span> <span class="o">=</span> <span class="n">Texttable</span><span class="p">()</span>
    <span class="n">t</span><span class="o">.</span><span class="n">add_rows</span><span class="p">([[</span><span class="s2">&quot;Parameter&quot;</span><span class="p">,</span> <span class="s2">&quot;Value&quot;</span><span class="p">]]</span> <span class="o">+</span>  <span class="p">[[</span><span class="n">k</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">,</span><span class="s2">&quot; &quot;</span><span class="p">)</span><span class="o">.</span><span class="n">capitalize</span><span class="p">(),</span> <span class="n">args</span><span class="p">[</span><span class="n">k</span><span class="p">]]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">])</span>
    <span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">draw</span><span class="p">())</span></div>

<div class="viewcode-block" id="onehot_like"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.onehot_like">[docs]</a><span class="k">def</span> <span class="nf">onehot_like</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Creates an array like a, with all values</span>
<span class="sd">    set to 0 except one.</span>
<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    a : array_like</span>
<span class="sd">        The returned one-hot array will have the same shape</span>
<span class="sd">        and dtype as this array</span>
<span class="sd">    index : int</span>
<span class="sd">        The index that should be set to `value`</span>
<span class="sd">    value : single value compatible with a.dtype</span>
<span class="sd">        The value to set at the given index</span>
<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    `numpy.ndarray`</span>
<span class="sd">        One-hot array with the given value at the given</span>
<span class="sd">        location and zeros everywhere else.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="c1">#TODO: change the note here.</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
    <span class="n">x</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
    <span class="k">return</span> <span class="n">x</span></div>

<span class="k">def</span> <span class="nf">reduce_sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="c1"># silly PyTorch, when will you get proper reducing sums/means?</span>
    <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dim</span><span class="p">())):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span>

<div class="viewcode-block" id="arctanh"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.arctanh">[docs]</a><span class="k">def</span> <span class="nf">arctanh</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Calculate arctanh(x)</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">x</span> <span class="o">*=</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">eps</span><span class="p">)</span>
    <span class="k">return</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">((</span><span class="mi">1</span> <span class="o">+</span> <span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">x</span><span class="p">)))</span> <span class="o">*</span> <span class="mf">0.5</span></div>

<span class="k">def</span> <span class="nf">l2r_dist</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">):</span>
    <span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span>
    <span class="n">d</span> <span class="o">=</span> <span class="n">reduce_sum</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
    <span class="n">d</span> <span class="o">+=</span> <span class="n">eps</span>  <span class="c1"># to prevent infinite gradient at 0</span>
    <span class="k">return</span> <span class="n">d</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>


<span class="k">def</span> <span class="nf">l2_dist</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span>
    <span class="k">return</span> <span class="n">reduce_sum</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">l1_dist</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="n">d</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">reduce_sum</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">l2_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="n">norm</span> <span class="o">=</span> <span class="n">reduce_sum</span><span class="p">(</span><span class="n">x</span><span class="o">*</span><span class="n">x</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">norm</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>


<span class="k">def</span> <span class="nf">l1_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">reduce_sum</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">abs</span><span class="p">(),</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>

<div class="viewcode-block" id="adjust_learning_rate"><a class="viewcode-back" href="../../../source/deeprobust.image.html#deeprobust.image.utils.adjust_learning_rate">[docs]</a><span class="k">def</span> <span class="nf">adjust_learning_rate</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;decrease the learning rate&quot;&quot;&quot;</span>
    <span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span>
    <span class="k">if</span> <span class="n">epoch</span> <span class="o">&gt;=</span> <span class="mi">55</span><span class="p">:</span>
        <span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="mf">0.1</span>
    <span class="k">if</span> <span class="n">epoch</span> <span class="o">&gt;=</span> <span class="mi">75</span><span class="p">:</span>
        <span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="mf">0.01</span>
    <span class="k">if</span> <span class="n">epoch</span> <span class="o">&gt;=</span> <span class="mi">90</span><span class="p">:</span>
        <span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="mf">0.001</span>
    <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
        <span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">lr</span>

    <span class="k">return</span> <span class="n">optimizer</span></div>

<span class="k">def</span> <span class="nf">progress_bar</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="n">total</span><span class="p">,</span> <span class="n">msg</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="k">global</span> <span class="n">last_time</span><span class="p">,</span> <span class="n">begin_time</span>
    <span class="k">if</span> <span class="n">current</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="n">begin_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>  <span class="c1"># Reset for new bar.</span>

    <span class="n">cur_len</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">TOTAL_BAR_LENGTH</span><span class="o">*</span><span class="n">current</span><span class="o">/</span><span class="n">total</span><span class="p">)</span>
    <span class="n">rest_len</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">TOTAL_BAR_LENGTH</span> <span class="o">-</span> <span class="n">cur_len</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>

    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; [&#39;</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">cur_len</span><span class="p">):</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;=&#39;</span><span class="p">)</span>
    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;&gt;&#39;</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rest_len</span><span class="p">):</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;]&#39;</span><span class="p">)</span>

    <span class="n">cur_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="n">step_time</span> <span class="o">=</span> <span class="n">cur_time</span> <span class="o">-</span> <span class="n">last_time</span>
    <span class="n">last_time</span> <span class="o">=</span> <span class="n">cur_time</span>
    <span class="n">tot_time</span> <span class="o">=</span> <span class="n">cur_time</span> <span class="o">-</span> <span class="n">begin_time</span>

    <span class="n">L</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">L</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39;  Step: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">format_time</span><span class="p">(</span><span class="n">step_time</span><span class="p">))</span>
    <span class="n">L</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39; | Tot: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">format_time</span><span class="p">(</span><span class="n">tot_time</span><span class="p">))</span>
    <span class="k">if</span> <span class="n">msg</span><span class="p">:</span>
        <span class="n">L</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39; | &#39;</span> <span class="o">+</span> <span class="n">msg</span><span class="p">)</span>

    <span class="n">msg</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">L</span><span class="p">)</span>
    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">term_width</span><span class="o">-</span><span class="nb">int</span><span class="p">(</span><span class="n">TOTAL_BAR_LENGTH</span><span class="p">)</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span><span class="o">-</span><span class="mi">3</span><span class="p">):</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="p">)</span>

    <span class="c1"># Go back to the center of the bar.</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">term_width</span><span class="o">-</span><span class="nb">int</span><span class="p">(</span><span class="n">TOTAL_BAR_LENGTH</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span><span class="o">+</span><span class="mi">2</span><span class="p">):</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\b</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; </span><span class="si">%d</span><span class="s1">/</span><span class="si">%d</span><span class="s1"> &#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">current</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">total</span><span class="p">))</span>

    <span class="k">if</span> <span class="n">current</span> <span class="o">&lt;</span> <span class="n">total</span><span class="o">-</span><span class="mi">1</span><span class="p">:</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\r</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>

</pre></div>

           </div>
           
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        
        &copy; Copyright 

    </p>
  </div>
    
    
    
    Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
    
    <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
    
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  

  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
   

</body>
</html>