<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>pygho.backend.Spspmm &mdash; PyGHO 0.0.1 documentation</title>
      <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script src="../../../_static/jquery.js?v=5d32c60e"></script>
        <script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
        <script src="../../../_static/documentation_options.js?v=d45e8c67"></script>
        <script src="../../../_static/doctools.js?v=888ff710"></script>
        <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
    <script src="../../../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../../../index.html" class="icon icon-home">
            PyGHO
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/installation.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/miniexample.html">Minimal Example</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/datastructure.html">Refined Basic Data Structure</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/hodata.html">Efficient High Order Data Processing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/operator.html">Operators</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced Tutorial</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/multtensor.html">Multiple Tensor</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Package Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/backend.html">pygho.backend package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/hodata.html">pygho.hodata package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/honn.html">pygho.honn package</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../index.html">PyGHO</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
          <li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
      <li class="breadcrumb-item active">pygho.backend.Spspmm</li>
      <li class="wy-breadcrumbs-aside">
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <h1>Source code for pygho.backend.Spspmm</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="kn">from</span> <span class="nn">.SpTensor</span> <span class="kn">import</span> <span class="n">SparseTensor</span><span class="p">,</span> <span class="n">indicehash</span><span class="p">,</span> <span class="n">decodehash</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">from</span> <span class="nn">.utils</span> <span class="kn">import</span> <span class="n">torch_scatter_reduce</span>

