

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>deeprobust.graph.utils &mdash; DeepRobust 0.1.1 documentation</title>
  

  
  <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />

  
  
  
  

  
  <!--[if lt IE 9]>
    <script src="../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
        <script type="text/javascript" src="../../../_static/jquery.js"></script>
        <script type="text/javascript" src="../../../_static/underscore.js"></script>
        <script type="text/javascript" src="../../../_static/doctools.js"></script>
        <script type="text/javascript" src="../../../_static/language_data.js"></script>
        <script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    
    <script type="text/javascript" src="../../../_static/js/theme.js"></script>

    
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../../../index.html" class="icon icon-home" alt="Documentation Home"> DeepRobust
          

          
          </a>

          
            
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/installation.html">Installation</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/data.html">Graph Dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/attack.html">Introduction to Graph Attack with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/defense.html">Introduction to Graph Defense with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/pyg.html">Using PyTorch Geometric in DeepRobust</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph/node_embedding.html">Node Embedding Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../image/example.html">Image Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.attack.html">deeprobust.image.attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.defense.html">deeprobust.image.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.image.netmodels.html">deeprobust.image.netmodels package</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.global_attack.html">deeprobust.graph.global_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.targeted_attack.html">deeprobust.graph.targeted_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.defense.html">deeprobust.graph.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../source/deeprobust.graph.data.html">deeprobust.graph.data package</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../index.html">DeepRobust</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
        
          <li><a href="../../index.html">Module code</a> &raquo;</li>
        
      <li>deeprobust.graph.utils</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <h1>Source code for deeprobust.graph.utils</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">scipy.sparse</span> <span class="k">as</span> <span class="nn">sp</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="kn">import</span> <span class="nn">torch.sparse</span> <span class="k">as</span> <span class="nn">ts</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="kn">import</span> <span class="nn">warnings</span>

<div class="viewcode-block" id="encode_onehot"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.encode_onehot">[docs]</a><span class="k">def</span> <span class="nf">encode_onehot</span><span class="p">(</span><span class="n">labels</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert label to onehot format.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    labels : numpy.array</span>
<span class="sd">        node labels</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    numpy.array</span>
<span class="sd">        onehot labels</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">eye</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">onehot_mx</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">labels</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">onehot_mx</span></div>

<div class="viewcode-block" id="tensor2onehot"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.tensor2onehot">[docs]</a><span class="k">def</span> <span class="nf">tensor2onehot</span><span class="p">(</span><span class="n">labels</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert label tensor to label onehot tensor.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    labels : torch.LongTensor</span>
<span class="sd">        node labels</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    torch.LongTensor</span>
<span class="sd">        onehot labels tensor</span>

<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">eye</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">onehot_mx</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">labels</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">onehot_mx</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></div>

<div class="viewcode-block" id="preprocess"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.preprocess">[docs]</a><span class="k">def</span> <span class="nf">preprocess</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preprocess_adj</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">preprocess_feature</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert adj, features, labels from array or sparse matrix to</span>
<span class="sd">    torch Tensor, and normalize the input data.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    adj : scipy.sparse.csr_matrix</span>
<span class="sd">        the adjacency matrix.</span>
<span class="sd">    features : scipy.sparse.csr_matrix</span>
<span class="sd">        node features</span>
<span class="sd">    labels : numpy.array</span>
<span class="sd">        node labels</span>
<span class="sd">    preprocess_adj : bool</span>
<span class="sd">        whether to normalize the adjacency matrix</span>
<span class="sd">    preprocess_feature : bool</span>
<span class="sd">        whether to normalize the feature matrix</span>
<span class="sd">    sparse : bool</span>
<span class="sd">       whether to return sparse tensor</span>
<span class="sd">    device : str</span>
<span class="sd">        &#39;cpu&#39; or &#39;cuda&#39;</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">if</span> <span class="n">preprocess_adj</span><span class="p">:</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">normalize_adj</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">preprocess_feature</span><span class="p">:</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">normalize_feature</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>

    <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">sparse</span><span class="p">:</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="o">.</span><span class="n">todense</span><span class="p">()))</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">adj</span><span class="o">.</span><span class="n">todense</span><span class="p">())</span>
    <span class="k">return</span> <span class="n">adj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">features</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></div>

<div class="viewcode-block" id="to_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.to_tensor">[docs]</a><span class="k">def</span> <span class="nf">to_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert adj, features, labels from array or sparse matrix to</span>
<span class="sd">    torch Tensor.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    adj : scipy.sparse.csr_matrix</span>
<span class="sd">        the adjacency matrix.</span>
<span class="sd">    features : scipy.sparse.csr_matrix</span>
<span class="sd">        node features</span>
<span class="sd">    labels : numpy.array</span>
<span class="sd">        node labels</span>
<span class="sd">    device : str</span>
<span class="sd">        &#39;cpu&#39; or &#39;cuda&#39;</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">sp</span><span class="o">.</span><span class="n">issparse</span><span class="p">(</span><span class="n">adj</span><span class="p">):</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">sp</span><span class="o">.</span><span class="n">issparse</span><span class="p">(</span><span class="n">features</span><span class="p">):</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="p">))</span>

    <span class="k">if</span> <span class="n">labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">adj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">features</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">adj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">features</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></div>

