<!DOCTYPE html>

<html lang="en" data-content_root="../../">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>chicken.model &#8212; contilearn  documentation</title>
    <link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=5ecbeea2" />
    <link rel="stylesheet" type="text/css" href="../../_static/basic.css?v=686e5160" />
    <link rel="stylesheet" type="text/css" href="../../_static/alabaster.css?v=27fed22d" />
    <script src="../../_static/documentation_options.js?v=5929fcd5"></script>
    <script src="../../_static/doctools.js?v=9bcbadda"></script>
    <script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
    <link rel="index" title="Index" href="../../genindex.html" />
    <link rel="search" title="Search" href="../../search.html" />
   
  <link rel="stylesheet" href="../../_static/custom.css" type="text/css" />
  

  
  

  </head><body>
  

    <div class="document">
      <div class="documentwrapper">
        <div class="bodywrapper">
          

          <div class="body" role="main">
            
  <h1>Source code for chicken.model</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span><span class="w"> </span><span class="nn">types</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">copy</span><span class="o">,</span><span class="w"> </span><span class="nn">os</span><span class="o">,</span><span class="w"> </span><span class="nn">random</span>

<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">List</span>

<div class="viewcode-block" id="Chicken">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">Chicken</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    A Incremental learning class Module.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="c1"># optional cache so we don&#39;t recreate the same subclass over and over</span>
    <span class="n">_cls_cache</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">type</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">],</span> <span class="nb">type</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{}</span>

    <span class="c1"># -------- object construction --------</span>
    <span class="k">def</span><span class="w"> </span><span class="fm">__new__</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">cls</span> <span class="ow">is</span> <span class="n">Chicken</span><span class="p">:</span>  <span class="c1"># only when user calls Chicken(…)</span>
            <span class="n">base</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>

            <span class="c1"># reuse cached subclass if it exists</span>
            <span class="n">Wrapped</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_cls_cache</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">base</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">Wrapped</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="n">Wrapped</span> <span class="o">=</span> <span class="n">types</span><span class="o">.</span><span class="n">new_class</span><span class="p">(</span>
                    <span class="sa">f</span><span class="s2">&quot;Conti</span><span class="si">{</span><span class="n">base</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>        <span class="c1"># e.g. ContiVisionTransformer</span>
                    <span class="p">(</span><span class="n">Chicken</span><span class="p">,</span> <span class="n">base</span><span class="p">),</span>               <span class="c1"># MRO: Chicken → base model</span>
                    <span class="p">{}</span>
                <span class="p">)</span>
                <span class="bp">cls</span><span class="o">.</span><span class="n">_cls_cache</span><span class="p">[</span><span class="n">base</span><span class="p">]</span> <span class="o">=</span> <span class="n">Wrapped</span>

            <span class="c1"># allocate instance of the *new* subclass</span>
            <span class="n">inst</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__new__</span><span class="p">(</span><span class="n">Wrapped</span><span class="p">)</span>
            <span class="c1"># copy every weight / buffer / attribute</span>
            <span class="n">inst</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">inst</span>

        <span class="c1"># if somebody subclasses Chicken explicitly, honour normal behaviour</span>
        <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__new__</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span>

<div class="viewcode-block" id="Chicken.__init__">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.__init__">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">model</span><span class="p">,</span>
        <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
        <span class="n">init_val</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
        <span class="n">max_mult</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
        <span class="n">matching_texts</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span><span class="o">=</span><span class="p">(</span><span class="s2">&quot;layernorm&quot;</span><span class="p">,</span> <span class="s2">&quot;bias&quot;</span><span class="p">,</span> <span class="s2">&quot;embeddings&quot;</span><span class="p">,</span> <span class="s2">&quot;layrnorm&quot;</span><span class="p">,</span> <span class="s2">&quot;layer_norm&quot;</span><span class="p">),</span>
        <span class="n">rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>  <span class="c1"># optional truncation</span>
    <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        model: torch.nn.Module, required</span>
<span class="sd">        device: string, optional</span>
<span class="sd">            Initial Value (default cpu).</span>
<span class="sd">        init_val: float, optional</span>
<span class="sd">            Maximum initial value mask ~ U[0,init_val] (default 0.1).</span>
<span class="sd">        max_mult: float, optional</span>
<span class="sd">            Maximum possible value the mask can take [0,max_mult] (default 1.0).</span>
<span class="sd">        matching_texts: List[str], optional</span>
<span class="sd">            A list of matching layer names that should not perform the decomposition and reconstruction (default (&quot;layernorm&quot;, &quot;bias&quot;, &quot;embeddings&quot;, &quot;layrnorm&quot;, &quot;layer_norm&quot;)).</span>
<span class="sd">        </span>
<span class="sd">        Examples</span>
<span class="sd">        --------</span>
<span class="sd">        &gt;&gt;&gt; from transformers import ViTModel</span>
<span class="sd">        &gt;&gt;&gt; model = ViTModel.from_pretrained(&#39;google/vit-base-patch16-224-in21k&#39;)</span>
<span class="sd">        &gt;&gt;&gt; model = Chicken(model, device=&quot;cuda&quot;, init_val=0.05, max_mult=1.0)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="c1"># super().__init__()  # DON&#39;T: dynamic subclass already has base attrs</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
            <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;model is not an torch.nn.Module&quot;</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">init_val</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">init_val</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">matching_texts</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">max_mult</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">max_mult</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">rank</span>  <span class="c1"># None = full SVD</span>

        <span class="c1"># snapshot of base params (on current device/dtype)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span>
        <span class="c1"># precompute decomposition only for 2D weights we intend to adapt</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decompose</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">base_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span> <span class="o">=</span> <span class="p">{}</span>

        <span class="c1"># register mask containers so they move with .to()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleDict</span><span class="p">()</span>   <span class="c1"># key: str(mask_idx) -&gt; ParameterDict</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span> <span class="o">=</span> <span class="p">{}</span>               <span class="c1"># mask_idx (int) -&gt; list[Parameter]</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">num_params</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>  <span class="c1"># no mask selected</span></div>


    <span class="c1"># ---------- properties / helpers ----------</span>
    <span class="nd">@property</span>
    <span class="k">def</span><span class="w"> </span><span class="nf">class_map</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Returns a string of mask index and the classes associated with it</span>

<span class="sd">        Returns:</span>
<span class="sd">            string</span>
<span class="sd">        </span>
<span class="sd">        Examples</span>
<span class="sd">        --------</span>
<span class="sd">        &gt;&gt;&gt; print(model.class_map)</span>
<span class="sd">        CLASS MAP</span>
<span class="sd">        1: cat, dog, horse, cow</span>
<span class="sd">        2: mouse, lion</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="c1"># inverse map</span>
        <span class="n">inverse_map</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
            <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">inverse_map</span><span class="p">:</span>
                <span class="n">inverse_map</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="n">inverse_map</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>

        <span class="n">string</span> <span class="o">=</span> <span class="s2">&quot;CLASS MAP</span><span class="se">\n</span><span class="s2">&quot;</span>
        <span class="n">string</span> <span class="o">+=</span> <span class="s2">&quot;------------------</span><span class="se">\n</span><span class="s2">&quot;</span>
        <span class="k">for</span> <span class="n">mask_idx</span> <span class="ow">in</span> <span class="n">inverse_map</span><span class="p">:</span>
            <span class="n">string</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">mask_idx</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">inverse_map</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">])</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span>
        <span class="n">string</span> <span class="o">+=</span> <span class="s2">&quot;------------------</span><span class="se">\n</span><span class="s2">&quot;</span>
        <span class="k">return</span> <span class="n">string</span>

    <span class="nd">@property</span>
    <span class="k">def</span><span class="w"> </span><span class="nf">latest_mask_idx</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        retruns the latest mask index</span>
<span class="sd">        </span>
<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        int</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span> <span class="o">-</span> <span class="mi">1</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span><span class="w"> </span><span class="nf">decompose</span><span class="p">(</span><span class="n">base_params</span><span class="p">,</span> <span class="n">skip_match_texts</span><span class="o">=</span><span class="p">(),</span> <span class="n">rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">):</span>
        <span class="n">decomposed_params</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">base_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">text</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">text</span> <span class="ow">in</span> <span class="n">skip_match_texts</span><span class="p">):</span>
                <span class="k">continue</span>  <span class="c1"># skip this param</span>
            <span class="n">W</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
            <span class="c1"># U: [m,r], S: [r], Vh: [r,n], r = min(m,n)</span>
            <span class="n">U</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">Vh</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">svd</span><span class="p">(</span><span class="n">W</span><span class="p">,</span> <span class="n">full_matrices</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
            <span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::U&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">U</span>
            <span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">S</span>
            <span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::Vh&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">Vh</span>
        <span class="k">return</span> <span class="n">decomposed_params</span>
    
    <span class="k">def</span><span class="w"> </span><span class="nf">add_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Call this to add a new mask (creates a new mask vector per decomposed matrix)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">mask_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterDict</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">text</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">text</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">):</span>
                <span class="k">continue</span>
            <span class="n">S</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">S</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="k">continue</span>
            <span class="c1"># init small random so sigmoid ≈ 0.5 with small variance</span>
            <span class="n">m</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">S</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_val</span><span class="p">)</span>
            <span class="n">mask_params</span><span class="p">[</span><span class="n">k</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="s1">&#39;__&#39;</span><span class="p">)]</span> <span class="o">=</span> <span class="n">m</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">num_params</span> <span class="o">+=</span> <span class="n">m</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>

        <span class="n">key</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask_params</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span><span class="p">]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">mask_params</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="k">return</span> <span class="kc">True</span>

<div class="viewcode-block" id="Chicken.add_class">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.add_class">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">add_class</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Call this to add a new set of classes (creates a new mask vector per decomposed matrix)</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        class_names: List[str], required</span>
<span class="sd">                 A list of class names</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        bool</span>
<span class="sd">            True if the classes were added successfully, False otherwise.</span>

<span class="sd">        Examples</span>
<span class="sd">        --------</span>
<span class="sd">        &gt;&gt;&gt; model.add_class([&quot;cat&quot;, &quot;dog&quot;])</span>
<span class="sd">        True</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">class_names</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span>

        <span class="n">mask_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterDict</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">text</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">text</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">):</span>
                <span class="k">continue</span>
            <span class="n">S</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">S</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="k">continue</span>
            <span class="c1"># init small random so sigmoid ≈ 0.5 with small variance</span>
            <span class="n">m</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">S</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_val</span><span class="p">)</span>
            <span class="n">mask_params</span><span class="p">[</span><span class="n">k</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="s1">&#39;__&#39;</span><span class="p">)]</span> <span class="o">=</span> <span class="n">m</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">num_params</span> <span class="o">+=</span> <span class="n">m</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>

        <span class="n">key</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask_params</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span><span class="p">]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">mask_params</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="k">return</span> <span class="kc">True</span></div>


<div class="viewcode-block" id="Chicken.set_mask">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.set_mask">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">set_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Set the selected mask</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        mask_idx: int, optional</span>
<span class="sd">            Set the selected mask to the mask_idx (default 0)</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        boolean</span>
<span class="sd">            True if selected mask set successfully</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>          <span class="c1"># special: base weights</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
            <span class="k">return</span> <span class="kc">True</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span>
        <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span><span class="s2">&quot;the mask number is out of range&quot;</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span> <span class="o">=</span> <span class="n">mask_idx</span>
        <span class="k">return</span> <span class="kc">True</span></div>



<div class="viewcode-block" id="Chicken.get_mask">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.get_mask">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">get_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Returns the state dictionary of the the selected mask</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        mask_idx: int, required</span>
<span class="sd">            The mask index if not sepecified return the last mask (default -1).</span>
<span class="sd">        </span>
<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        dict</span>
<span class="sd">            state_dict: a state dict of the selected mask</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span>
        <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span><span class="s2">&quot;the mask number is out of range&quot;</span><span class="p">)</span>

        <span class="c1"># set to latest mask if not specified</span>
        <span class="k">if</span> <span class="n">mask_num</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">latest_mask_idx</span>

        <span class="c1"># return ParameterDict for transparency</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">mask_num</span><span class="p">)]</span></div>


    <span class="k">def</span><span class="w"> </span><span class="nf">get_trainable_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
            <span class="k">return</span> <span class="p">[]</span>   <span class="c1"># nothing to train when using base weights</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span>

<div class="viewcode-block" id="Chicken.save_weights">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.save_weights">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">save_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Save the mask weights to the path</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        path: str, required</span>
<span class="sd">             location to where the mask should be saved should be .pt file.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">payload</span> <span class="o">=</span> <span class="p">{</span>
            <span class="s2">&quot;learnable_params&quot;</span><span class="p">:</span> <span class="p">{</span>
                <span class="n">idx</span><span class="p">:</span> <span class="p">{</span><span class="n">n</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
                <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
            <span class="p">},</span>
            <span class="s2">&quot;enable_mask&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="p">,</span>
            <span class="s2">&quot;new_mask_idx&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span><span class="p">,</span>
            <span class="s2">&quot;class_policy_map&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span><span class="p">,</span>
            <span class="s2">&quot;rank&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span>
            <span class="s2">&quot;matching_texts&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">,</span>
            <span class="s2">&quot;init_val&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_val</span><span class="p">,</span>
            <span class="s2">&quot;max_mult&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_mult</span><span class="p">,</span>
        <span class="p">}</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">payload</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span></div>


<div class="viewcode-block" id="Chicken.load_weights">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.load_weights">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">load_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Load the mask</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        path: str, required</span>
<span class="sd">            location to where the .pt for the mask is located.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">info</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleDict</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">info</span><span class="p">[</span><span class="s2">&quot;enable_mask&quot;</span><span class="p">])</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">new_mask_idx</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">info</span><span class="p">[</span><span class="s2">&quot;new_mask_idx&quot;</span><span class="p">])</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">class_policy_map</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">info</span><span class="p">[</span><span class="s2">&quot;class_policy_map&quot;</span><span class="p">])</span>
        <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;learnable_params&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
            <span class="n">pd</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterDict</span><span class="p">({</span><span class="n">n</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">d</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">pd</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">_mask_param_lists</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">idx</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">pd</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
        <span class="c1"># choose a mask (method unchanged per your request)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">set_mask</span><span class="p">()</span></div>


    <span class="k">def</span><span class="w"> </span><span class="nf">activate_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]:</span>
            <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_mult</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">compose_new_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">):</span>
        <span class="n">U</span>  <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">param_name</span><span class="si">}</span><span class="s2">::U&quot;</span><span class="p">]</span>   <span class="c1"># [m,r]</span>
        <span class="n">S</span>  <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">param_name</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">]</span>   <span class="c1"># [r]</span>
        <span class="n">Vh</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">param_name</span><span class="si">}</span><span class="s2">::Vh&quot;</span><span class="p">]</span>  <span class="c1"># [r,n]</span>

        <span class="n">mparam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learnable_params</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">mask_idx</span><span class="p">)][</span><span class="n">param_name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="s1">&#39;__&#39;</span><span class="p">)]</span>  <span class="c1"># [r]</span>
        <span class="n">mm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activate_mask</span><span class="p">(</span><span class="n">mparam</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">)</span>          <span class="c1"># [r]</span>
        <span class="n">S_scaled</span> <span class="o">=</span> <span class="n">S</span> <span class="o">*</span> <span class="n">mm</span>

        <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">S</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">eps</span>
        <span class="n">scale</span> <span class="o">=</span> <span class="p">(</span><span class="n">S</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="p">(</span><span class="n">S_scaled</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>

        <span class="n">Wp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;mr,r-&gt;mr&#39;</span><span class="p">,</span> <span class="n">U</span><span class="p">,</span> <span class="n">S_scaled</span><span class="p">)</span>
        <span class="n">Wp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;mr,rn-&gt;mn&#39;</span><span class="p">,</span> <span class="n">Wp</span><span class="p">,</span> <span class="n">Vh</span><span class="p">)</span> <span class="o">*</span> <span class="n">scale</span>
        <span class="k">return</span> <span class="n">Wp</span>


<div class="viewcode-block" id="Chicken.toggle_mask">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.toggle_mask">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">toggle_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_value</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        turn on or off the mask</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        mask_value: bool, optional</span>
<span class="sd">            A boolean checking whether the mask should be on or off (default True)</span>
<span class="sd">        mask_idx: int, optional</span>
<span class="sd">           If None selected the last mask index (default None)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">latest_mask_idx</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">enable_mask</span><span class="p">[</span><span class="n">mask_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask_value</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">apply_policy_to_model</span><span class="p">(</span><span class="n">mask_idx</span><span class="p">)</span></div>


<div class="viewcode-block" id="Chicken.update_backward">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.update_backward">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">update_backward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Backpropagate through the learnable mask parameters using VJP.</span>
<span class="sd">        Requires that loss.backward() has populated dL/dW on base weights.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        mask_idx: int, optional</span>
<span class="sd">            If None use the selected mask from set_mask (default None)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span>

        <span class="n">keys</span> <span class="o">=</span> <span class="p">[</span><span class="n">k</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span>
                <span class="k">if</span> <span class="nb">all</span><span class="p">(</span><span class="n">text</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">text</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">)</span>
                <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">]</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="n">keys</span><span class="p">:</span>
            <span class="k">return</span>
        <span class="n">last_key</span> <span class="o">=</span> <span class="n">keys</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">:</span>
            <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_parameter</span><span class="p">(</span><span class="n">k</span><span class="p">)</span><span class="o">.</span><span class="n">grad</span>
            <span class="k">if</span> <span class="n">g</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;No grad for </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">; call set_train() and loss.backward() first.&quot;</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">compose_new_params</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">)</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">g</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="p">(</span><span class="n">k</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">last_key</span><span class="p">))</span></div>


<div class="viewcode-block" id="Chicken.set_train">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.set_train">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">set_train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Set the learnable parameters to training mode.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ---------</span>
<span class="sd">        mask_idx: int, optional</span>
<span class="sd">            If None use the mask index from set_mask</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span>

        <span class="c1"># 1) freeze everything</span>
        <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">():</span>
            <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>

        <span class="c1"># 2) enable grads on base weights we compose (so dL/dW is computed)</span>
        <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span><span class="p">:</span>
            <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">s</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">):</span>
                <span class="k">continue</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="k">continue</span>
            <span class="n">p</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_parameter</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
            <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
            <span class="n">p</span><span class="o">.</span><span class="n">retain_grad</span><span class="p">()</span>  <span class="c1"># keep grad around for VJP</span>

        <span class="c1"># 3) ensure masks are trainable</span>
        <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_trainable_parameters</span><span class="p">(</span><span class="n">mask_idx</span><span class="p">):</span>
            <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span></div>


<div class="viewcode-block" id="Chicken.apply_policy_to_model">
<a class="viewcode-back" href="../../index.html#chicken.model.Chicken.apply_policy_to_model">[docs]</a>
    <span class="k">def</span><span class="w"> </span><span class="nf">apply_policy_to_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Compose &amp; write weights into the live model (fast in-place copy).</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        mask_idx: int, required</span>
<span class="sd">            index of the mask that should be applied to the model if None will choose based on set_mask or latest mask</span>
<span class="sd">        </span>
<span class="sd">        Examples</span>
<span class="sd">        --------</span>
<span class="sd">        &gt;&gt;&gt; model.apply_policy_to_model(1)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="n">mask_idx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mask_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_mask</span>

        <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">base</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
                <span class="n">param</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_parameter</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
                <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">skip</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">skip</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">matching_texts</span><span class="p">):</span>
                    <span class="n">param</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
                    <span class="k">continue</span>
                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decomposed_params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">::S&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                    <span class="n">param</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
                    <span class="k">continue</span>
                <span class="n">Wp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compose_new_params</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">mask_idx</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
                <span class="n">param</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">Wp</span><span class="p">)</span></div>
</div>

</pre></div>

          </div>
          
        </div>
      </div>
      <div class="sphinxsidebar" role="navigation" aria-label="Main">
        <div class="sphinxsidebarwrapper">
<h1 class="logo"><a href="../../index.html">contilearn</a></h1>









<search id="searchbox" style="display: none" role="search">
    <div class="searchformwrapper">
    <form class="search" action="../../search.html" method="get">
      <input type="text" name="q" aria-labelledby="searchlabel" autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false" placeholder="Search"/>
      <input type="submit" value="Go" />
    </form>
    </div>
</search>
<script>document.getElementById('searchbox').style.display = "block"</script><h3>Navigation</h3>

<div class="relations">
<h3>Related Topics</h3>
<ul>
  <li><a href="../../index.html">Documentation overview</a><ul>
  <li><a href="../index.html">Module code</a><ul>
  </ul></li>
  </ul></li>
</ul>
</div>








        </div>
      </div>
      <div class="clearer"></div>
    </div>
    <div class="footer">
      &#169;2025, Burhan Ul Tayyab.
      
      |
      Powered by <a href="https://www.sphinx-doc.org/">Sphinx 8.1.3</a>
      &amp; <a href="https://alabaster.readthedocs.io">Alabaster 1.0.0</a>
      
    </div>

    

    
  </body>
</html>