<div class="viewcode-block" id="ptr2batch">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.ptr2batch">[docs]</a>
<span class="k">def</span> <span class="nf">ptr2batch</span><span class="p">(</span><span class="n">ptr</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">dim_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Converts a pointer tensor to a batch tensor. TODO: use torch_scatter gather instead?</span>

<span class="sd">    This function takes a pointer tensor `ptr` and a `dim_size` and converts it to a</span>
<span class="sd">    batch tensor where each element in the batch tensor corresponds to a range of</span>
<span class="sd">    indices in the original tensor.</span>

<span class="sd">    Args:</span>

<span class="sd">    - ptr (LongTensor): The pointer tensor, where `ptr[0] = 0` and `torch.all(diff(ptr) &gt;= 0)` is true.</span>
<span class="sd">    - dim_size (int): The size of the target dimension.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - LongTensor: A batch tensor of shape `(dim_size,)` where `batch[ptr[i]:ptr[i+1]] = i`.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">ptr</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;ptr should be 1-d&quot;</span>
    <span class="k">assert</span> <span class="n">ptr</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">diff</span><span class="p">(</span><span class="n">ptr</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;should put in a ptr tensor&quot;</span>
    <span class="k">assert</span> <span class="n">ptr</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">dim_size</span><span class="p">,</span> <span class="s2">&quot;dim_size should match ptr&quot;</span>
    <span class="n">tmp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">dim_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">ptr</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ptr</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">ptr</span><span class="p">,</span> <span class="n">tmp</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="k">return</span> <span class="n">ret</span></div>



<div class="viewcode-block" id="spspmm_ind">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.spspmm_ind">[docs]</a>
<span class="k">def</span> <span class="nf">spspmm_ind</span><span class="p">(</span><span class="n">ind1</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
               <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
               <span class="n">ind2</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
               <span class="n">dim2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
               <span class="n">is_k2_sorted</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">,</span> <span class="n">LongTensor</span><span class="p">]:</span>   
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Sparse-sparse matrix multiplication for indices.</span>

<span class="sd">    This function performs a sparse-sparse matrix multiplication for indices. </span>
<span class="sd">    Given two sets of indices `ind1` and `ind2`, this function eliminates `dim1` in `ind1` and `dim2` in `ind2`, and concatenates the remaining dimensions. </span>
<span class="sd">    </span>
<span class="sd">    The result represents the product of the input indices.</span>

<span class="sd">    Args:</span>

<span class="sd">    - ind1 (LongTensor): The indices of the first sparse tensor of shape `(sparsedim1, M1)`.</span>
<span class="sd">    - dim1 (int): The dimension to eliminate in `ind1`.</span>
<span class="sd">    - ind2 (LongTensor): The indices of the second sparse tensor of shape `(sparsedim2, M2)`.</span>
<span class="sd">    - dim2 (int): The dimension to eliminate in `ind2`.</span>
<span class="sd">    - is_k2_sorted (bool, optional): Whether `ind2` is sorted along `dim2`. Defaults to `False`.</span>

<span class="sd">    Returns:</span>
<span class="sd">    </span>
<span class="sd">    - tarind: LongTensor: The resulting indices after performing the sparse-sparse matrix    multiplication.</span>
<span class="sd">    - bcd: LongTensor: In tensor perspective (\*i_1, k, \*i_2), (\*j_1, k, \*j_2) -&gt; (\*i_1, \*i_2, \*j_1, \*j_2).</span>
<span class="sd">      The return indice is of shape (3, nnz), (b, c, d), c represent index of \*i, d represent index of \*j, b represent index of output.For i=1,2,...,nnz,  val1[c[i]] * val2[d[i]] will be add to output val&#39;s b[i]-th element.</span>

<span class="sd">    Example:</span>
<span class="sd">    </span>
<span class="sd">    ::</span>

<span class="sd">        ind1 = torch.tensor([[0, 1, 1, 2],</span>
<span class="sd">                            [2, 1, 0, 2]], dtype=torch.long)</span>
<span class="sd">        dim1 = 0</span>
<span class="sd">        ind2 = torch.tensor([[2, 1, 0, 1],</span>
<span class="sd">                            [1, 0, 2, 2]], dtype=torch.long)</span>
<span class="sd">        dim2 = 1</span>
<span class="sd">        result = spspmm_ind(ind1, dim1, ind2, dim2)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">dim1</span> <span class="o">&lt;</span> <span class="n">ind1</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
        <span class="mi">0</span><span class="p">],</span> <span class="sa">f</span><span class="s2">&quot;ind1&#39;s reduced dim </span><span class="si">{</span><span class="n">dim1</span><span class="si">}</span><span class="s2"> is out of range&quot;</span>
    <span class="k">assert</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">dim2</span> <span class="o">&lt;</span> <span class="n">ind2</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
        <span class="mi">0</span><span class="p">],</span> <span class="sa">f</span><span class="s2">&quot;ind2&#39;s reduced dim </span><span class="si">{</span><span class="n">dim2</span><span class="si">}</span><span class="s2"> is out of range&quot;</span>
    <span class="k">if</span> <span class="n">dim2</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="p">(</span><span class="n">is_k2_sorted</span><span class="p">):</span>
        <span class="n">perm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">ind2</span><span class="p">[</span><span class="n">dim2</span><span class="p">])</span>
        <span class="n">tarind</span><span class="p">,</span> <span class="n">bcd</span> <span class="o">=</span> <span class="n">spspmm_ind</span><span class="p">(</span><span class="n">ind1</span><span class="p">,</span> <span class="n">dim1</span><span class="p">,</span> <span class="n">ind2</span><span class="p">[:,</span> <span class="n">perm</span><span class="p">],</span> <span class="n">dim2</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
        <span class="n">bcd</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="n">bcd</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span>
        <span class="k">return</span> <span class="n">tarind</span><span class="p">,</span> <span class="n">bcd</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">nnz1</span><span class="p">,</span> <span class="n">nnz2</span><span class="p">,</span> <span class="n">sparsedim1</span><span class="p">,</span> <span class="n">sparsedim2</span> <span class="o">=</span> <span class="n">ind1</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">ind2</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
            <span class="mi">1</span><span class="p">],</span> <span class="n">ind1</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ind2</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">k1</span><span class="p">,</span> <span class="n">k2</span> <span class="o">=</span> <span class="n">ind1</span><span class="p">[</span><span class="n">dim1</span><span class="p">],</span> <span class="n">ind2</span><span class="p">[</span><span class="n">dim2</span><span class="p">]</span>
        <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">diff</span><span class="p">(</span><span class="n">k2</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;ind2[0] should be sorted&quot;</span>
        <span class="c1"># for each k in k1, it can match a interval of k2 as k2 is sorted</span>
        <span class="n">upperbound</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">k2</span><span class="p">,</span> <span class="n">k1</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="n">lowerbound</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">k2</span><span class="p">,</span> <span class="n">k1</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
        <span class="n">matched_num</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min_</span><span class="p">(</span><span class="n">upperbound</span> <span class="o">-</span> <span class="n">lowerbound</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

        <span class="c1"># ptr[i] provide the offset to place pair of ind1[:, i] and the matched ind2</span>
        <span class="n">retptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">nnz1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
                             <span class="n">dtype</span><span class="o">=</span><span class="n">matched_num</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
                             <span class="n">device</span><span class="o">=</span><span class="n">matched_num</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">matched_num</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">retptr</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
        <span class="n">retsize</span> <span class="o">=</span> <span class="n">retptr</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

        <span class="c1"># fill the output with ptr</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="n">retsize</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">ind1</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ind1</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">ptr2batch</span><span class="p">(</span><span class="n">retptr</span><span class="p">,</span> <span class="n">retsize</span><span class="p">)</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">retsize</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">ret</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ret</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
        <span class="n">offset</span> <span class="o">=</span> <span class="p">(</span><span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">retptr</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span> <span class="o">-</span> <span class="n">lowerbound</span><span class="p">)[</span><span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
        <span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-=</span> <span class="n">offset</span>

        <span class="c1"># compute the ind pair index</span>

        <span class="n">combinedind</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">ind1</span><span class="p">[:</span><span class="n">dim1</span><span class="p">],</span> <span class="n">ind1</span><span class="p">[</span><span class="n">dim1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:])))[:,</span> <span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">]],</span>
                 <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">ind2</span><span class="p">[:</span><span class="n">dim2</span><span class="p">],</span> <span class="n">ind2</span><span class="p">[</span><span class="n">dim2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]))[:,</span> <span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">]])))</span>
        <span class="n">combinedind</span><span class="p">,</span> <span class="n">taridx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">combinedind</span><span class="p">,</span>
                                           <span class="nb">sorted</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                           <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="n">tarind</span> <span class="o">=</span> <span class="n">decodehash</span><span class="p">(</span><span class="n">combinedind</span><span class="p">,</span> <span class="n">sparsedim1</span> <span class="o">+</span> <span class="n">sparsedim2</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">taridx</span>

        <span class="n">sorted_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">ret</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>  <span class="c1"># sort is optional</span>
        <span class="k">return</span> <span class="n">tarind</span><span class="p">,</span> <span class="n">ret</span><span class="p">[:,</span> <span class="n">sorted_idx</span><span class="p">]</span></div>



<div class="viewcode-block" id="spsphadamard_ind">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.spsphadamard_ind">[docs]</a>
<span class="k">def</span> <span class="nf">spsphadamard_ind</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">ind</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Auxiliary function for SparseTensor-SparseTensor Hadamard product.</span>

<span class="sd">    This function is an auxiliary function used in the Hadamard product of two sparse tensors. Given the indices `tar_ind` of sparse tensor A and the indices `ind` of sparse tensor B, this function returns an index array `b2a` of shape `(ind.shape[1],)` such that `ind[:, i]` matches `tar_ind[:, b2a[i]]` for each `i`. If `b2a[i]` is less than 0, it means `ind[:, i]` is not matched.</span>

<span class="sd">    Args:</span>

<span class="sd">    - tar_ind (LongTensor): The indices of sparse tensor A.</span>
<span class="sd">    - ind (LongTensor): The indices of sparse tensor B.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - LongTensor: An index array `b2a` representing the matching indices between `tar_ind` and `ind`.</span>
<span class="sd">      b2a of shape ind.shape[1]. ind[:, i] matches tar_ind[:, b2a[i]]. if b2a[i]&lt;0, ind[:, i] is not matched </span>
<span class="sd">    </span>
<span class="sd">    Example:</span>

<span class="sd">    ::</span>

<span class="sd">        tar_ind = torch.tensor([[0, 1, 1, 2],</span>
<span class="sd">                                [2, 1, 0, 2]], dtype=torch.long)</span>
<span class="sd">        ind = torch.tensor([[2, 1, 0, 1],</span>
<span class="sd">                            [1, 0, 2, 2]], dtype=torch.long)</span>
<span class="sd">        b2a = spsphadamard_ind(tar_ind, ind)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">tar_ind</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">ind</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">combine_tar_ind</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">diff</span><span class="p">(</span><span class="n">combine_tar_ind</span><span class="p">)</span> <span class="o">&gt;</span>
                     <span class="mi">0</span><span class="p">),</span> <span class="s2">&quot;tar_ind should be sorted and coalesce&quot;</span>
    <span class="n">combine_ind</span> <span class="o">=</span> <span class="n">indicehash</span><span class="p">(</span><span class="n">ind</span><span class="p">)</span>

    <span class="n">b2a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min_</span><span class="p">(</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">combine_tar_ind</span><span class="p">,</span> <span class="n">combine_ind</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
    <span class="n">notmatchmask</span> <span class="o">=</span> <span class="p">(</span><span class="n">combine_ind</span> <span class="o">!=</span> <span class="n">combine_tar_ind</span><span class="p">[</span><span class="n">b2a</span><span class="p">])</span>
    <span class="n">b2a</span><span class="p">[</span><span class="n">notmatchmask</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
    <span class="k">return</span> <span class="n">b2a</span></div>



<div class="viewcode-block" id="filterind">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.filterind">[docs]</a>
<span class="k">def</span> <span class="nf">filterind</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">ind</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
              <span class="n">bcd</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    A combination of Hadamard and Sparse Matrix Multiplication.</span>

<span class="sd">    Given the indices `tar_ind` of sparse tensor A, the indices `ind` of sparse tensor BC, and the index array `bcd`, this function returns an index array `acd`, where `(A ⊙ (BC)).val[a] = A.val[a] * scatter(B.val[c] * C.val[d], a)`.</span>

<span class="sd">    Args:</span>

<span class="sd">    - tar_ind (LongTensor): The indices of sparse tensor A.</span>
<span class="sd">    - ind (LongTensor): The indices of sparse tensor BC.</span>
<span class="sd">    - bcd (LongTensor): An index array representing `(BC).val`.</span>

<span class="sd">    Returns:</span>
<span class="sd">    </span>
<span class="sd">    - LongTensor: An index array `acd` representing the filtered indices.</span>

<span class="sd">    Example:</span>
<span class="sd">    </span>
<span class="sd">    ::</span>

<span class="sd">        tar_ind = torch.tensor([[0, 1, 1, 2],</span>
<span class="sd">                                [2, 1, 0, 2]], dtype=torch.long)</span>
<span class="sd">        ind = torch.tensor([[2, 1, 0, 1],</span>
<span class="sd">                            [1, 0, 2, 2]], dtype=torch.long)</span>
<span class="sd">        bcd = torch.tensor([[3, 2, 1, 0],</span>
<span class="sd">                            [6, 5, 4, 3],</span>
<span class="sd">                            [9, 8, 7, 6]], dtype=torch.long)</span>
<span class="sd">        acd = filterind(tar_ind, ind, bcd)</span>


<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">b2a</span> <span class="o">=</span> <span class="n">spsphadamard_ind</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">,</span> <span class="n">ind</span><span class="p">)</span>
    <span class="n">a</span> <span class="o">=</span> <span class="n">b2a</span><span class="p">[</span><span class="n">bcd</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
    <span class="n">retmask</span> <span class="o">=</span> <span class="n">a</span> <span class="o">&gt;=</span> <span class="mi">0</span>
    <span class="n">acd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">a</span><span class="p">[</span><span class="n">retmask</span><span class="p">],</span> <span class="n">bcd</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="n">retmask</span><span class="p">],</span> <span class="n">bcd</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">retmask</span><span class="p">]))</span>
    <span class="k">return</span> <span class="n">acd</span></div>



