

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>deeprobust.graph.data.dataset &mdash; DeepRobust 0.1.1 documentation</title>
  

  
  <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />

  
  
  
  

  
  <!--[if lt IE 9]>
    <script src="../../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../../../../" src="../../../../_static/documentation_options.js"></script>
        <script type="text/javascript" src="../../../../_static/jquery.js"></script>
        <script type="text/javascript" src="../../../../_static/underscore.js"></script>
        <script type="text/javascript" src="../../../../_static/doctools.js"></script>
        <script type="text/javascript" src="../../../../_static/language_data.js"></script>
        <script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    
    <script type="text/javascript" src="../../../../_static/js/theme.js"></script>

    
    <link rel="index" title="Index" href="../../../../genindex.html" />
    <link rel="search" title="Search" href="../../../../search.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../../../../index.html" class="icon icon-home" alt="Documentation Home"> DeepRobust
          

          
          </a>

          
            
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../notes/installation.html">Installation</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph/data.html">Graph Dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph/attack.html">Introduction to Graph Attack with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph/defense.html">Introduction to Graph Defense with Examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph/pyg.html">Using PyTorch Geometric in DeepRobust</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph/node_embedding.html">Node Embedding Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../image/example.html">Image Attack and Defense</a></li>
</ul>
<p class="caption"><span class="caption-text">Image Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.image.attack.html">deeprobust.image.attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.image.defense.html">deeprobust.image.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.image.netmodels.html">deeprobust.image.netmodels package</a></li>
</ul>
<p class="caption"><span class="caption-text">Graph Package</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.graph.global_attack.html">deeprobust.graph.global_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.graph.targeted_attack.html">deeprobust.graph.targeted_attack package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.graph.defense.html">deeprobust.graph.defense package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../source/deeprobust.graph.data.html">deeprobust.graph.data package</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../../index.html">DeepRobust</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
        
          <li><a href="../../../index.html">Module code</a> &raquo;</li>
        
      <li>deeprobust.graph.data.dataset</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <h1>Source code for deeprobust.graph.data.dataset</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">scipy.sparse</span> <span class="k">as</span> <span class="nn">sp</span>
<span class="kn">import</span> <span class="nn">os.path</span> <span class="k">as</span> <span class="nn">osp</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">urllib.request</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">pickle</span> <span class="k">as</span> <span class="nn">pkl</span>
<span class="kn">import</span> <span class="nn">networkx</span> <span class="k">as</span> <span class="nn">nx</span>
<span class="kn">from</span> <span class="nn">deeprobust.graph.utils</span> <span class="kn">import</span> <span class="n">get_train_val_test</span><span class="p">,</span> <span class="n">get_train_val_test_gcn</span>
<span class="kn">import</span> <span class="nn">zipfile</span>
<span class="kn">import</span> <span class="nn">json</span>

<div class="viewcode-block" id="Dataset"><a class="viewcode-back" href="../../../../source/deeprobust.graph.data.html#deeprobust.graph.data.dataset.Dataset">[docs]</a><span class="k">class</span> <span class="nc">Dataset</span><span class="p">():</span>
    <span class="sd">&quot;&quot;&quot;Dataset class contains four citation network datasets &quot;cora&quot;, &quot;cora-ml&quot;, &quot;citeseer&quot; and &quot;pubmed&quot;,</span>