<div class="viewcode-block" id="normalize_feature"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.normalize_feature">[docs]</a><span class="k">def</span> <span class="nf">normalize_feature</span><span class="p">(</span><span class="n">mx</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Row-normalize sparse matrix or dense matrix</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    mx : scipy.sparse.csr_matrix or numpy.array</span>
<span class="sd">        matrix to be normalized</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    scipy.sprase.lil_matrix</span>
<span class="sd">        normalized matrix</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">sp</span><span class="o">.</span><span class="n">lil</span><span class="o">.</span><span class="n">lil_matrix</span><span class="p">:</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">tolil</span><span class="p">()</span>
        <span class="k">except</span> <span class="ne">AttributeError</span><span class="p">:</span>
            <span class="k">pass</span>
    <span class="n">rowsum</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">r_inv</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">rowsum</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
    <span class="n">r_inv</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
    <span class="n">r_mat_inv</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">diags</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)</span>
    <span class="n">mx</span> <span class="o">=</span> <span class="n">r_mat_inv</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">mx</span></div>

<div class="viewcode-block" id="normalize_adj"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.normalize_adj">[docs]</a><span class="k">def</span> <span class="nf">normalize_adj</span><span class="p">(</span><span class="n">mx</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Normalize sparse adjacency matrix,</span>
<span class="sd">    A&#39; = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2</span>
<span class="sd">    Row-normalize sparse matrix</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    mx : scipy.sparse.csr_matrix</span>
<span class="sd">        matrix to be normalized</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    scipy.sprase.lil_matrix</span>
<span class="sd">        normalized matrix</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="c1"># TODO: maybe using coo format would be better?</span>
    <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">sp</span><span class="o">.</span><span class="n">lil</span><span class="o">.</span><span class="n">lil_matrix</span><span class="p">:</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">tolil</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">mx</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="p">:</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span> <span class="o">+</span> <span class="n">sp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">rowsum</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">r_inv</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">rowsum</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
    <span class="n">r_inv</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
    <span class="n">r_mat_inv</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">diags</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)</span>
    <span class="n">mx</span> <span class="o">=</span> <span class="n">r_mat_inv</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span>
    <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">r_mat_inv</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">mx</span></div>

