<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>pygho.backend.MaTensor &mdash; PyGHO 0.0.1 documentation</title>
      <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="../../../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script src="../../../_static/jquery.js?v=5d32c60e"></script>
        <script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
        <script src="../../../_static/documentation_options.js?v=d45e8c67"></script>
        <script src="../../../_static/doctools.js?v=888ff710"></script>
        <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
    <script src="../../../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../../../genindex.html" />
    <link rel="search" title="Search" href="../../../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../../../index.html" class="icon icon-home">
            PyGHO
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/installation.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/miniexample.html">Minimal Example</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/datastructure.html">Refined Basic Data Structure</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/hodata.html">Efficient High Order Data Processing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/operator.html">Operators</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced Tutorial</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/multtensor.html">Multiple Tensor</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Package Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/backend.html">pygho.backend package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/hodata.html">pygho.hodata package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../modules/honn.html">pygho.honn package</a></li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../../../index.html">PyGHO</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
          <li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
      <li class="breadcrumb-item active">pygho.backend.MaTensor</li>
      <li class="wy-breadcrumbs-aside">
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <h1>Source code for pygho.backend.MaTensor</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">BoolTensor</span><span class="p">,</span> <span class="n">LongTensor</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Iterable</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
<span class="c1"># merge torch.nested or torch.masked API in the long run.</span>


<div class="viewcode-block" id="filterinf">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.filterinf">[docs]</a>
<span class="k">def</span> <span class="nf">filterinf</span><span class="p">(</span><span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">filled_value</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Replaces positive and negative infinity values in a tensor with a specified value.</span>

<span class="sd">    Args:</span>

<span class="sd">    - X (Tensor): The input tensor.</span>
<span class="sd">    - filled_value (float, optional): The value to replace positive and negative</span>
<span class="sd">      infinity values with (default: 0).</span>

<span class="sd">    Returns:</span>

<span class="sd">    - Tensor: A tensor with positive and negative infinity values replaced by the</span>
<span class="sd">      specified `filled_value`.</span>

<span class="sd">    Example:</span>
<span class="sd">    </span>
<span class="sd">    ::</span>
<span class="sd">    </span>
<span class="sd">        input_tensor = torch.tensor([1.0, 2.0, torch.inf, -torch.inf, 3.0])</span>
<span class="sd">        result = filterinf(input_tensor, filled_value=999.0)</span>

<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="n">X</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">X</span> <span class="o">==</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">inf</span><span class="p">),</span>
                       <span class="n">filled_value</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span></div>



<div class="viewcode-block" id="MaskedTensor">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor">[docs]</a>
<span class="k">class</span> <span class="nc">MaskedTensor</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Represents a masked tensor with optional padding values.</span>
<span class="sd">    This class allows you to work with tensors that have a mask indicating valid and</span>
<span class="sd">    invalid values. You can perform various operations on the masked tensor, such as</span>
<span class="sd">    filling masked values, computing sums, means, maximums, minimums, and more.</span>