<div class="viewcode-block" id="spsphadamard">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.spsphadamard">[docs]</a>
<span class="k">def</span> <span class="nf">spsphadamard</span><span class="p">(</span><span class="n">A</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
                 <span class="n">B</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
                 <span class="n">b2a</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SparseTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Element-wise Hadamard product between two SparseTensors.</span>

<span class="sd">    This function performs the element-wise Hadamard product between two SparseTensors, `A` and `B`. The `b2a` parameter is an optional auxiliary index produced by the `spsphadamard_ind` function.</span>

<span class="sd">    Args:</span>

<span class="sd">    - A (SparseTensor): The first SparseTensor.</span>
<span class="sd">    - B (SparseTensor): The second SparseTensor.</span>
<span class="sd">    - b2a (LongTensor, optional): An optional index array produced by `spsphadamard_ind`. If not provided, it will be computed.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - SparseTensor: A SparseTensor containing the result of the Hadamard product.</span>


<span class="sd">    Notes:</span>

<span class="sd">    - Both `A` and `B` must be coalesced SparseTensors.</span>
<span class="sd">    - The dense shapes of `A` and `B` must be broadcastable.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">A</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">(),</span> <span class="s2">&quot;A should be coalesced&quot;</span>
    <span class="k">assert</span> <span class="n">B</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">(),</span> <span class="s2">&quot;B should be coalesced&quot;</span>
    <span class="k">assert</span> <span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span> <span class="o">==</span> <span class="n">B</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">,</span> <span class="s2">&quot;A, B should be of the same sparse shape&quot;</span>
    <span class="n">ind1</span><span class="p">,</span> <span class="n">val1</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span>
    <span class="n">ind2</span><span class="p">,</span> <span class="n">val2</span> <span class="o">=</span> <span class="n">B</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span>
    <span class="k">if</span> <span class="n">b2a</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">b2a</span> <span class="o">=</span> <span class="n">spsphadamard_ind</span><span class="p">(</span><span class="n">ind1</span><span class="p">,</span> <span class="n">ind2</span><span class="p">)</span>
    <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">b2a</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">val1</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">retval</span> <span class="o">=</span> <span class="n">val2</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
    <span class="k">elif</span> <span class="n">val2</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">retval</span> <span class="o">=</span> <span class="n">val1</span><span class="p">[</span><span class="n">b2a</span><span class="p">[</span><span class="n">mask</span><span class="p">]]</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">retval</span> <span class="o">=</span> <span class="n">val1</span><span class="p">[</span><span class="n">b2a</span><span class="p">[</span><span class="n">mask</span><span class="p">]]</span> <span class="o">*</span> <span class="n">val2</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
    <span class="n">retind</span> <span class="o">=</span> <span class="n">ind2</span><span class="p">[:,</span> <span class="n">mask</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="n">retind</span><span class="p">,</span>
                        <span class="n">retval</span><span class="p">,</span>
                        <span class="n">shape</span><span class="o">=</span><span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span> <span class="o">+</span> <span class="n">retval</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span>
                        <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>



<div class="viewcode-block" id="spspmm">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.spspmm">[docs]</a>
<span class="k">def</span> <span class="nf">spspmm</span><span class="p">(</span><span class="n">A</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
           <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
           <span class="n">B</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
           <span class="n">dim2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
           <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;sum&quot;</span><span class="p">,</span>
           <span class="n">bcd</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
           <span class="n">tar_ind</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
           <span class="n">acd</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SparseTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    SparseTensor SparseTensor matrix multiplication at a specified sparse dimension.</span>

<span class="sd">    This function performs matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values.</span>

<span class="sd">    Args:</span>

<span class="sd">    - A (SparseTensor): The first SparseTensor.</span>
<span class="sd">    - dim1 (int): The dimension along which `A` is multiplied.</span>
<span class="sd">    - B (SparseTensor): The second SparseTensor.</span>
<span class="sd">    - dim2 (int): The dimension along which `B` is multiplied.</span>
<span class="sd">    - aggr (str, optional): The reduction operation to use for merging edge features (&quot;sum&quot;, &quot;min&quot;, &quot;max&quot;, &quot;mean&quot;). Defaults to &quot;sum&quot;.</span>
<span class="sd">    - bcd (LongTensor, optional): An optional auxiliary index array produced by spspmm_ind.</span>
<span class="sd">    - tar_ind (LongTensor, optional): An optional target index array for the output. If not provided, it will be computed.</span>
<span class="sd">    - acd (LongTensor, optional): An optional auxiliary index array produced by filterind.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - SparseTensor: A SparseTensor containing the result of the matrix multiplication.</span>

<span class="sd">    Notes:</span>

<span class="sd">    - Both `A` and `B` must be coalesced SparseTensors.</span>
<span class="sd">    - The dense shapes of `A` and `B` must be broadcastable.</span>
<span class="sd">    - This function allows for optional indices `bcd` and `tar_ind` for improved performance and control.</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">assert</span> <span class="n">A</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">(),</span> <span class="s2">&quot;A should be coalesced&quot;</span>
    <span class="k">assert</span> <span class="n">B</span><span class="o">.</span><span class="n">is_coalesced</span><span class="p">(),</span> <span class="s2">&quot;B should be coalesced&quot;</span>
    <span class="k">if</span> <span class="n">acd</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="k">assert</span> <span class="n">tar_ind</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
        <span class="k">if</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mult</span> <span class="o">=</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span>
        <span class="k">elif</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mult</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">mult</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">*</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span>
        <span class="n">retval</span> <span class="o">=</span> <span class="n">torch_scatter_reduce</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">mult</span><span class="p">,</span> <span class="n">acd</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">tar_ind</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">aggr</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">,</span>
                            <span class="n">retval</span><span class="p">,</span>
                            <span class="n">shape</span><span class="o">=</span><span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[:</span><span class="n">dim1</span><span class="p">]</span> <span class="o">+</span>
                            <span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[</span><span class="n">dim1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">B</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[:</span><span class="n">dim2</span><span class="p">]</span> <span class="o">+</span>
                            <span class="n">B</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[</span><span class="n">dim2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">retval</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span>
                            <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;acd is not found&quot;</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">bcd</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">ind</span><span class="p">,</span> <span class="n">bcd</span> <span class="o">=</span> <span class="n">spspmm_ind</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">dim1</span><span class="p">,</span> <span class="n">B</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">dim2</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">tar_ind</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">acd</span> <span class="o">=</span> <span class="n">filterind</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">,</span> <span class="n">ind</span><span class="p">,</span> <span class="n">bcd</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">spspmm</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">dim1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">dim2</span><span class="p">,</span> <span class="n">aggr</span><span class="p">,</span> <span class="n">acd</span><span class="o">=</span><span class="n">acd</span><span class="p">,</span> <span class="n">tar_ind</span><span class="o">=</span><span class="n">tar_ind</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;tar_ind is not found&quot;</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">spspmm</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">dim1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">dim2</span><span class="p">,</span> <span class="n">aggr</span><span class="p">,</span> <span class="n">acd</span><span class="o">=</span><span class="n">bcd</span><span class="p">,</span> <span class="n">tar_ind</span><span class="o">=</span><span class="n">ind</span><span class="p">)</span></div>



<div class="viewcode-block" id="spspmpnn">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.Spspmm.spspmpnn">[docs]</a>
<span class="k">def</span> <span class="nf">spspmpnn</span><span class="p">(</span><span class="n">A</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
             <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
             <span class="n">B</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
             <span class="n">dim2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
             <span class="n">C</span><span class="p">:</span> <span class="n">SparseTensor</span><span class="p">,</span>
             <span class="n">acd</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
             <span class="n">message_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span><span class="p">],</span>
                                    <span class="n">Tensor</span><span class="p">],</span>
             <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;sum&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SparseTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    SparseTensor SparseTensor matrix multiplication at a specified sparse dimension using a message function.</span>

<span class="sd">    This function extend matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`, while using a message function `message_func` to compute the messages sent from `A` to `B` and `C`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values.</span>

<span class="sd">    Args:</span>

<span class="sd">    - A (SparseTensor): The first SparseTensor.</span>
<span class="sd">    - dim1 (int): The dimension along which `A` is multiplied.</span>
<span class="sd">    - B (SparseTensor): The second SparseTensor.</span>
<span class="sd">    - dim2 (int): The dimension along which `B` is multiplied.</span>
<span class="sd">    - C (SparseTensor): The third SparseTensor.</span>
<span class="sd">    - acd (LongTensor): The auxiliary index array produced by a previous operation.</span>
<span class="sd">    - message_func (Callable): A callable function that computes the messages between `A`, `B`, and `C`.</span>
<span class="sd">    - aggr (str, optional): The reduction operation to use for merging edge features (&quot;sum&quot;, &quot;min&quot;, &quot;max&quot;, &quot;mul&quot;, &quot;any&quot;). Defaults to &quot;sum&quot;.</span>

<span class="sd">    Returns:</span>

<span class="sd">    - SparseTensor: A SparseTensor containing the result of the matrix multiplication.</span>

<span class="sd">    Notes:</span>
<span class="sd">    </span>
<span class="sd">    - Both `A` and `B` must be coalesced SparseTensors.</span>
<span class="sd">    - The dense shapes of `A`, `B`, and `C` must be broadcastable.</span>
<span class="sd">    - The `message_func` should take four arguments: `A_values`, `B_values`, `C_values`, and `acd`, and return messages based on custom logic.</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">mult</span> <span class="o">=</span> <span class="n">message_func</span><span class="p">(</span><span class="kc">None</span> <span class="k">if</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">A</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">1</span><span class="p">]],</span>
                        <span class="kc">None</span> <span class="k">if</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">B</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">2</span><span class="p">]],</span>
                        <span class="kc">None</span> <span class="k">if</span> <span class="n">C</span><span class="o">.</span><span class="n">values</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">C</span><span class="o">.</span><span class="n">values</span><span class="p">[</span><span class="n">acd</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">acd</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">tar_ind</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">indices</span>
    <span class="n">retval</span> <span class="o">=</span> <span class="n">torch_scatter_reduce</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">mult</span><span class="p">,</span> <span class="n">acd</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">tar_ind</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">aggr</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">SparseTensor</span><span class="p">(</span><span class="n">tar_ind</span><span class="p">,</span>
                        <span class="n">retval</span><span class="p">,</span>
                        <span class="n">shape</span><span class="o">=</span><span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[:</span><span class="n">dim1</span><span class="p">]</span> <span class="o">+</span> <span class="n">A</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[</span><span class="n">dim1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">+</span>
                        <span class="n">B</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[:</span><span class="n">dim2</span><span class="p">]</span> <span class="o">+</span> <span class="n">B</span><span class="o">.</span><span class="n">sparseshape</span><span class="p">[</span><span class="n">dim2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">+</span>
                        <span class="n">retval</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span>
                        <span class="n">is_coalesced</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>

</pre></div>

           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>