<div class="viewcode-block" id="normalize_sparse_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.normalize_sparse_tensor">[docs]</a><span class="k">def</span> <span class="nf">normalize_sparse_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Normalize sparse tensor. Need to import torch_scatter</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">edge_index</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">_indices</span><span class="p">()</span>
    <span class="n">edge_weight</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">_values</span><span class="p">()</span>
    <span class="n">num_nodes</span><span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span> <span class="o">=</span> <span class="n">add_self_loops</span><span class="p">(</span>
	<span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span><span class="p">,</span> <span class="n">fill_value</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">)</span>

    <span class="n">row</span><span class="p">,</span> <span class="n">col</span> <span class="o">=</span> <span class="n">edge_index</span>
    <span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_add</span>
    <span class="n">deg</span> <span class="o">=</span> <span class="n">scatter_add</span><span class="p">(</span><span class="n">edge_weight</span><span class="p">,</span> <span class="n">row</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">num_nodes</span><span class="p">)</span>
    <span class="n">deg_inv_sqrt</span> <span class="o">=</span> <span class="n">deg</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">deg_inv_sqrt</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="n">values</span> <span class="o">=</span> <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">edge_weight</span> <span class="o">*</span> <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">col</span><span class="p">]</span>

    <span class="n">shape</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">shape</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sparse</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span></div>

<span class="k">def</span> <span class="nf">add_self_loops</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_nodes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="c1"># num_nodes = maybe_num_nodes(edge_index, num_nodes)</span>

    <span class="n">loop_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span>
                              <span class="n">device</span><span class="o">=</span><span class="n">edge_index</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
    <span class="n">loop_index</span> <span class="o">=</span> <span class="n">loop_index</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">edge_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="k">assert</span> <span class="n">edge_weight</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">==</span> <span class="n">edge_index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">loop_weight</span> <span class="o">=</span> <span class="n">edge_weight</span><span class="o">.</span><span class="n">new_full</span><span class="p">((</span><span class="n">num_nodes</span><span class="p">,</span> <span class="p">),</span> <span class="n">fill_value</span><span class="p">)</span>
        <span class="n">edge_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">edge_weight</span><span class="p">,</span> <span class="n">loop_weight</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">edge_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">loop_index</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span>

<div class="viewcode-block" id="normalize_adj_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.normalize_adj_tensor">[docs]</a><span class="k">def</span> <span class="nf">normalize_adj_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Normalize adjacency tensor matrix.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">adj</span><span class="o">.</span><span class="n">is_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">sparse</span><span class="p">:</span>
        <span class="c1"># warnings.warn(&#39;If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse&#39;)</span>
        <span class="c1"># TODO if this is too slow, uncomment the following code,</span>
        <span class="c1"># but you need to install torch_scatter</span>
        <span class="c1"># return normalize_sparse_tensor(adj)</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">to_scipy</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">normalize_adj</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">adj</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">adj</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
        <span class="n">rowsum</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">r_inv</span> <span class="o">=</span> <span class="n">rowsum</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
        <span class="n">r_inv</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
        <span class="n">r_mat_inv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diag</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">r_mat_inv</span> <span class="o">@</span> <span class="n">mx</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span> <span class="o">@</span> <span class="n">r_mat_inv</span>
    <span class="k">return</span> <span class="n">mx</span></div>

<div class="viewcode-block" id="degree_normalize_adj"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.degree_normalize_adj">[docs]</a><span class="k">def</span> <span class="nf">degree_normalize_adj</span><span class="p">(</span><span class="n">mx</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Row-normalize sparse matrix&quot;&quot;&quot;</span>
    <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">tolil</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">mx</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="p">:</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">mx</span> <span class="o">+</span> <span class="n">sp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">rowsum</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">r_inv</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">rowsum</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
    <span class="n">r_inv</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
    <span class="n">r_mat_inv</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">diags</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)</span>
    <span class="c1"># mx = mx.dot(r_mat_inv)</span>
    <span class="n">mx</span> <span class="o">=</span> <span class="n">r_mat_inv</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">mx</span></div>

<div class="viewcode-block" id="degree_normalize_sparse_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.degree_normalize_sparse_tensor">[docs]</a><span class="k">def</span> <span class="nf">degree_normalize_sparse_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;degree_normalize_sparse_tensor.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">edge_index</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">_indices</span><span class="p">()</span>
    <span class="n">edge_weight</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">_values</span><span class="p">()</span>
    <span class="n">num_nodes</span><span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span> <span class="o">=</span> <span class="n">add_self_loops</span><span class="p">(</span>
	<span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_weight</span><span class="p">,</span> <span class="n">fill_value</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">)</span>

    <span class="n">row</span><span class="p">,</span> <span class="n">col</span> <span class="o">=</span> <span class="n">edge_index</span>
    <span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_add</span>
    <span class="n">deg</span> <span class="o">=</span> <span class="n">scatter_add</span><span class="p">(</span><span class="n">edge_weight</span><span class="p">,</span> <span class="n">row</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">num_nodes</span><span class="p">)</span>
    <span class="n">deg_inv_sqrt</span> <span class="o">=</span> <span class="n">deg</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">deg_inv_sqrt</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="n">values</span> <span class="o">=</span> <span class="n">deg_inv_sqrt</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">edge_weight</span>
    <span class="n">shape</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">shape</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sparse</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span></div>

<div class="viewcode-block" id="degree_normalize_adj_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.degree_normalize_adj_tensor">[docs]</a><span class="k">def</span> <span class="nf">degree_normalize_adj_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;degree_normalize_adj_tensor.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">adj</span><span class="o">.</span><span class="n">is_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">sparse</span><span class="p">:</span>
        <span class="c1"># return  degree_normalize_sparse_tensor(adj)</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">to_scipy</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">degree_normalize_adj</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">mx</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">adj</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">adj</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
        <span class="n">rowsum</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">r_inv</span> <span class="o">=</span> <span class="n">rowsum</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
        <span class="n">r_inv</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
        <span class="n">r_mat_inv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diag</span><span class="p">(</span><span class="n">r_inv</span><span class="p">)</span>
        <span class="n">mx</span> <span class="o">=</span> <span class="n">r_mat_inv</span> <span class="o">@</span> <span class="n">mx</span>
    <span class="k">return</span> <span class="n">mx</span></div>

<div class="viewcode-block" id="accuracy"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.accuracy">[docs]</a><span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Return accuracy of output compared to labels.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    output : torch.Tensor</span>
<span class="sd">        output from model</span>
<span class="sd">    labels : torch.Tensor or numpy.array</span>
<span class="sd">        node labels</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    float</span>
<span class="sd">        accuracy</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="s1">&#39;__len__&#39;</span><span class="p">):</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">labels</span><span class="p">]</span>
    <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">preds</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">type_as</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">correct</span> <span class="o">=</span> <span class="n">preds</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">double</span><span class="p">()</span>
    <span class="n">correct</span> <span class="o">=</span> <span class="n">correct</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span></div>

<span class="k">def</span> <span class="nf">loss_acc</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">avg_loss</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
    <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">preds</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">type_as</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">correct</span> <span class="o">=</span> <span class="n">preds</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">double</span><span class="p">()[</span><span class="n">targets</span><span class="p">]</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">[</span><span class="n">targets</span><span class="p">],</span> <span class="n">labels</span><span class="p">[</span><span class="n">targets</span><span class="p">],</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#39;</span> <span class="k">if</span> <span class="n">avg_loss</span> <span class="k">else</span> <span class="s1">&#39;none&#39;</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">avg_loss</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">correct</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">correct</span>
    <span class="c1"># correct = correct.sum()</span>
    <span class="c1"># return loss, correct / len(labels)</span>

<div class="viewcode-block" id="classification_margin"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.classification_margin">[docs]</a><span class="k">def</span> <span class="nf">classification_margin</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">true_label</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Calculate classification margin for outputs.</span>
<span class="sd">    `probs_true_label - probs_best_second_class`</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    output: torch.Tensor</span>
<span class="sd">        output vector (1 dimension)</span>
<span class="sd">    true_label: int</span>
<span class="sd">        true label for this node</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    list</span>
<span class="sd">        classification margin for this node</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
    <span class="n">probs_true_label</span> <span class="o">=</span> <span class="n">probs</span><span class="p">[</span><span class="n">true_label</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    <span class="n">probs</span><span class="p">[</span><span class="n">true_label</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">probs_best_second_class</span> <span class="o">=</span> <span class="n">probs</span><span class="p">[</span><span class="n">probs</span><span class="o">.</span><span class="n">argmax</span><span class="p">()]</span>
    <span class="k">return</span> <span class="p">(</span><span class="n">probs_true_label</span> <span class="o">-</span> <span class="n">probs_best_second_class</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span></div>

<div class="viewcode-block" id="sparse_mx_to_torch_sparse_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.sparse_mx_to_torch_sparse_tensor">[docs]</a><span class="k">def</span> <span class="nf">sparse_mx_to_torch_sparse_tensor</span><span class="p">(</span><span class="n">sparse_mx</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert a scipy sparse matrix to a torch sparse tensor.&quot;&quot;&quot;</span>
    <span class="n">sparse_mx</span> <span class="o">=</span> <span class="n">sparse_mx</span><span class="o">.</span><span class="n">tocoo</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
    <span class="n">sparserow</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">sparse_mx</span><span class="o">.</span><span class="n">row</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">sparsecol</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">sparse_mx</span><span class="o">.</span><span class="n">col</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">sparseconcat</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">sparserow</span><span class="p">,</span> <span class="n">sparsecol</span><span class="p">),</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">sparsedata</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">sparse_mx</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sparse</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">sparseconcat</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span><span class="n">sparsedata</span><span class="p">,</span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">(</span><span class="n">sparse_mx</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span></div>

	<span class="c1"># slower version....</span>
    <span class="c1"># sparse_mx = sparse_mx.tocoo().astype(np.float32)</span>
    <span class="c1"># indices = torch.from_numpy(</span>
    <span class="c1">#     np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))</span>
    <span class="c1"># values = torch.from_numpy(sparse_mx.data)</span>
    <span class="c1"># shape = torch.Size(sparse_mx.shape)</span>
    <span class="c1"># return torch.sparse.FloatTensor(indices, values, shape)</span>



<div class="viewcode-block" id="to_scipy"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.to_scipy">[docs]</a><span class="k">def</span> <span class="nf">to_scipy</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Convert a dense/sparse tensor to scipy matrix&quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">is_sparse_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
        <span class="n">values</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">_values</span><span class="p">()</span>
        <span class="n">indices</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">_indices</span><span class="p">()</span>
        <span class="k">return</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">values</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">indices</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">shape</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">indices</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
        <span class="n">values</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
        <span class="k">return</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">values</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">indices</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()),</span> <span class="n">shape</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>

<div class="viewcode-block" id="is_sparse_tensor"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.is_sparse_tensor">[docs]</a><span class="k">def</span> <span class="nf">is_sparse_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Check if a tensor is sparse tensor.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    tensor : torch.Tensor</span>
<span class="sd">        given tensor</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    bool</span>
<span class="sd">        whether a tensor is sparse tensor</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="c1"># if hasattr(tensor, &#39;nnz&#39;):</span>
    <span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">layout</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">sparse_coo</span><span class="p">:</span>
        <span class="k">return</span> <span class="kc">True</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">return</span> <span class="kc">False</span></div>

<div class="viewcode-block" id="get_train_val_test"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.get_train_val_test">[docs]</a><span class="k">def</span> <span class="nf">get_train_val_test</span><span class="p">(</span><span class="n">nnodes</span><span class="p">,</span> <span class="n">val_size</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;This setting follows nettack/mettack, where we split the nodes</span>
<span class="sd">    into 10% training, 10% validation and 80% testing data</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    nnodes : int</span>
<span class="sd">        number of nodes in total</span>
<span class="sd">    val_size : float</span>
<span class="sd">        size of validation set</span>
<span class="sd">    test_size : float</span>
<span class="sd">        size of test set</span>
<span class="sd">    stratify :</span>
<span class="sd">        data is expected to split in a stratified fashion. So stratify should be labels.</span>
<span class="sd">    seed : int or None</span>
<span class="sd">        random seed</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    idx_train :</span>
<span class="sd">        node training indices</span>
<span class="sd">    idx_val :</span>
<span class="sd">        node validation indices</span>
<span class="sd">    idx_test :</span>
<span class="sd">        node test indices</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">assert</span> <span class="n">stratify</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;stratify cannot be None!&#39;</span>

    <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>

    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">nnodes</span><span class="p">)</span>
    <span class="n">train_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">val_size</span> <span class="o">-</span> <span class="n">test_size</span>
    <span class="n">idx_train_and_val</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span>
                                                   <span class="n">random_state</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                                   <span class="n">train_size</span><span class="o">=</span><span class="n">train_size</span> <span class="o">+</span> <span class="n">val_size</span><span class="p">,</span>
                                                   <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span>
                                                   <span class="n">stratify</span><span class="o">=</span><span class="n">stratify</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">stratify</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">stratify</span> <span class="o">=</span> <span class="n">stratify</span><span class="p">[</span><span class="n">idx_train_and_val</span><span class="p">]</span>

    <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">idx_train_and_val</span><span class="p">,</span>
                                          <span class="n">random_state</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                          <span class="n">train_size</span><span class="o">=</span><span class="p">(</span><span class="n">train_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">train_size</span> <span class="o">+</span> <span class="n">val_size</span><span class="p">)),</span>
                                          <span class="n">test_size</span><span class="o">=</span><span class="p">(</span><span class="n">val_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">train_size</span> <span class="o">+</span> <span class="n">val_size</span><span class="p">)),</span>
                                          <span class="n">stratify</span><span class="o">=</span><span class="n">stratify</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span></div>

<div class="viewcode-block" id="get_train_test"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.get_train_test">[docs]</a><span class="k">def</span> <span class="nf">get_train_test</span><span class="p">(</span><span class="n">nnodes</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;This function returns training and test set without validation.</span>
<span class="sd">    It can be used for settings of different label rates.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    nnodes : int</span>
<span class="sd">        number of nodes in total</span>
<span class="sd">    test_size : float</span>
<span class="sd">        size of test set</span>
<span class="sd">    stratify :</span>
<span class="sd">        data is expected to split in a stratified fashion. So stratify should be labels.</span>
<span class="sd">    seed : int or None</span>
<span class="sd">        random seed</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    idx_train :</span>
<span class="sd">        node training indices</span>
<span class="sd">    idx_test :</span>
<span class="sd">        node test indices</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">stratify</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;stratify cannot be None!&#39;</span>

    <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>

    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">nnodes</span><span class="p">)</span>
    <span class="n">train_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">test_size</span>
    <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                                <span class="n">train_size</span><span class="o">=</span><span class="n">train_size</span><span class="p">,</span>
                                                <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span>
                                                <span class="n">stratify</span><span class="o">=</span><span class="n">stratify</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_test</span></div>

<div class="viewcode-block" id="get_train_val_test_gcn"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.get_train_val_test_gcn">[docs]</a><span class="k">def</span> <span class="nf">get_train_val_test_gcn</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;This setting follows gcn, where we randomly sample 20 instances for each class</span>
<span class="sd">    as training data, 500 instances as validation data, 1000 instances as test data.</span>
<span class="sd">    Note here we are not using fixed splits. When random seed changes, the splits</span>
<span class="sd">    will also change.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    labels : numpy.array</span>
<span class="sd">        node labels</span>
<span class="sd">    seed : int or None</span>
<span class="sd">        random seed</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    idx_train :</span>
<span class="sd">        node training indices</span>
<span class="sd">    idx_val :</span>
<span class="sd">        node validation indices</span>
<span class="sd">    idx_test :</span>
<span class="sd">        node test indices</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>

    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
    <span class="n">nclass</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="n">idx_train</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">idx_unlabeled</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nclass</span><span class="p">):</span>
        <span class="n">labels_i</span> <span class="o">=</span> <span class="n">idx</span><span class="p">[</span><span class="n">labels</span><span class="o">==</span><span class="n">i</span><span class="p">]</span>
        <span class="n">labels_i</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">labels_i</span><span class="p">)</span>
        <span class="n">idx_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">idx_train</span><span class="p">,</span> <span class="n">labels_i</span><span class="p">[:</span> <span class="mi">20</span><span class="p">]))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
        <span class="n">idx_unlabeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">idx_unlabeled</span><span class="p">,</span> <span class="n">labels_i</span><span class="p">[</span><span class="mi">20</span><span class="p">:</span> <span class="p">]))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>

    <span class="n">idx_unlabeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">idx_unlabeled</span><span class="p">)</span>
    <span class="n">idx_val</span> <span class="o">=</span> <span class="n">idx_unlabeled</span><span class="p">[:</span> <span class="mi">500</span><span class="p">]</span>
    <span class="n">idx_test</span> <span class="o">=</span> <span class="n">idx_unlabeled</span><span class="p">[</span><span class="mi">500</span><span class="p">:</span> <span class="mi">1500</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span></div>

<div class="viewcode-block" id="get_train_test_labelrate"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.get_train_test_labelrate">[docs]</a><span class="k">def</span> <span class="nf">get_train_test_labelrate</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">label_rate</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Get train test according to given label rate.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">nclass</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="n">train_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">*</span> <span class="n">label_rate</span> <span class="o">/</span> <span class="n">nclass</span><span class="p">))</span>
    <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;=== train_size = </span><span class="si">%s</span><span class="s2"> ===&quot;</span> <span class="o">%</span> <span class="n">train_size</span><span class="p">)</span>
    <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="n">get_splits_each_class</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">train_size</span><span class="o">=</span><span class="n">train_size</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_test</span></div>

<div class="viewcode-block" id="get_splits_each_class"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.get_splits_each_class">[docs]</a><span class="k">def</span> <span class="nf">get_splits_each_class</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">train_size</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;We randomly sample n instances for class, where n = train_size.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
    <span class="n">nclass</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="n">idx_train</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">idx_val</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">idx_test</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nclass</span><span class="p">):</span>
        <span class="n">labels_i</span> <span class="o">=</span> <span class="n">idx</span><span class="p">[</span><span class="n">labels</span><span class="o">==</span><span class="n">i</span><span class="p">]</span>
        <span class="n">labels_i</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">labels_i</span><span class="p">)</span>
        <span class="n">idx_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">idx_train</span><span class="p">,</span> <span class="n">labels_i</span><span class="p">[:</span> <span class="n">train_size</span><span class="p">]))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
        <span class="n">idx_val</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">idx_val</span><span class="p">,</span> <span class="n">labels_i</span><span class="p">[</span><span class="n">train_size</span><span class="p">:</span> <span class="mi">2</span><span class="o">*</span><span class="n">train_size</span><span class="p">]))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
        <span class="n">idx_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">idx_test</span><span class="p">,</span> <span class="n">labels_i</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="n">train_size</span><span class="p">:</span> <span class="p">]))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">idx_train</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">idx_val</span><span class="p">),</span> \
           <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">idx_test</span><span class="p">)</span></div>


<span class="k">def</span> <span class="nf">unravel_index</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="n">array_shape</span><span class="p">):</span>
    <span class="n">rows</span> <span class="o">=</span> <span class="n">index</span> <span class="o">//</span> <span class="n">array_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">cols</span> <span class="o">=</span> <span class="n">index</span> <span class="o">%</span> <span class="n">array_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">rows</span><span class="p">,</span> <span class="n">cols</span>


<span class="k">def</span> <span class="nf">get_degree_squence</span><span class="p">(</span><span class="n">adj</span><span class="p">):</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">adj</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="k">except</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">ts</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">adj</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to_dense</span><span class="p">()</span>

<div class="viewcode-block" id="likelihood_ratio_filter"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.likelihood_ratio_filter">[docs]</a><span class="k">def</span> <span class="nf">likelihood_ratio_filter</span><span class="p">(</span><span class="n">node_pairs</span><span class="p">,</span> <span class="n">modified_adjacency</span><span class="p">,</span> <span class="n">original_adjacency</span><span class="p">,</span> <span class="n">d_min</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.004</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see</span>
<span class="sd">    https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge</span>
<span class="sd">    between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted</span>
<span class="sd">    and undirected graphs.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">N</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">modified_adjacency</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="c1"># original_degree_sequence = get_degree_squence(original_adjacency)</span>
    <span class="c1"># current_degree_sequence = get_degree_squence(modified_adjacency)</span>
    <span class="n">original_degree_sequence</span> <span class="o">=</span> <span class="n">original_adjacency</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">current_degree_sequence</span> <span class="o">=</span> <span class="n">modified_adjacency</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">concat_degree_sequence</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">current_degree_sequence</span><span class="p">,</span> <span class="n">original_degree_sequence</span><span class="p">))</span>

    <span class="c1"># Compute the log likelihood values of the original, modified, and combined degree sequences.</span>
    <span class="n">ll_orig</span><span class="p">,</span> <span class="n">alpha_orig</span><span class="p">,</span> <span class="n">n_orig</span><span class="p">,</span> <span class="n">sum_log_degrees_original</span> <span class="o">=</span> <span class="n">degree_sequence_log_likelihood</span><span class="p">(</span><span class="n">original_degree_sequence</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="n">ll_current</span><span class="p">,</span> <span class="n">alpha_current</span><span class="p">,</span> <span class="n">n_current</span><span class="p">,</span> <span class="n">sum_log_degrees_current</span> <span class="o">=</span> <span class="n">degree_sequence_log_likelihood</span><span class="p">(</span><span class="n">current_degree_sequence</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>

    <span class="n">ll_comb</span><span class="p">,</span> <span class="n">alpha_comb</span><span class="p">,</span> <span class="n">n_comb</span><span class="p">,</span> <span class="n">sum_log_degrees_combined</span> <span class="o">=</span> <span class="n">degree_sequence_log_likelihood</span><span class="p">(</span><span class="n">concat_degree_sequence</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>

    <span class="c1"># Compute the log likelihood ratio</span>
    <span class="n">current_ratio</span> <span class="o">=</span> <span class="o">-</span><span class="mi">2</span> <span class="o">*</span> <span class="n">ll_comb</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">ll_orig</span> <span class="o">+</span> <span class="n">ll_current</span><span class="p">)</span>

    <span class="c1"># Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair.</span>
    <span class="n">new_lls</span><span class="p">,</span> <span class="n">new_alphas</span><span class="p">,</span> <span class="n">new_ns</span><span class="p">,</span> <span class="n">new_sum_log_degrees</span> <span class="o">=</span> <span class="n">updated_log_likelihood_for_edge_changes</span><span class="p">(</span><span class="n">node_pairs</span><span class="p">,</span>
                                                                                               <span class="n">modified_adjacency</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>

    <span class="c1"># Combination of the original degree distribution with the distributions corresponding to each node pair.</span>
    <span class="n">n_combined</span> <span class="o">=</span> <span class="n">n_orig</span> <span class="o">+</span> <span class="n">new_ns</span>
    <span class="n">new_sum_log_degrees_combined</span> <span class="o">=</span> <span class="n">sum_log_degrees_original</span> <span class="o">+</span> <span class="n">new_sum_log_degrees</span>
    <span class="n">alpha_combined</span> <span class="o">=</span> <span class="n">compute_alpha</span><span class="p">(</span><span class="n">n_combined</span><span class="p">,</span> <span class="n">new_sum_log_degrees_combined</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="n">new_ll_combined</span> <span class="o">=</span> <span class="n">compute_log_likelihood</span><span class="p">(</span><span class="n">n_combined</span><span class="p">,</span> <span class="n">alpha_combined</span><span class="p">,</span> <span class="n">new_sum_log_degrees_combined</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="n">new_ratios</span> <span class="o">=</span> <span class="o">-</span><span class="mi">2</span> <span class="o">*</span> <span class="n">new_ll_combined</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">new_lls</span> <span class="o">+</span> <span class="n">ll_orig</span><span class="p">)</span>

    <span class="c1"># Allowed edges are only those for which the resulting likelihood ratio measure is &lt; than the threshold</span>
    <span class="n">allowed_edges</span> <span class="o">=</span> <span class="n">new_ratios</span> <span class="o">&lt;</span> <span class="n">threshold</span>

    <span class="k">if</span> <span class="n">allowed_edges</span><span class="o">.</span><span class="n">is_cuda</span><span class="p">:</span>
        <span class="n">filtered_edges</span> <span class="o">=</span> <span class="n">node_pairs</span><span class="p">[</span><span class="n">allowed_edges</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)]</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">filtered_edges</span> <span class="o">=</span> <span class="n">node_pairs</span><span class="p">[</span><span class="n">allowed_edges</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)]</span>

    <span class="n">allowed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">modified_adjacency</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">allowed_mask</span><span class="p">[</span><span class="n">filtered_edges</span><span class="o">.</span><span class="n">T</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
    <span class="n">allowed_mask</span> <span class="o">+=</span> <span class="n">allowed_mask</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">allowed_mask</span><span class="p">,</span> <span class="n">current_ratio</span></div>


<div class="viewcode-block" id="degree_sequence_log_likelihood"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.degree_sequence_log_likelihood">[docs]</a><span class="k">def</span> <span class="nf">degree_sequence_log_likelihood</span><span class="p">(</span><span class="n">degree_sequence</span><span class="p">,</span> <span class="n">d_min</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="c1"># Determine which degrees are to be considered, i.e. &gt;= d_min.</span>
    <span class="n">D_G</span> <span class="o">=</span> <span class="n">degree_sequence</span><span class="p">[(</span><span class="n">degree_sequence</span> <span class="o">&gt;=</span> <span class="n">d_min</span><span class="o">.</span><span class="n">item</span><span class="p">())]</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="n">sum_log_degrees</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">D_G</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    <span class="k">except</span><span class="p">:</span>
        <span class="n">sum_log_degrees</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">D_G</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    <span class="n">n</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">D_G</span><span class="p">)</span>

    <span class="n">alpha</span> <span class="o">=</span> <span class="n">compute_alpha</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">sum_log_degrees</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="n">ll</span> <span class="o">=</span> <span class="n">compute_log_likelihood</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">sum_log_degrees</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">ll</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">sum_log_degrees</span></div>

<div class="viewcode-block" id="updated_log_likelihood_for_edge_changes"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.updated_log_likelihood_for_edge_changes">[docs]</a><span class="k">def</span> <span class="nf">updated_log_likelihood_for_edge_changes</span><span class="p">(</span><span class="n">node_pairs</span><span class="p">,</span> <span class="n">adjacency_matrix</span><span class="p">,</span> <span class="n">d_min</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot; Adopted from https://github.com/danielzuegner/nettack</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="c1"># For each node pair find out whether there is an edge or not in the input adjacency matrix.</span>

    <span class="n">edge_entries_before</span> <span class="o">=</span> <span class="n">adjacency_matrix</span><span class="p">[</span><span class="n">node_pairs</span><span class="o">.</span><span class="n">T</span><span class="p">]</span>
    <span class="n">degree_sequence</span> <span class="o">=</span> <span class="n">adjacency_matrix</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">D_G</span> <span class="o">=</span> <span class="n">degree_sequence</span><span class="p">[</span><span class="n">degree_sequence</span> <span class="o">&gt;=</span> <span class="n">d_min</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
    <span class="n">sum_log_degrees</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">D_G</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    <span class="n">n</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">D_G</span><span class="p">)</span>
    <span class="n">deltas</span> <span class="o">=</span> <span class="o">-</span><span class="mi">2</span> <span class="o">*</span> <span class="n">edge_entries_before</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="n">d_edges_before</span> <span class="o">=</span> <span class="n">degree_sequence</span><span class="p">[</span><span class="n">node_pairs</span><span class="p">]</span>

    <span class="n">d_edges_after</span> <span class="o">=</span> <span class="n">degree_sequence</span><span class="p">[</span><span class="n">node_pairs</span><span class="p">]</span> <span class="o">+</span> <span class="n">deltas</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>

    <span class="c1"># Sum the log of the degrees after the potential changes which are &gt;= d_min</span>
    <span class="n">sum_log_degrees_after</span><span class="p">,</span> <span class="n">new_n</span> <span class="o">=</span> <span class="n">update_sum_log_degrees</span><span class="p">(</span><span class="n">sum_log_degrees</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">d_edges_before</span><span class="p">,</span> <span class="n">d_edges_after</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="c1"># Updated estimates of the Powerlaw exponents</span>
    <span class="n">new_alpha</span> <span class="o">=</span> <span class="n">compute_alpha</span><span class="p">(</span><span class="n">new_n</span><span class="p">,</span> <span class="n">sum_log_degrees_after</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="c1"># Updated log likelihood values for the Powerlaw distributions</span>
    <span class="n">new_ll</span> <span class="o">=</span> <span class="n">compute_log_likelihood</span><span class="p">(</span><span class="n">new_n</span><span class="p">,</span> <span class="n">new_alpha</span><span class="p">,</span> <span class="n">sum_log_degrees_after</span><span class="p">,</span> <span class="n">d_min</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">new_ll</span><span class="p">,</span> <span class="n">new_alpha</span><span class="p">,</span> <span class="n">new_n</span><span class="p">,</span> <span class="n">sum_log_degrees_after</span></div>


<span class="k">def</span> <span class="nf">update_sum_log_degrees</span><span class="p">(</span><span class="n">sum_log_degrees_before</span><span class="p">,</span> <span class="n">n_old</span><span class="p">,</span> <span class="n">d_old</span><span class="p">,</span> <span class="n">d_new</span><span class="p">,</span> <span class="n">d_min</span><span class="p">):</span>
    <span class="c1"># Find out whether the degrees before and after the change are above the threshold d_min.</span>
    <span class="n">old_in_range</span> <span class="o">=</span> <span class="n">d_old</span> <span class="o">&gt;=</span> <span class="n">d_min</span>
    <span class="n">new_in_range</span> <span class="o">=</span> <span class="n">d_new</span> <span class="o">&gt;=</span> <span class="n">d_min</span>
    <span class="n">d_old_in_range</span> <span class="o">=</span> <span class="n">d_old</span> <span class="o">*</span> <span class="n">old_in_range</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
    <span class="n">d_new_in_range</span> <span class="o">=</span> <span class="n">d_new</span> <span class="o">*</span> <span class="n">new_in_range</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>

    <span class="c1"># Update the sum by subtracting the old values and then adding the updated logs of the degrees.</span>
    <span class="n">sum_log_degrees_after</span> <span class="o">=</span> <span class="n">sum_log_degrees_before</span> <span class="o">-</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">d_old_in_range</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mi">1</span><span class="p">)))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> \
                                 <span class="o">+</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">d_new_in_range</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mi">1</span><span class="p">)))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># Update the number of degrees &gt;= d_min</span>

    <span class="n">new_n</span> <span class="o">=</span> <span class="n">n_old</span> <span class="o">-</span> <span class="p">(</span><span class="n">old_in_range</span><span class="o">!=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">new_in_range</span><span class="o">!=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">new_n</span> <span class="o">=</span> <span class="n">new_n</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">sum_log_degrees_after</span><span class="p">,</span> <span class="n">new_n</span>

<span class="k">def</span> <span class="nf">compute_alpha</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">sum_log_degrees</span><span class="p">,</span> <span class="n">d_min</span><span class="p">):</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="n">alpha</span> <span class="o">=</span>  <span class="mi">1</span> <span class="o">+</span> <span class="n">n</span> <span class="o">/</span> <span class="p">(</span><span class="n">sum_log_degrees</span> <span class="o">-</span> <span class="n">n</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">d_min</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">))</span>
    <span class="k">except</span><span class="p">:</span>
        <span class="n">alpha</span> <span class="o">=</span>  <span class="mi">1</span> <span class="o">+</span> <span class="n">n</span> <span class="o">/</span> <span class="p">(</span><span class="n">sum_log_degrees</span> <span class="o">-</span> <span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">d_min</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">alpha</span>

<span class="k">def</span> <span class="nf">compute_log_likelihood</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">sum_log_degrees</span><span class="p">,</span> <span class="n">d_min</span><span class="p">):</span>
    <span class="c1"># Log likelihood under alpha</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="n">ll</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="n">n</span> <span class="o">*</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">d_min</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">sum_log_degrees</span>
    <span class="k">except</span><span class="p">:</span>
        <span class="n">ll</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="n">n</span> <span class="o">*</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">d_min</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">sum_log_degrees</span>

    <span class="k">return</span> <span class="n">ll</span>

<div class="viewcode-block" id="ravel_multiple_indices"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.ravel_multiple_indices">[docs]</a><span class="k">def</span> <span class="nf">ravel_multiple_indices</span><span class="p">(</span><span class="n">ixs</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    &quot;Flattens&quot; multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index.</span>
<span class="sd">    Does the same as ravel_index but for multiple indices at once.</span>
<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    ixs: array of ints shape (n, 2)</span>
<span class="sd">        The array of n indices that will be flattened.</span>

<span class="sd">    shape: list or tuple of ints of length 2</span>
<span class="sd">        The shape of the corresponding matrix.</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    array of n ints between 0 and shape[0]*shape[1]-1</span>
<span class="sd">        The indices on the flattened matrix corresponding to the 2D input indices.</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">reverse</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">ixs</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">ixs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>

    <span class="k">return</span> <span class="n">ixs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">ixs</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span></div>

<div class="viewcode-block" id="visualize"><a class="viewcode-back" href="../../../source/deeprobust.graph.html#deeprobust.graph.utils.visualize">[docs]</a><span class="k">def</span> <span class="nf">visualize</span><span class="p">(</span><span class="n">your_var</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;visualize computation graph&quot;&quot;&quot;</span>
    <span class="kn">from</span> <span class="nn">graphviz</span> <span class="kn">import</span> <span class="n">Digraph</span>
    <span class="kn">import</span> <span class="nn">torch</span>
    <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
    <span class="kn">from</span> <span class="nn">torchviz</span> <span class="kn">import</span> <span class="n">make_dot</span>
    <span class="n">make_dot</span><span class="p">(</span><span class="n">your_var</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">()</span></div>

<span class="k">def</span> <span class="nf">reshape_mx</span><span class="p">(</span><span class="n">mx</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
    <span class="n">indices</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">mx</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="n">indices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">])),</span> <span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)</span>

<span class="c1"># def check_path(file_path):</span>
<span class="c1">#     if not osp.exists(file_path):</span>
<span class="c1">#         os.system(f&#39;mkdir -p {file_path}&#39;)</span>

</pre></div>

           </div>
           
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        
        &copy; Copyright 

    </p>
  </div>
    
    
    
    Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
    
    <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
    
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  

  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
   

</body>
</html>