<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>pygho.backend.SpTensor &mdash; PyGHO 0.0.1 documentation</title>
      <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script src="../../../_static/jquery.js?v=5d32c60e"></script>
        <script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
        <script src="../../../_static/documentation_options.js?v=d45e8c67"></script>
        <script src="../../../_static/doctools.js?v=888ff710"></script>
        <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
    <script src="../../../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../../../index.html" class="icon icon-home">
            PyGHO
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/installation.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/miniexample.html">Minimal Example</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/datastructure.html">Refined Basic Data Structure</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/hodata.html">Efficient High Order Data Processing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/operator.html">Operators</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced Tutorial</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/multtensor.html">Multiple Tensor</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Package Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/backend.html">pygho.backend package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/hodata.html">pygho.hodata package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/honn.html">pygho.honn package</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../index.html">PyGHO</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
          <li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
      <li class="breadcrumb-item active">pygho.backend.SpTensor</li>
      <li class="wy-breadcrumbs-aside">
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <h1>Source code for pygho.backend.SpTensor</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Callable</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Iterable</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">.utils</span> <span class="kn">import</span> <span class="n">torch_scatter_reduce</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Final</span>


<div class="viewcode-block" id="indicehash">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.indicehash">[docs]</a>
<span class="k">def</span> <span class="nf">indicehash</span><span class="p">(</span><span class="n">indice</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Hashes a indice of shape (sparse_dim, nnz) to a single LongTensor of shape (nnz). Keep lexicographic order.</span>

<span class="sd">    Parameters:</span>
<span class="sd">    - indice (LongTensor): The input indices tensor of shape (sparse_dim, nnz).</span>

<span class="sd">    Returns:</span>
<span class="sd">    - LongTensor: A single LongTensor representing the hashed values.</span>

<span class="sd">    Raises:</span>
<span class="sd">    - AssertionError: If the input tensor doesn&#39;t have the expected shape or if the indices are too large or if there exists negative indice.</span>

<span class="sd">    Example:</span>
<span class="sd">    </span>
<span class="sd">    ::</span>
<span class="sd">    </span>
<span class="sd">        indices = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)</span>
<span class="sd">        hashed = indicehash(indices)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">indice</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">indice</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;indice cannot be negative&quot;</span>
    <span class="n">sparse_dim</span> <span class="o">=</span> <span class="n">indice</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="k">if</span> <span class="n">sparse_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">indice</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">interval</span> <span class="o">=</span> <span class="p">(</span><span class="mi">63</span> <span class="o">//</span> <span class="n">sparse_dim</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">indice</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&lt;</span> <span class="p">(</span>
        <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="n">interval</span><span class="p">),</span> <span class="s2">&quot;too large indice, hash is not injective&quot;</span>

    <span class="n">eihash</span> <span class="o">=</span> <span class="n">indice</span><span class="p">[</span><span class="n">sparse_dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">sparse_dim</span><span class="p">):</span>
        <span class="n">eihash</span><span class="o">.</span><span class="n">bitwise_or_</span><span class="p">(</span><span class="n">indice</span><span class="p">[</span><span class="n">sparse_dim</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">bitwise_left_shift</span><span class="p">(</span>
            <span class="n">interval</span> <span class="o">*</span> <span class="n">i</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">eihash</span></div>



<div class="viewcode-block" id="decodehash">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.decodehash">[docs]</a>
<span class="k">def</span> <span class="nf">decodehash</span><span class="p">(</span><span class="n">indhash</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">sparse_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Decodes a hashed LongTensor into tuples of indices.</span>

<span class="sd">    This function takes a hashed LongTensor and decodes it into pairs of indices,</span>
<span class="sd">    which is commonly used in sparse tensor operations.</span>

<span class="sd">    Parameters:</span>

<span class="sd">    - indhash (LongTensor): The input hashed LongTensor of shape (nnz).</span>
<span class="sd">    - sparse_dim (int): The number of dimensions represented by the hash.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - LongTensor: A LongTensor representing pairs of indices.</span>

<span class="sd">    Raises:</span>

<span class="sd">    - AssertionError: If the input tensor doesn&#39;t have the expected shape or</span>
<span class="sd">      if the sparse dimension is invalid.</span>

<span class="sd">    Example:</span>

<span class="sd">    ::</span>

<span class="sd">        indices = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)</span>
<span class="sd">        hashed = indicehash(indices)</span>
<span class="sd">        indices = decodehash(hashed)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">sparse_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">indhash</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">indhash</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;indhash should of shape (nnz) &quot;</span>
    <span class="n">interval</span> <span class="o">=</span> <span class="p">(</span><span class="mi">63</span> <span class="o">//</span> <span class="n">sparse_dim</span><span class="p">)</span>
    <span class="n">mask</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="s2">&quot;0b&quot;</span> <span class="o">+</span> <span class="s2">&quot;1&quot;</span> <span class="o">*</span> <span class="n">interval</span><span class="p">)</span>
    <span class="n">offset</span> <span class="o">=</span> <span class="p">(</span><span class="n">sparse_dim</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span>
        <span class="n">sparse_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">indhash</span><span class="o">.</span><span class="n">device</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">interval</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bitwise_right_shift</span><span class="p">(</span><span class="n">indhash</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
                                    <span class="n">offset</span><span class="p">)</span><span class="o">.</span><span class="n">bitwise_and_</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">ret</span></div>



<div class="viewcode-block" id="indicehash_tight">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.indicehash_tight">[docs]</a>
<span class="k">def</span> <span class="nf">indicehash_tight</span><span class="p">(</span><span class="n">indice</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">dimsize</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Hashes a 2D LongTensor of indices tightly into a single LongTensor.</span>
<span class="sd">    Equivalently, it compute the indice of flattened sparse tensor with indice and dimsize</span>

<span class="sd">    Parameters:</span>
<span class="sd">    - indice (LongTensor): The input indices tensor of shape (sparse_dim, nnz).</span>
<span class="sd">    - dimsize (LongTensor): The sizes of each dimension in the sparse tensor of shape (sparse_dim).</span>

<span class="sd">    Returns:</span>
<span class="sd">    - LongTensor: A single LongTensor representing the tightly hashed values.</span>

<span class="sd">    Raises:</span>
<span class="sd">    - AssertionError: If the input tensors don&#39;t have the expected shapes or if the indices exceed the dimension sizes.</span>

<span class="sd">    Example:</span>
<span class="sd">    </span>
<span class="sd">    ::</span>

<span class="sd">        indices = torch.tensor([[1, 2, 0], [4, 1, 2]], dtype=torch.long)</span>
<span class="sd">        dim_sizes = torch.tensor([3, 5], dtype=torch.long)</span>
<span class="sd">        hashed = indicehash_tight(indices, dim_sizes)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">indice</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;indice shoule be of shape (sparse_dim, nnz) &quot;</span>
    <span class="k">assert</span> <span class="n">dimsize</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;dim size should be of shape (sparse_dim)&quot;</span>
    <span class="k">assert</span> <span class="n">dimsize</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">indice</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
        <span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;indice dim and dim size not match&quot;</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">indice</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">dimsize</span><span class="p">),</span> <span class="s2">&quot;indice exceeds dimsize&quot;</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dimsize</span><span class="p">)</span> <span class="o">&lt;</span> <span class="p">(</span>
        <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="mi">62</span><span class="p">),</span> <span class="s2">&quot;total size exceeds the range that torch.long can express&quot;</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">indice</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;indice cannot be negative&quot;</span>
    <span class="k">if</span> <span class="n">indice</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">indice</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">dimsize</span><span class="p">)</span>
    <span class="n">step</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">dimsize</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">)),</span> <span class="mi">0</span><span class="p">),</span>
                           <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">))</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">step</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">indice</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></div>