<span class="sd">    and one blog dataset &quot;Polblogs&quot;. Datasets &quot;ACM&quot;, &quot;BlogCatalog&quot;, &quot;Flickr&quot;, &quot;UAI&quot;,</span>
<span class="sd">    &quot;Flickr&quot; are also available. See more details in https://github.com/DSE-MSU/DeepRobust/tree/master/deeprobust/graph#supported-datasets.</span>
<span class="sd">    The &#39;cora&#39;, &#39;cora-ml&#39;, &#39;polblogs&#39; and &#39;citeseer&#39; are downloaded from https://github.com/danielzuegner/gnn-meta-attack/tree/master/data, and &#39;pubmed&#39; is from https://github.com/tkipf/gcn/tree/master/gcn/data.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    root : string</span>
<span class="sd">        root directory where the dataset should be saved.</span>
<span class="sd">    name : string</span>
<span class="sd">        dataset name, it can be chosen from [&#39;cora&#39;, &#39;citeseer&#39;, &#39;cora_ml&#39;, &#39;polblogs&#39;,</span>
<span class="sd">        &#39;pubmed&#39;, &#39;acm&#39;, &#39;blogcatalog&#39;, &#39;uai&#39;, &#39;flickr&#39;]</span>
<span class="sd">    setting : string</span>
<span class="sd">        there are two data splits settings. It can be chosen from [&#39;nettack&#39;, &#39;gcn&#39;, &#39;prognn&#39;]</span>
<span class="sd">        The &#39;nettack&#39; setting follows nettack paper where they select the largest connected</span>
<span class="sd">        components of the graph and use 10%/10%/80% nodes for training/validation/test .</span>
<span class="sd">        The &#39;gcn&#39; setting follows gcn paper where they use the full graph and 20 samples</span>
<span class="sd">        in each class for traing, 500 nodes for validation, and 1000</span>
<span class="sd">        nodes for test. (Note here &#39;netack&#39; and &#39;gcn&#39; setting do not provide fixed split, i.e.,</span>
<span class="sd">        different random seed would return different data splits)</span>
<span class="sd">    seed : int</span>
<span class="sd">        random seed for splitting training/validation/test.</span>
<span class="sd">    require_mask : bool</span>
<span class="sd">        setting require_mask True to get training, validation and test mask</span>
<span class="sd">        (self.train_mask, self.val_mask, self.test_mask)</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">	We can first create an instance of the Dataset class and then take out its attributes.</span>

<span class="sd">	&gt;&gt;&gt; from deeprobust.graph.data import Dataset</span>
<span class="sd">	&gt;&gt;&gt; data = Dataset(root=&#39;/tmp/&#39;, name=&#39;cora&#39;, seed=15)</span>
<span class="sd">	&gt;&gt;&gt; adj, features, labels = data.adj, data.features, data.labels</span>
<span class="sd">	&gt;&gt;&gt; idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">setting</span><span class="o">=</span><span class="s1">&#39;nettack&#39;</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">require_mask</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">setting</span> <span class="o">=</span> <span class="n">setting</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>

        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;cora&#39;</span><span class="p">,</span> <span class="s1">&#39;citeseer&#39;</span><span class="p">,</span> <span class="s1">&#39;cora_ml&#39;</span><span class="p">,</span> <span class="s1">&#39;polblogs&#39;</span><span class="p">,</span>
                <span class="s1">&#39;pubmed&#39;</span><span class="p">,</span> <span class="s1">&#39;acm&#39;</span><span class="p">,</span> <span class="s1">&#39;blogcatalog&#39;</span><span class="p">,</span> <span class="s1">&#39;uai&#39;</span><span class="p">,</span> <span class="s1">&#39;flickr&#39;</span><span class="p">],</span> \
                <span class="s1">&#39;Currently only support cora, citeseer, cora_ml, &#39;</span> <span class="o">+</span> \
                <span class="s1">&#39;polblogs, pubmed, acm, blogcatalog, flickr&#39;</span>
        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">setting</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;gcn&#39;</span><span class="p">,</span> <span class="s1">&#39;nettack&#39;</span><span class="p">,</span> <span class="s1">&#39;prognn&#39;</span><span class="p">],</span> <span class="s2">&quot;Settings should be&quot;</span> <span class="o">+</span> \
                        <span class="s2">&quot; choosen from [&#39;gcn&#39;, &#39;nettack&#39;, &#39;prognn&#39;]&quot;</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">seed</span> <span class="o">=</span> <span class="n">seed</span>
        <span class="c1"># self.url =  &#39;https://raw.githubusercontent.com/danielzuegner/nettack/master/data/%s.npz&#39; % self.name</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">url</span> <span class="o">=</span>  <span class="s1">&#39;https://raw.githubusercontent.com/danielzuegner/gnn-meta-attack/master/data/</span><span class="si">%s</span><span class="s1">.npz&#39;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="n">osp</span><span class="o">.</span><span class="n">normpath</span><span class="p">(</span><span class="n">root</span><span class="p">))</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">data_filename</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span> <span class="o">+</span> <span class="s1">&#39;.npz&#39;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">require_mask</span> <span class="o">=</span> <span class="n">require_mask</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">require_lcc</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">if</span> <span class="n">setting</span> <span class="o">==</span> <span class="s1">&#39;gcn&#39;</span> <span class="k">else</span> <span class="kc">True</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">adj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">setting</span> <span class="o">==</span> <span class="s1">&#39;prognn&#39;</span><span class="p">:</span>
            <span class="k">assert</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;cora&#39;</span><span class="p">,</span> <span class="s1">&#39;citeseer&#39;</span><span class="p">,</span> <span class="s1">&#39;pubmed&#39;</span><span class="p">,</span> <span class="s1">&#39;cora_ml&#39;</span><span class="p">],</span> <span class="s2">&quot;ProGNN splits only &quot;</span> <span class="o">+</span> \
                        <span class="s2">&quot;cora, citeseer, pubmed, cora_ml&quot;</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">idx_train</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_val</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_test</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prognn_splits</span><span class="p">()</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">idx_train</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_val</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_test</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_train_val_test</span><span class="p">()</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">require_mask</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">get_mask</span><span class="p">()</span>

<div class="viewcode-block" id="Dataset.get_train_val_test"><a class="viewcode-back" href="../../../../source/deeprobust.graph.data.html#deeprobust.graph.data.dataset.Dataset.get_train_val_test">[docs]</a>    <span class="k">def</span> <span class="nf">get_train_val_test</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;Get training, validation, test splits according to self.setting (either &#39;nettack&#39; or &#39;gcn&#39;).</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">setting</span> <span class="o">==</span> <span class="s1">&#39;nettack&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">get_train_val_test</span><span class="p">(</span><span class="n">nnodes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adj</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">val_size</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">setting</span> <span class="o">==</span> <span class="s1">&#39;gcn&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">get_train_val_test_gcn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span></div>

<div class="viewcode-block" id="Dataset.get_prognn_splits"><a class="viewcode-back" href="../../../../source/deeprobust.graph.data.html#deeprobust.graph.data.dataset.Dataset.get_prognn_splits">[docs]</a>    <span class="k">def</span> <span class="nf">get_prognn_splits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;Get target nodes incides, which is the nodes with degree &gt; 10 in the test set.&quot;&quot;&quot;</span>
        <span class="n">url</span> <span class="o">=</span> <span class="s1">&#39;https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/&#39;</span> <span class="o">+</span> \
                     <span class="s1">&#39;master/splits/</span><span class="si">{}</span><span class="s1">_prognn_splits.json&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
        <span class="n">json_file</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span>
                <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">_prognn_splits.json&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>

        <span class="k">if</span> <span class="ow">not</span> <span class="n">osp</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">json_file</span><span class="p">):</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">download_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">json_file</span><span class="p">)</span>
        <span class="c1"># with open(f&#39;/mnt/home/jinwei2/Projects/nettack/{dataset}_nettacked_nodes.json&#39;, &#39;r&#39;) as f:</span>
        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">json_file</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
            <span class="n">idx</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">idx</span><span class="p">[</span><span class="s1">&#39;idx_train&#39;</span><span class="p">]),</span> \
               <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">idx</span><span class="p">[</span><span class="s1">&#39;idx_val&#39;</span><span class="p">]),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">idx</span><span class="p">[</span><span class="s1">&#39;idx_test&#39;</span><span class="p">])</span></div>

    <span class="k">def</span> <span class="nf">load_data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Loading </span><span class="si">{}</span><span class="s1"> dataset...&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s1">&#39;pubmed&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_pubmed</span><span class="p">()</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;acm&#39;</span><span class="p">,</span> <span class="s1">&#39;blogcatalog&#39;</span><span class="p">,</span> <span class="s1">&#39;uai&#39;</span><span class="p">,</span> <span class="s1">&#39;flickr&#39;</span><span class="p">]:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_zip</span><span class="p">()</span>

        <span class="k">if</span> <span class="ow">not</span> <span class="n">osp</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_filename</span><span class="p">):</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">download_npz</span><span class="p">()</span>

        <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_adj</span><span class="p">()</span>
        <span class="k">return</span> <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span>

    <span class="k">def</span> <span class="nf">download_file</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Dowloading from </span><span class="si">{}</span><span class="s1"> to </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">))</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
        <span class="k">except</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;Download failed! Make sure you have </span><span class="se">\</span>
<span class="s2">                    stable Internet connection and enter the right name&quot;</span><span class="p">)</span>

<div class="viewcode-block" id="Dataset.download_npz"><a class="viewcode-back" href="../../../../source/deeprobust.graph.data.html#deeprobust.graph.data.dataset.Dataset.download_npz">[docs]</a>    <span class="k">def</span> <span class="nf">download_npz</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;Download adjacen matrix npz file from self.url.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Downloading from </span><span class="si">{}</span><span class="s1"> to </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">url</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_filename</span><span class="p">))</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">url</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_filename</span><span class="p">)</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Done!&#39;</span><span class="p">)</span>
        <span class="k">except</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;&#39;&#39;Download failed! Make sure you have stable Internet connection and enter the right name&#39;&#39;&#39;</span><span class="p">)</span></div>

    <span class="k">def</span> <span class="nf">download_pubmed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
        <span class="n">url</span> <span class="o">=</span> <span class="s1">&#39;https://raw.githubusercontent.com/tkipf/gcn/master/gcn/data/&#39;</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Downloading&#39;</span><span class="p">,</span> <span class="n">url</span><span class="p">)</span>
            <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span> <span class="o">+</span> <span class="n">name</span><span class="p">,</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">name</span><span class="p">))</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Done!&#39;</span><span class="p">)</span>
        <span class="k">except</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;&#39;&#39;Download failed! Make sure you have stable Internet connection and enter the right name&#39;&#39;&#39;</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">download_zip</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
        <span class="n">url</span> <span class="o">=</span> <span class="s1">&#39;https://raw.githubusercontent.com/ChandlerBang/Pro-GNN/master/other_datasets/</span><span class="si">{}</span><span class="s1">.zip&#39;</span><span class="o">.</span>\
                <span class="nb">format</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Downlading&#39;</span><span class="p">,</span> <span class="n">url</span><span class="p">)</span>
            <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">name</span><span class="o">+</span><span class="s1">&#39;.zip&#39;</span><span class="p">))</span>
            <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Done!&#39;</span><span class="p">)</span>
        <span class="k">except</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;&#39;&#39;Download failed! Make sure you have stable Internet connection and enter the right name&#39;&#39;&#39;</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">load_zip</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">data_filename</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span> <span class="o">+</span> <span class="s1">&#39;.zip&#39;</span>
        <span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="n">osp</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">data_filename</span><span class="p">):</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">download_zip</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
            <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">data_filename</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">zip_ref</span><span class="p">:</span>
                <span class="n">zip_ref</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">)</span>

        <span class="n">feature_path</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{0}</span><span class="s1">.feature&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
        <span class="n">label_path</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{0}</span><span class="s1">.label&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
        <span class="n">graph_path</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_folder</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{0}</span><span class="s1">.edge&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>

        <span class="n">f</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="n">feature_path</span><span class="p">,</span> <span class="n">dtype</span> <span class="o">=</span> <span class="nb">float</span><span class="p">)</span>
        <span class="n">l</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="n">label_path</span><span class="p">,</span> <span class="n">dtype</span> <span class="o">=</span> <span class="nb">int</span><span class="p">)</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="c1"># features = torch.FloatTensor(np.array(features.todense()))</span>
        <span class="n">struct_edges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">genfromtxt</span><span class="p">(</span><span class="n">graph_path</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="n">sedges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">struct_edges</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">struct_edges</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
        <span class="n">n</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">sadj</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">coo_matrix</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sedges</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="p">(</span><span class="n">sedges</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">sedges</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="n">sadj</span> <span class="o">=</span> <span class="n">sadj</span> <span class="o">+</span> <span class="n">sadj</span><span class="o">.</span><span class="n">T</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">sadj</span><span class="o">.</span><span class="n">T</span> <span class="o">&gt;</span> <span class="n">sadj</span><span class="p">)</span> <span class="o">-</span> <span class="n">sadj</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">sadj</span><span class="o">.</span><span class="n">T</span> <span class="o">&gt;</span> <span class="n">sadj</span><span class="p">)</span>
        <span class="n">label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">l</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">sadj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">label</span>

    <span class="k">def</span> <span class="nf">load_pubmed</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">dataset</span> <span class="o">=</span> <span class="s1">&#39;pubmed&#39;</span>
        <span class="n">names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;x&#39;</span><span class="p">,</span> <span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="s1">&#39;tx&#39;</span><span class="p">,</span> <span class="s1">&#39;ty&#39;</span><span class="p">,</span> <span class="s1">&#39;allx&#39;</span><span class="p">,</span> <span class="s1">&#39;ally&#39;</span><span class="p">,</span> <span class="s1">&#39;graph&#39;</span><span class="p">]</span>
        <span class="n">objects</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)):</span>
            <span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;ind.</span><span class="si">{}</span><span class="s2">.</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">names</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
            <span class="n">data_filename</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>

            <span class="k">if</span> <span class="ow">not</span> <span class="n">osp</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">data_filename</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">download_pubmed</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>

            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_filename</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
                <span class="k">if</span> <span class="n">sys</span><span class="o">.</span><span class="n">version_info</span> <span class="o">&gt;</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">):</span>
                    <span class="n">objects</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;latin1&#39;</span><span class="p">))</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">objects</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">))</span>

        <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">tx</span><span class="p">,</span> <span class="n">ty</span><span class="p">,</span> <span class="n">allx</span><span class="p">,</span> <span class="n">ally</span><span class="p">,</span> <span class="n">graph</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">objects</span><span class="p">)</span>


        <span class="n">test_idx_file</span> <span class="o">=</span> <span class="s2">&quot;ind.</span><span class="si">{}</span><span class="s2">.test.index&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="n">osp</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">test_idx_file</span><span class="p">)):</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">download_pubmed</span><span class="p">(</span><span class="n">test_idx_file</span><span class="p">)</span>

        <span class="n">test_idx_reorder</span> <span class="o">=</span> <span class="n">parse_index_file</span><span class="p">(</span><span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">test_idx_file</span><span class="p">))</span>
        <span class="n">test_idx_range</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">test_idx_reorder</span><span class="p">)</span>

        <span class="n">features</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">vstack</span><span class="p">((</span><span class="n">allx</span><span class="p">,</span> <span class="n">tx</span><span class="p">))</span><span class="o">.</span><span class="n">tolil</span><span class="p">()</span>
        <span class="n">features</span><span class="p">[</span><span class="n">test_idx_reorder</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">features</span><span class="p">[</span><span class="n">test_idx_range</span><span class="p">,</span> <span class="p">:]</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">adjacency_matrix</span><span class="p">(</span><span class="n">nx</span><span class="o">.</span><span class="n">from_dict_of_lists</span><span class="p">(</span><span class="n">graph</span><span class="p">))</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">((</span><span class="n">ally</span><span class="p">,</span> <span class="n">ty</span><span class="p">))</span>
        <span class="n">labels</span><span class="p">[</span><span class="n">test_idx_reorder</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">test_idx_range</span><span class="p">,</span> <span class="p">:]</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">labels</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
        <span class="k">return</span> <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span>

    <span class="k">def</span> <span class="nf">get_adj</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_npz</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_filename</span><span class="p">)</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">adj</span> <span class="o">+</span> <span class="n">adj</span><span class="o">.</span><span class="n">T</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">tolil</span><span class="p">()</span>
        <span class="n">adj</span><span class="p">[</span><span class="n">adj</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">require_lcc</span><span class="p">:</span>
            <span class="n">lcc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">largest_connected_components</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
            <span class="n">adj</span> <span class="o">=</span> <span class="n">adj</span><span class="p">[</span><span class="n">lcc</span><span class="p">][:,</span> <span class="n">lcc</span><span class="p">]</span>
            <span class="n">features</span> <span class="o">=</span> <span class="n">features</span><span class="p">[</span><span class="n">lcc</span><span class="p">]</span>
            <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">lcc</span><span class="p">]</span>
            <span class="k">assert</span> <span class="n">adj</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">A1</span><span class="o">.</span><span class="n">min</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Graph contains singleton nodes&quot;</span>

        <span class="c1"># whether to set diag=0?</span>
        <span class="n">adj</span><span class="o">.</span><span class="n">setdiag</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="n">adj</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">tocsr</span><span class="p">()</span>
        <span class="n">adj</span><span class="o">.</span><span class="n">eliminate_zeros</span><span class="p">()</span>

        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">adj</span> <span class="o">-</span> <span class="n">adj</span><span class="o">.</span><span class="n">T</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Input graph is not symmetric&quot;</span>
        <span class="k">assert</span> <span class="n">adj</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">adj</span><span class="p">[</span><span class="n">adj</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()]</span><span class="o">.</span><span class="n">A1</span><span class="p">))</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Graph must be unweighted&quot;</span>

        <span class="k">return</span> <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span>

    <span class="k">def</span> <span class="nf">load_npz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_name</span><span class="p">,</span> <span class="n">is_sparse</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
        <span class="k">with</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">file_name</span><span class="p">)</span> <span class="k">as</span> <span class="n">loader</span><span class="p">:</span>
            <span class="c1"># loader = dict(loader)</span>
            <span class="k">if</span> <span class="n">is_sparse</span><span class="p">:</span>
                <span class="n">adj</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">loader</span><span class="p">[</span><span class="s1">&#39;adj_data&#39;</span><span class="p">],</span> <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;adj_indices&#39;</span><span class="p">],</span>
                                            <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;adj_indptr&#39;</span><span class="p">]),</span> <span class="n">shape</span><span class="o">=</span><span class="n">loader</span><span class="p">[</span><span class="s1">&#39;adj_shape&#39;</span><span class="p">])</span>
                <span class="k">if</span> <span class="s1">&#39;attr_data&#39;</span> <span class="ow">in</span> <span class="n">loader</span><span class="p">:</span>
                    <span class="n">features</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">loader</span><span class="p">[</span><span class="s1">&#39;attr_data&#39;</span><span class="p">],</span> <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;attr_indices&#39;</span><span class="p">],</span>
                                                 <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;attr_indptr&#39;</span><span class="p">]),</span> <span class="n">shape</span><span class="o">=</span><span class="n">loader</span><span class="p">[</span><span class="s1">&#39;attr_shape&#39;</span><span class="p">])</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">features</span> <span class="o">=</span> <span class="kc">None</span>
                <span class="n">labels</span> <span class="o">=</span> <span class="n">loader</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;labels&#39;</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">adj</span> <span class="o">=</span> <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;adj_data&#39;</span><span class="p">]</span>
                <span class="k">if</span> <span class="s1">&#39;attr_data&#39;</span> <span class="ow">in</span> <span class="n">loader</span><span class="p">:</span>
                    <span class="n">features</span> <span class="o">=</span> <span class="n">loader</span><span class="p">[</span><span class="s1">&#39;attr_data&#39;</span><span class="p">]</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">features</span> <span class="o">=</span> <span class="kc">None</span>
                <span class="n">labels</span> <span class="o">=</span> <span class="n">loader</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;labels&#39;</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">features</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">adj</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">features</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span>

<div class="viewcode-block" id="Dataset.largest_connected_components"><a class="viewcode-back" href="../../../../source/deeprobust.graph.data.html#deeprobust.graph.data.dataset.Dataset.largest_connected_components">[docs]</a>    <span class="k">def</span> <span class="nf">largest_connected_components</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">adj</span><span class="p">,</span> <span class="n">n_components</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;Select k largest connected components.</span>

<span class="sd">		Parameters</span>
<span class="sd">		----------</span>
<span class="sd">		adj : scipy.sparse.csr_matrix</span>
<span class="sd">			input adjacency matrix</span>
<span class="sd">		n_components : int</span>
<span class="sd">			n largest connected components we want to select</span>
<span class="sd">		&quot;&quot;&quot;</span>

        <span class="n">_</span><span class="p">,</span> <span class="n">component_indices</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">csgraph</span><span class="o">.</span><span class="n">connected_components</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
        <span class="n">component_sizes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">bincount</span><span class="p">(</span><span class="n">component_indices</span><span class="p">)</span>
        <span class="n">components_to_keep</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">component_sizes</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">][:</span><span class="n">n_components</span><span class="p">]</span>  <span class="c1"># reverse order to sort descending</span>
        <span class="n">nodes_to_keep</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">idx</span> <span class="k">for</span> <span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">component</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">component_indices</span><span class="p">)</span> <span class="k">if</span> <span class="n">component</span> <span class="ow">in</span> <span class="n">components_to_keep</span><span class="p">]</span>
        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Selecting </span><span class="si">{0}</span><span class="s2"> largest connected components&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">n_components</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">nodes_to_keep</span></div>

    <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="s1">&#39;</span><span class="si">{0}</span><span class="s1">(adj_shape=</span><span class="si">{1}</span><span class="s1">, feature_shape=</span><span class="si">{2}</span><span class="s1">)&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">adj</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">get_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_train</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_val</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx_test</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">onehot</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span>

        <span class="k">def</span> <span class="nf">get_mask</span><span class="p">(</span><span class="n">idx</span><span class="p">):</span>
            <span class="n">mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
            <span class="n">mask</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
            <span class="k">return</span> <span class="n">mask</span>

        <span class="k">def</span> <span class="nf">get_y</span><span class="p">(</span><span class="n">idx</span><span class="p">):</span>
            <span class="n">mx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
            <span class="n">mx</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
            <span class="k">return</span> <span class="n">mx</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">train_mask</span> <span class="o">=</span> <span class="n">get_mask</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">idx_train</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">val_mask</span> <span class="o">=</span> <span class="n">get_mask</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">idx_val</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">test_mask</span> <span class="o">=</span> <span class="n">get_mask</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">idx_test</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">y_train</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_val</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_test</span> <span class="o">=</span> <span class="n">get_y</span><span class="p">(</span><span class="n">idx_train</span><span class="p">),</span> <span class="n">get_y</span><span class="p">(</span><span class="n">idx_val</span><span class="p">),</span> <span class="n">get_y</span><span class="p">(</span><span class="n">idx_test</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">onehot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
        <span class="n">eye</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">identity</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">onehot_mx</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">labels</span><span class="p">]</span>
        <span class="k">return</span> <span class="n">onehot_mx</span></div>

<span class="k">def</span> <span class="nf">parse_index_file</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
    <span class="n">index</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="nb">open</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
        <span class="n">index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()))</span>
    <span class="k">return</span> <span class="n">index</span>


<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
    <span class="kn">from</span> <span class="nn">deeprobust.graph.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
    <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;cora&#39;</span><span class="p">,</span> <span class="s1">&#39;citeseer&#39;</span><span class="p">,</span> <span class="s1">&#39;pubmed&#39;</span><span class="p">,</span> <span class="s1">&#39;cora_ml&#39;</span><span class="p">]:</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;/tmp/&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="n">setting</span><span class="o">=</span><span class="s2">&quot;prognn&quot;</span><span class="p">)</span>
        <span class="n">idx_train</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_train</span>
        <span class="n">data2</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;/tmp/&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="n">setting</span><span class="o">=</span><span class="s2">&quot;nettack&quot;</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">15</span><span class="p">)</span>
        <span class="n">idx_train2</span> <span class="o">=</span> <span class="n">data2</span><span class="o">.</span><span class="n">idx_train</span>
        <span class="k">assert</span> <span class="p">(</span><span class="n">idx_train</span> <span class="o">!=</span> <span class="n">idx_train2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span>

    <span class="n">data</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;/tmp/&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;flickr&#39;</span><span class="p">)</span>
    <span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">adj</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">labels</span>
    <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_train</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_val</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_test</span>

</pre></div>

           </div>
           
          </div>
          <footer>
  

  <hr/>

  <div role="contentinfo">
    <p>
        
        &copy; Copyright 

    </p>
  </div>
    
    
    
    Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
    
    <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
    
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  

  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
   

</body>
</html>