<span class="sd">    Parameters:</span>
<span class="sd">    </span>
<span class="sd">    - data (Tensor): The underlying data tensor of shape (\*maskedshape, \*denseshape)</span>
<span class="sd">    - mask (BoolTensor): The mask tensor of shape (\*maskedshape) </span>
<span class="sd">      where `True` represents valid values, and False` represents invalid values.</span>
<span class="sd">    - padvalue (float, optional): The value to use for padding. Defaults to 0.</span>
<span class="sd">    - is_filled (bool, optional): Indicates whether the invalid values have already</span>
<span class="sd">      been filled to the padvalue. Defaults to False.</span>

<span class="sd">    Attributes:</span>
<span class="sd">    </span>
<span class="sd">    - data (Tensor): The underlying data tensor.</span>
<span class="sd">    - mask (BoolTensor): The mask tensor.</span>
<span class="sd">    - fullmask (BoolTensor): The mask tensor after broadcasting to match the data&#39;s</span>
<span class="sd">      dimensions.</span>
<span class="sd">    - padvalue (float): The padding value.</span>
<span class="sd">    - shape (torch.Size): The shape of the data tensor.</span>
<span class="sd">    - masked_dim (int): The number of dimensions in maskedshape.</span>
<span class="sd">    - dense_dim (int): The number of dimensions in denseshape.</span>
<span class="sd">    - maskedshape (torch.Size): The shape of the tensor up to the masked dimensions.</span>
<span class="sd">    - denseshape (torch.Size): The shape of the tensor after the masked dimensions.</span>

<span class="sd">    Methods:</span>
<span class="sd">    </span>
<span class="sd">    - fill_masked_(self, val: float = 0) -&gt; None: In-place fill of masked values.</span>
<span class="sd">    - fill_masked(self, val: float = 0) -&gt; Tensor: Return a tensor with masked values</span>
<span class="sd">      filled with the specified value.</span>
<span class="sd">    - to(self, device: torch.DeviceObjType, non_blocking: bool = True): Move the</span>
<span class="sd">      tensor to the specified device.</span>
<span class="sd">    - sum(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the</span>
<span class="sd">      sum of masked values along specified dimensions.</span>
<span class="sd">    - mean(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute</span>
<span class="sd">      the mean of masked values along specified dimensions.</span>
<span class="sd">    - max(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the</span>
<span class="sd">      maximum of masked values along specified dimensions.</span>
<span class="sd">    - min(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the</span>
<span class="sd">      minimum of masked values along specified dimensions.</span>
<span class="sd">    - diag(self, dims: Iterable[int]): Extract diagonals from the tensor. </span>
<span class="sd">      The dimensions in dims will be take diagonal and put at dims[0]</span>
<span class="sd">    - unpooling(self, dims: Union[int, Iterable[int]], tarX): Perform unpooling</span>
<span class="sd">      operation along specified dimensions.</span>
<span class="sd">    - tuplewiseapply(self, func: Callable[[Tensor], Tensor]): Apply a function to</span>
<span class="sd">      each element of the masked tensor.</span>
<span class="sd">    - diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): Apply a</span>
<span class="sd">      function to diagonal elements of the masked tensor.</span>
<span class="sd">    - add(self, tarX, samesparse: bool): Add two masked tensors together.</span>
<span class="sd">    - catvalue(self, tarX, samesparse: bool): Concatenate values of two masked</span>
<span class="sd">      tensors.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">data</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
                 <span class="n">mask</span><span class="p">:</span> <span class="n">BoolTensor</span><span class="p">,</span>
                 <span class="n">padvalue</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
                 <span class="n">is_filled</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="c1"># mask: True for valid value, False for invalid value</span>
        <span class="k">assert</span> <span class="n">data</span><span class="o">.</span><span class="n">ndim</span> <span class="o">&gt;=</span> <span class="n">mask</span><span class="o">.</span><span class="n">ndim</span><span class="p">,</span> <span class="s2">&quot;data&#39;s #dim should be larger than mask &quot;</span>
        <span class="k">assert</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="n">mask</span><span class="o">.</span>
                          <span class="n">ndim</span><span class="p">]</span> <span class="o">==</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="s2">&quot;data and mask&#39;s first dimensions should match&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__data</span> <span class="o">=</span> <span class="n">data</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__mask</span> <span class="o">=</span> <span class="n">mask</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__masked_dim</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">ndim</span>
        <span class="k">while</span> <span class="n">mask</span><span class="o">.</span><span class="n">ndim</span> <span class="o">&lt;</span> <span class="n">data</span><span class="o">.</span><span class="n">ndim</span><span class="p">:</span>
            <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__fullmask</span> <span class="o">=</span> <span class="n">mask</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="n">is_filled</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__padvalue</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">inf</span> <span class="k">if</span> <span class="n">padvalue</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">inf</span> <span class="k">else</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">inf</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">fill_masked_</span><span class="p">(</span><span class="n">padvalue</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">__padvalue</span> <span class="o">=</span> <span class="n">padvalue</span>

<div class="viewcode-block" id="MaskedTensor.fill_masked_">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.fill_masked_">[docs]</a>
    <span class="k">def</span> <span class="nf">fill_masked_</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        inplace fill the masked values</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">padvalue</span> <span class="o">==</span> <span class="n">val</span><span class="p">:</span>
            <span class="k">return</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__padvalue</span> <span class="o">=</span> <span class="n">val</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fullmask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.fill_masked">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.fill_masked">[docs]</a>
    <span class="k">def</span> <span class="nf">fill_masked</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        return a tensor with masked values filled with val.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">__padvalue</span> <span class="o">==</span> <span class="n">val</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span>
        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fullmask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.to">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.to">[docs]</a>
    <span class="k">def</span> <span class="nf">to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">DeviceObjType</span><span class="p">,</span> <span class="n">non_blocking</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        move data to some device</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__mask</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">__fullmask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__fullmask</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span></div>


    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">padvalue</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__padvalue</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__data</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">mask</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BoolTensor</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__mask</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">fullmask</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BoolTensor</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__fullmask</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__data</span><span class="o">.</span><span class="n">shape</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">masked_dim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">__masked_dim</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">dense_dim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">denseshape</span><span class="p">)</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">maskedshape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">masked_dim</span><span class="p">]</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">denseshape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">masked_dim</span><span class="p">:]</span>

<div class="viewcode-block" id="MaskedTensor.sum">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.sum">[docs]</a>
    <span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">],</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
                                      <span class="n">dim</span><span class="o">=</span><span class="n">dims</span><span class="p">,</span>
                                      <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span>
                            <span class="n">torch</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span>
                            <span class="n">padvalue</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
                            <span class="n">is_filled</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.mean">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.mean">[docs]</a>
    <span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">],</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="n">count</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min_</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fullmask</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">valsum</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">valsum</span><span class="o">.</span><span class="n">data</span> <span class="o">/</span> <span class="n">count</span><span class="p">,</span>
                            <span class="n">valsum</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span>
                            <span class="n">padvalue</span><span class="o">=</span><span class="n">valsum</span><span class="o">.</span><span class="n">padvalue</span><span class="p">,</span>
                            <span class="n">is_filled</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.max">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.max">[docs]</a>
    <span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">],</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="n">tmp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">filterinf</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="n">tmp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span> <span class="mi">0</span><span class="p">),</span>
                            <span class="n">torch</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span>
                            <span class="n">padvalue</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
                            <span class="n">is_filled</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.min">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.min">[docs]</a>
    <span class="k">def</span> <span class="nf">min</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">],</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
        <span class="n">tmp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">filterinf</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="n">tmp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span> <span class="mi">0</span><span class="p">),</span>
                            <span class="n">torch</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">),</span>
                            <span class="n">padvalue</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
                            <span class="n">is_filled</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.diag">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.diag">[docs]</a>
    <span class="k">def</span> <span class="nf">diag</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        put the reduced output to dim[0]</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;must diag several dims&quot;</span>
        <span class="n">dims</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
        <span class="n">tdata</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span>
        <span class="n">tmask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span>
        <span class="n">tdata</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diagonal</span><span class="p">(</span><span class="n">tdata</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dims</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
        <span class="n">tmask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diagonal</span><span class="p">(</span><span class="n">tmask</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dims</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)):</span>
            <span class="n">tdata</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diagonal</span><span class="p">(</span><span class="n">tdata</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">tmask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diagonal</span><span class="p">(</span><span class="n">tmask</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">tdata</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">movedim</span><span class="p">(</span><span class="n">tdata</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">tmask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">movedim</span><span class="p">(</span><span class="n">tmask</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">tdata</span><span class="p">,</span> <span class="n">tmask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padvalue</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.unpooling">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.unpooling">[docs]</a>
    <span class="k">def</span> <span class="nf">unpooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="n">tarX</span><span class="p">):</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
            <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">dims</span><span class="p">]</span>
        <span class="n">dims</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
        <span class="n">tdata</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span>
        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">dims</span><span class="p">:</span>
            <span class="n">tdata</span> <span class="o">=</span> <span class="n">tdata</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">_</span><span class="p">)</span>
        <span class="n">tdata</span> <span class="o">=</span> <span class="n">tdata</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="n">tarX</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
                               <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tdata</span><span class="o">.</span><span class="n">ndim</span><span class="p">)))</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">tdata</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padvalue</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.tuplewiseapply">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.tuplewiseapply">[docs]</a>
    <span class="k">def</span> <span class="nf">tuplewiseapply</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]):</span>
        <span class="c1"># it may cause nan in gradient and makes amp unable to update</span>
        <span class="n">ndata</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">ndata</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.diagonalapply">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.diagonalapply">[docs]</a>
    <span class="k">def</span> <span class="nf">diagonalapply</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]):</span>
        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">masked_dim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;only implemented for 2D&quot;</span>
        <span class="n">diagonaltype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
                                 <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
                                 <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span>
                                 <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="n">diagonaltype</span> <span class="o">=</span> <span class="n">diagonaltype</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span>
        <span class="n">ndata</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">diagonaltype</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">ndata</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.add">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.add">[docs]</a>
    <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tarX</span><span class="p">,</span> <span class="n">samesparse</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">samesparse</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span><span class="n">tarX</span><span class="o">.</span><span class="n">data</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
                                <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span>
                                <span class="bp">self</span><span class="o">.</span><span class="n">padvalue</span><span class="p">,</span>
                                <span class="n">is_filled</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">padvalue</span> <span class="o">==</span> <span class="n">tarX</span><span class="o">.</span><span class="n">padvalue</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">MaskedTensor</span><span class="p">(</span>
                <span class="n">tarX</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_masked</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
                <span class="n">torch</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">mask</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="MaskedTensor.catvalue">
<a class="viewcode-back" href="../../../modules/backend.html#pygho.backend.MaTensor.MaskedTensor.catvalue">[docs]</a>
    <span class="k">def</span> <span class="nf">catvalue</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tarX</span><span class="p">,</span> <span class="n">samesparse</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">samesparse</span> <span class="o">==</span> <span class="kc">True</span><span class="p">,</span> <span class="s2">&quot;must have the same sparcity to concat value&quot;</span>
        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tarX</span><span class="p">,</span> <span class="n">MaskedTensor</span><span class="p">):</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">tarX</span><span class="o">.</span><span class="n">data</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tarX</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">):</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tuplewiseapply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
                <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">tarX</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
</div>

</pre></div>

           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>