<div class="viewcode-block" id="decodehash_tight">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.decodehash_tight">[docs]</a>
<span class="k">def</span> <span class="nf">decodehash_tight</span><span class="p">(</span><span class="n">indhash</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">dimsize</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Decodes a tightly hashed LongTensor into pairs of indices considering dimension sizes.</span>

<span class="sd">    Parameters:</span>
<span class="sd">    - indhash (LongTensor): The input hashed LongTensor of shape (nnz).</span>
<span class="sd">    - dimsize (LongTensor): The sizes of each dimension in the sparse tensor of shape (sparse_dim).</span>

<span class="sd">    Returns:</span>
<span class="sd">    - LongTensor: A LongTensor representing pairs of indices.</span>

<span class="sd">    Raises:</span>
<span class="sd">    - AssertionError: If the input tensors don&#39;t have the expected shapes or if the total size exceeds the range that torch.long can express.</span>

<span class="sd">    Example:</span>

<span class="sd">    ::</span>

<span class="sd">        indices = torch.tensor([[1, 2, 0], [4, 1, 2]], dtype=torch.long)</span>
<span class="sd">        dim_sizes = torch.tensor([3, 5], dtype=torch.long)</span>
<span class="sd">        hashed = indicehash_tight(indices, dim_sizes)</span>
<span class="sd">        indices = decodehash_tight(hashed, dim_sizes)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">indhash</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;indhash should of shape (nnz) &quot;</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dimsize</span><span class="p">)</span> <span class="o">&lt;</span> <span class="p">(</span>
        <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="mi">62</span><span class="p">),</span> <span class="s2">&quot;total size exceeds the range that torch.long can express&quot;</span>
    <span class="k">if</span> <span class="n">dimsize</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">indhash</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">dimsize</span><span class="p">)</span>
    <span class="n">step</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">dimsize</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">)),</span> <span class="mi">0</span><span class="p">),</span>
                           <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">))</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">indhash</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">step</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">-=</span> <span class="n">ret</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">dimsize</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">ret</span></div>



<div class="viewcode-block" id="coalesce">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.coalesce">[docs]</a>
<span class="k">def</span> <span class="nf">coalesce</span><span class="p">(</span><span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
             <span class="n">edge_attr</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
             <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;sum&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]]:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Coalesces and reduces duplicate entries in edge indices and attributes.</span>
<span class="sd">    </span>
<span class="sd">    Args:</span>

<span class="sd">    - edge_index (LongTensor): The edge indices.</span>
<span class="sd">    - edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-dimensional</span>
<span class="sd">      edge features. If given as a list, it will be reshuffled and duplicates will be</span>
<span class="sd">      removed for all entries. (default: None)</span>
<span class="sd">    - reduce (str, optional): The reduction operation to use for merging edge features.</span>
<span class="sd">      Options include &#39;sum&#39;, &#39;mean&#39;, &#39;min&#39;, &#39;max&#39;, &#39;mul&#39;. (default: &#39;sum&#39;)</span>

<span class="sd">    Returns:</span>

<span class="sd">    - Tuple[Tensor, Optional[Tensor]]: A tuple containing the coalesced edge indices</span>
<span class="sd">      and the coalesced and reduced edge attributes (if provided). If edge_attr is</span>
<span class="sd">      None, the second element will be None.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">sparsedim</span> <span class="o">=</span> <span class="n">edge_index</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">eihash</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span><span class="n">edge_index</span><span class="p">)</span>
    <span class="n">eihash</span><span class="p">,</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">eihash</span><span class="p">,</span> <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="n">edge_index</span> <span class="o">=</span> <span class="n">decodehash</span><span class="p">(</span><span class="n">eihash</span><span class="p">,</span> <span class="n">sparsedim</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">edge_attr</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">edge_index</span><span class="p">,</span> <span class="kc">None</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">edge_attr</span> <span class="o">=</span> <span class="n">torch_scatter_reduce</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">edge_attr</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">eihash</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
                                         <span class="n">reduce</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">edge_attr</span></div>



<div class="viewcode-block" id="SparseTensor">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor">[docs]</a>
<span class="k">class</span> <span class="nc">SparseTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Represents a sparse tensor in coo format.</span>

<span class="sd">    This class allows you to work with sparse tensors represented by indices and</span>
<span class="sd">    values. It provides various operations such as sum, max, mean, unpooling,</span>
<span class="sd">    diagonal extraction, and more.</span>

<span class="sd">    Parameters:</span>
<span class="sd">    - indices (LongTensor): The indices of the sparse tensor, of shape (#sparsedim, #nnz).</span>
<span class="sd">    - values (Optional[Tensor]): The values associated with the indices, of shape (#nnz,\*denseshapeshape). Should have the same number of nnz as indices. Defaults to None.</span>
<span class="sd">    - shape (Optional[List[int]]): The shape of the sparse tensor. If None, it is computed from the indices and values. Defaults to None.</span>
<span class="sd">    - is_coalesced (bool): Indicates whether the indices and values are coalesced. Defaults to False.</span>

<span class="sd">    Methods:</span>
<span class="sd">    - is_coalesced(self): Check if the tensor is coalesced.</span>
<span class="sd">    - to(self, device: torch.DeviceObjType, non_blocking: bool = False): Move the tensor to the specified device.</span>
<span class="sd">    - diag(self, dims: Optional[Iterable[int]], return_sparse: bool = False): Extract diagonal elements from the tensor. The dimensions in dims will be take diagonal and put at dims[0]</span>
<span class="sd">    - sum(self, dims: Union[int, Optional[Iterable[int]]], return_sparse: bool = False): Compute the sum of tensor values along specified dimensions. return_sparse=True will return a sparse tensor, otherwise return a dense tensor.</span>
<span class="sd">    - max(self, dims: Union[int, Optional[Iterable[int]]], return_sparse: bool = False): Compute the maximum of tensor values along specified dimensions. return_sparse=True will return a sparse tensor, otherwise return a dense tensor.</span>
<span class="sd">    - mean(self, dims: Union[int, Optional[Iterable[int]]], return_sparse: bool = False): Compute the mean of tensor values along specified dimensions. return_sparse=True will return a sparse tensor, otherwise return a dense tensor.</span>
<span class="sd">    - unpooling(self, dims: Union[int, Iterable[int]], tarX): Perform unpooling operation along specified dimensions.</span>
<span class="sd">    - tuplewiseapply(self, func: Callable[[Tensor], Tensor]): Apply a function to each element of the tensor.</span>
<span class="sd">    - diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): Apply a function to diagonal elements of the tensor.</span>
<span class="sd">    - add(self, tarX, samesparse: bool): Add two sparse tensors together. samesparse=True means that two sparse tensor have the indice and can add values directly. </span>
<span class="sd">    - catvalue(self, tarX, samesparse: bool): Concatenate values of two sparse tensors. samesparse=True means that two sparse tensor have the indice and can cat values along the first dimension directly. </span>
<span class="sd">    - from_torch_sparse_coo(cls, A: torch.Tensor): Create a SparseTensor from a torch sparse COO tensor.</span>
<span class="sd">    - to_torch_sparse_coo(self) -&gt; Tensor: Convert the SparseTensor to a torch sparse COO tensor.</span>

<span class="sd">    Attributes:</span>
<span class="sd">    - indices (LongTensor): The indices of the sparse tensor.</span>
<span class="sd">    - values (Tensor): The values associated with the indices.</span>
<span class="sd">    - sparse_dim (int): The number of dimensions represented by the indices.</span>
<span class="sd">    - nnz (int): The number of non-zero values.</span>
<span class="sd">    - shape (torch.Size): The shape of the tensor.</span>
<span class="sd">    - sparseshape (torch.Size): The shape of the tensor up to the sparse dimensions.</span>
<span class="sd">    - denseshape (torch.Size): The shape of the tensor after the sparse dimensions.</span>

<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">indices</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
                 <span class="n">values</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
                 <span class="n">shape</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
                 <span class="n">is_coalesced</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
                 <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;sum&quot;</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">indices</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;indice should of shape (#sparsedim, #nnz)&quot;</span>
        <span class="k">if</span> <span class="n">values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">assert</span> <span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">values</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
                <span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;indices and values should have the same number of nnz&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
            <span class="c1"># print(self.shape, self.denseshape, self.sparseshape, values.shape)</span>
            <span class="k">if</span> <span class="n">values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span> <span class="o">==</span> <span class="n">values</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
                    <span class="mi">1</span><span class="p">:],</span> <span class="s2">&quot;shape, value not match&quot;</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span>
                <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
                         <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()))</span> <span class="o">+</span>
                <span class="nb">list</span><span class="p">(</span><span class="n">values</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]))</span>
        <span class="k">if</span> <span class="n">is_coalesced</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__indices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">__values</span> <span class="o">=</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__indices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">__values</span> <span class="o">=</span> <span class="n">coalesce</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">reduce</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__nnz</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

<div class="viewcode-block" id="SparseTensor.is_coalesced">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.is_coalesced">[docs]</a>
    <span class="k">def</span> <span class="nf">is_coalesced</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="kc">True</span></div>


<div class="viewcode-block" id="SparseTensor.to">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.to">[docs]</a>
    <span class="k">def</span> <span class="nf">to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">DeviceObjType</span><span class="p">,</span> <span class="n">non_blocking</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__indices</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__values</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span></div>


    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">indices</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__indices</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">values</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__values</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">sparse_dim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">nnz</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__nnz</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__shape</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">sparseshape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">]</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">denseshape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">:]</span>

    <span class="k">def</span> <span class="nf">_diag_to_sparse</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
            <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span>
        <span class="p">),</span> <span class="s2">&quot;please use tuplewiseapply for operation on dense dims&quot;</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;do not support negative dims&quot;</span>
<span class="w">        </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd">        diag dims is then put at the first dims in dims list.</span>
<span class="sd">        &#39;&#39;&#39;</span>
        <span class="n">dims</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">dims</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[[</span><span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">]]])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span>
                         <span class="n">dims</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dims</span><span class="p">[</span><span class="mi">1</span><span class="p">:]]</span>
        <span class="n">other_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">])</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span>
        <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="n">indices</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">idx</span><span class="p">][:,</span> <span class="n">mask</span><span class="p">],</span>
                            <span class="n">values</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">mask</span><span class="p">],</span>
                            <span class="n">shape</span><span class="o">=</span><span class="n">other_shape</span><span class="p">,</span>
                            <span class="n">is_coalesced</span><span class="o">=</span><span class="p">(</span><span class="n">idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>
                            <span class="ow">and</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">diff</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">_diag_to_dense</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd">        diag dims is then put at the first dims in dims list.</span>
<span class="sd">        &#39;&#39;&#39;</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
            <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span>
        <span class="p">),</span> <span class="s2">&quot;please use tuplewiseapply for operation on dense dims&quot;</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;do not support negative dims&quot;</span>
        <span class="n">dims</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">dims</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[[</span><span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">]]])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span>
                         <span class="n">dims</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dims</span><span class="p">[</span><span class="mi">1</span><span class="p">:]]</span>
        <span class="n">nsparse_shape</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">]</span>
        <span class="n">nsparse_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">nsparse_shape</span><span class="p">)</span>

        <span class="n">thash</span> <span class="o">=</span> <span class="n">indicehash_tight</span><span class="p">(</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">idx</span><span class="p">][:,</span> <span class="n">mask</span><span class="p">],</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">nsparse_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">nsparse_size</span><span class="p">,</span> <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span><span class="p">,</span>
                          <span class="n">device</span><span class="o">=</span><span class="n">thash</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
                          <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">ret</span><span class="p">[</span><span class="n">thash</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">unflatten</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">nsparse_shape</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">ret</span>

<div class="viewcode-block" id="SparseTensor.diag">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.diag">[docs]</a>
    <span class="k">def</span> <span class="nf">diag</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="n">return_sparse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd">        TODO: unit test ??</span>
<span class="sd">        &#39;&#39;&#39;</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="k">raise</span> <span class="ne">NotImplementedError</span>
        <span class="k">if</span> <span class="n">dims</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">return_sparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_diag_to_sparse</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_diag_to_dense</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span></div>


    <span class="k">def</span> <span class="nf">_reduce_to_sparse</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
            <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span>
        <span class="p">),</span> <span class="s2">&quot;please use tuplewiseapply for operation on dense dims&quot;</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;do not support negative dims&quot;</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">)]</span>
        <span class="n">other_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
        <span class="n">other_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">])</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span>
        <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="n">indices</span><span class="o">=</span><span class="n">other_ind</span><span class="p">,</span>
                            <span class="n">values</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span>
                            <span class="n">shape</span><span class="o">=</span><span class="n">other_shape</span><span class="p">,</span>
                            <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                            <span class="n">reduce</span><span class="o">=</span><span class="n">reduce</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_reduce_to_dense</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
            <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">__sparse_dim</span>
        <span class="p">),</span> <span class="s2">&quot;please use tuplewiseapply for operation on dense dims&quot;</span>
        <span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;do not support negative dims&quot;</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">)]</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">idx</span> <span class="o">=</span> <span class="n">idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            <span class="n">other_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
            <span class="n">nsparse_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
            <span class="n">ret</span> <span class="o">=</span> <span class="n">torch_scatter_reduce</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">other_ind</span><span class="p">,</span> <span class="n">nsparse_size</span><span class="p">,</span>
                                       <span class="n">reduce</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">ret</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">other_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
            <span class="n">other_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">)</span>
            <span class="n">nsparse_shape</span> <span class="o">=</span> <span class="n">other_shape</span>
            <span class="n">nsparse_size</span> <span class="o">=</span> <span class="mi">1</span>
            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">nsparse_shape</span><span class="p">:</span>
                <span class="n">nsparse_size</span> <span class="o">*=</span> <span class="n">_</span>

            <span class="n">thash</span> <span class="o">=</span> <span class="n">indicehash_tight</span><span class="p">(</span>
                <span class="n">other_ind</span><span class="p">,</span>
                <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">nsparse_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">other_ind</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
            <span class="n">ret</span> <span class="o">=</span> <span class="n">torch_scatter_reduce</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">thash</span><span class="p">,</span> <span class="n">nsparse_size</span><span class="p">,</span>
                                       <span class="n">reduce</span><span class="p">)</span>
            <span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">nsparse_shape</span> <span class="o">+</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">ret</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]))</span>
            <span class="k">return</span> <span class="n">ret</span>

<div class="viewcode-block" id="SparseTensor.sum">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.sum">[docs]</a>
    <span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
            <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span>
            <span class="n">return_sparse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">dims</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">dims</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">return_sparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_sparse</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_dense</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.max">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.max">[docs]</a>
    <span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
            <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span>
            <span class="n">return_sparse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">dims</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">dims</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">return_sparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_sparse</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_dense</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.mean">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.mean">[docs]</a>
    <span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
             <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span>
             <span class="n">return_sparse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">dims</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">dims</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">return_sparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_sparse</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;mean&quot;</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_to_dense</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="s2">&quot;mean&quot;</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.unpooling">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.unpooling">[docs]</a>
    <span class="k">def</span> <span class="nf">unpooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="n">tarX</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd">        unpooling to of tarX indice</span>
<span class="sd">        dims: of tarX</span>
<span class="sd">        &#39;&#39;&#39;</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">dims</span><span class="p">]</span>
        <span class="n">self_hash</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">)</span>
        <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">diff</span><span class="p">(</span><span class="n">self_hash</span><span class="p">)),</span> <span class="s2">&quot;self is not coalesced&quot;</span>
        <span class="n">tarX</span><span class="p">:</span> <span class="n">SparseTensor</span> <span class="o">=</span> <span class="n">tarX</span>
        <span class="n">taridx</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tarX</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">)]</span>
        <span class="n">tar_hash</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span><span class="n">tarX</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">taridx</span><span class="p">])</span>

        <span class="n">b2a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min_</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">self_hash</span><span class="p">,</span> <span class="n">tar_hash</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

        <span class="n">matchmask</span> <span class="o">=</span> <span class="p">(</span><span class="n">self_hash</span><span class="p">[</span><span class="n">b2a</span><span class="p">]</span> <span class="o">==</span> <span class="n">tar_hash</span><span class="p">)</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">tar_hash</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span><span class="p">,</span>
                          <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
                          <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="n">ret</span><span class="p">[</span><span class="n">matchmask</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">b2a</span><span class="p">[</span><span class="n">matchmask</span><span class="p">]]</span>
        <span class="k">return</span> <span class="n">tarX</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">ret</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.unpooling_fromdense1dim">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.unpooling_fromdense1dim">[docs]</a>
    <span class="k">def</span> <span class="nf">unpooling_fromdense1dim</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd">        unpooling to of self shape. Note the dims is for self to maintain, and expand other dims</span>
<span class="sd">        &#39;&#39;&#39;</span>
        <span class="k">assert</span> <span class="n">dims</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="p">,</span> <span class="s2">&quot;only unpooling sparse dims&quot;</span>
        <span class="k">assert</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dims</span><span class="p">],</span> <span class="s2">&quot;shape not match&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">X</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">dims</span><span class="p">]])</span></div>


<div class="viewcode-block" id="SparseTensor.from_torch_sparse_coo">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.from_torch_sparse_coo">[docs]</a>
    <span class="nd">@classmethod</span>
    <span class="k">def</span> <span class="nf">from_torch_sparse_coo</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">A</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">A</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">,</span> <span class="s2">&quot;from_torch_sparse_coo converts a torch.sparse_coo_tensor to SparseTensor&quot;</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">_indices</span><span class="p">(),</span> <span class="n">A</span><span class="o">.</span><span class="n">_values</span><span class="p">(),</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">A</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">())</span>
        <span class="k">return</span> <span class="n">ret</span></div>


<div class="viewcode-block" id="SparseTensor.to_torch_sparse_coo">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.to_torch_sparse_coo">[docs]</a>
    <span class="k">def</span> <span class="nf">to_torch_sparse_coo</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sparse_coo_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span>
                                      <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span>
                                      <span class="n">size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">_coalesced_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">())</span>
        <span class="k">return</span> <span class="n">ret</span></div>


<div class="viewcode-block" id="SparseTensor.tuplewiseapply">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.tuplewiseapply">[docs]</a>
    <span class="k">def</span> <span class="nf">tuplewiseapply</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]):</span>
        <span class="n">nvalues</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span>
                            <span class="n">nvalues</span><span class="p">,</span>
                            <span class="bp">self</span><span class="o">.</span><span class="n">sparseshape</span> <span class="o">+</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">nvalues</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span>
                            <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.diagonalapply">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.diagonalapply">[docs]</a>
    <span class="k">def</span> <span class="nf">diagonalapply</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]):</span>
        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;only implemented for 2D&quot;</span>
        <span class="n">nvalues</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span>
                       <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span>
                            <span class="n">nvalues</span><span class="p">,</span>
                            <span class="bp">self</span><span class="o">.</span><span class="n">sparseshape</span> <span class="o">+</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">nvalues</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span>
                            <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.add">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.add">[docs]</a>
    <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tarX</span><span class="p">,</span> <span class="n">samesparse</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="n">samesparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span>
                <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">indices</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
                <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">values</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
                <span class="kc">False</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">tarX</span><span class="o">.</span><span class="n">values</span><span class="p">)</span></div>


<div class="viewcode-block" id="SparseTensor.catvalue">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.SpTensor.SparseTensor.catvalue">[docs]</a>
    <span class="k">def</span> <span class="nf">catvalue</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tarX</span><span class="p">,</span> <span class="n">samesparse</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">samesparse</span> <span class="o">==</span> <span class="kc">True</span><span class="p">,</span> <span class="s2">&quot;must have the same sparcity to concat value&quot;</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tarX</span><span class="p">,</span> <span class="n">SparseTensor</span><span class="p">):</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">values</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tarX</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">):</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">values</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">tarX</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>


    <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="sa">f</span><span class="s1">&#39;SparseTensor(shape=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">, sparse_dim=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">sparse_dim</span><span class="si">}</span><span class="s1">, nnz=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">nnz</span><span class="si">}</span><span class="s1">)&#39;</span></div>

</pre></div>

           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>