{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.profiler import profile, record_function, ProfilerActivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/home/mmccabe/venvs/well_venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from functools import partial\n",
    "\n",
    "# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n",
    "# from timm.models.helpers import load_pretrained\n",
    "from timm.models.layers import DropPath, to_2tuple\n",
    "# from timm.models.registry import register_model\n",
    "from einops import rearrange\n",
    "import numpy as np\n",
    "\n",
    "from torch.nn.attention.flex_attention import (create_block_mask, \n",
    "                                               flex_attention\n",
    ")\n",
    "\n",
    "torch.set_float32_matmul_precision('high')\n",
    "\n",
    "# flex_attention = torch.com\n",
    "\n",
    "# flex_attention = torch.compile(flex_attention, dynamic=False)\n",
    "# from torch.nn.attention.flex_attention import flex_attention\n",
    "\n",
    "def _cfg(url='', **kwargs):\n",
    "    return {\n",
    "        'url': url,\n",
    "        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n",
    "        'crop_pct': .9, 'interpolation': 'bicubic',\n",
    "        # 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n",
    "        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n",
    "        **kwargs\n",
    "    }\n",
    "\n",
    "\n",
    "default_cfgs = {\n",
    "    'cswin_224': _cfg(),\n",
    "    'cswin_384': _cfg(\n",
    "        crop_pct=1.0\n",
    "    ),\n",
    "\n",
    "}\n",
    "\n",
    "\n",
    "class Mlp(nn.Module):\n",
    "    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n",
    "        super().__init__()\n",
    "        out_features = out_features or in_features\n",
    "        hidden_features = hidden_features or in_features\n",
    "        self.fc1 = nn.Linear(in_features, hidden_features)\n",
    "        self.act = act_layer()\n",
    "        self.fc2 = nn.Linear(hidden_features, out_features)\n",
    "        self.drop = nn.Dropout(drop)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.act(x)\n",
    "        x = self.drop(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.drop(x)\n",
    "        return x\n",
    "\n",
    "class LePEAttention(nn.Module):\n",
    "    def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., qk_scale=None):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.dim_out = dim_out or dim\n",
    "        self.resolution = resolution\n",
    "        self.split_size = split_size\n",
    "        self.num_heads = num_heads\n",
    "        head_dim = dim // num_heads\n",
    "        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n",
    "        self.scale = qk_scale or head_dim ** -0.5\n",
    "        if idx == -1:\n",
    "            H_sp, W_sp = self.resolution, self.resolution\n",
    "        elif idx == 0:\n",
    "            H_sp, W_sp = self.resolution, self.split_size\n",
    "        elif idx == 1:\n",
    "            W_sp, H_sp = self.resolution, self.split_size\n",
    "        else:\n",
    "            print (\"ERROR MODE\", idx)\n",
    "            exit(0)\n",
    "        self.H_sp = H_sp\n",
    "        self.W_sp = W_sp\n",
    "        stride = 1\n",
    "        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)\n",
    "\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "\n",
    "    def im2cswin(self, x):\n",
    "        B, N, C = x.shape\n",
    "        H = W = int(np.sqrt(N))\n",
    "        x = x.transpose(-2,-1).contiguous().view(B, C, H, W)\n",
    "        x = img2windows(x, self.H_sp, self.W_sp)\n",
    "        x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()\n",
    "        return x\n",
    "\n",
    "    def get_lepe(self, x, func):\n",
    "        B, N, C = x.shape\n",
    "        H = W = int(np.sqrt(N))\n",
    "        x = x.transpose(-2,-1).contiguous().view(B, C, H, W)\n",
    "\n",
    "        H_sp, W_sp = self.H_sp, self.W_sp\n",
    "        x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)\n",
    "        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W'\n",
    "        lepe = func(x) ### B', C, H', W'\n",
    "        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous()\n",
    "\n",
    "        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp* self.W_sp).permute(0, 1, 3, 2).contiguous()\n",
    "        return x, lepe\n",
    "\n",
    "    def forward(self, qkv):\n",
    "        \"\"\"\n",
    "        x: B L C\n",
    "        \"\"\"\n",
    "        q,k,v = qkv[0], qkv[1], qkv[2]\n",
    "\n",
    "        ### Img2Window\n",
    "        H = W = self.resolution\n",
    "        B, L, C = q.shape\n",
    "        assert L == H * W, \"flatten img_tokens has wrong size\"\n",
    "        \n",
    "        q = self.im2cswin(q)\n",
    "        k = self.im2cswin(k)\n",
    "        v, lepe = self.get_lepe(v, self.get_v)\n",
    "\n",
    "        x = F.scaled_dot_product_attention(q, k, v) + lepe\n",
    "        # q = q * self.scale\n",
    "        # attn = (q @ k.transpose(-2, -1))  # B head N C @ B head C N --> B head N N\n",
    "        # attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)\n",
    "        # attn = self.attn_drop(attn)\n",
    "        # print out shapes of major tensors\n",
    "        # x = (attn @ v) + lepe\n",
    "        x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C)  # B head N N @ B head N C\n",
    "\n",
    "        ### Window2Img\n",
    "        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)  # B H' W' C\n",
    "\n",
    "        return x\n",
    "    \n",
    "def build_cswin_op(H, H_sp, W, W_sp, D=1, D_sp=1):\n",
    "    ## Building this assuming 3D is a possibility\n",
    "    ## Almost certainly a more efficient approach here, but don't want to spend forever\n",
    "    def block_mask(b, h, q_idx, kv_idx):\n",
    "        # Get x, y, z coordinates for q_idx and kv_idx, assuming stride is H, W, D\n",
    "        z_q, z_kv = q_idx % D, kv_idx % D\n",
    "        next_q, next_kv = q_idx // D, kv_idx // D\n",
    "        x_q, x_kv = next_q % W, next_kv % W\n",
    "        y_q, y_kv = next_q // W, next_kv // W\n",
    "        # Now make sure each coord is in same window\n",
    "        z_block_mask = (z_q // D_sp) == (z_kv // D_sp)\n",
    "        x_block_mask = (x_q // W_sp) == (x_kv // W_sp)\n",
    "        y_block_mask = (y_q // H_sp) == (y_kv // H_sp)\n",
    "        return z_block_mask & x_block_mask  & y_block_mask\n",
    "    return block_mask\n",
    "    \n",
    "class LePEAttentionFlex(nn.Module):\n",
    "    def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., qk_scale=None):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.dim_out = dim_out or dim\n",
    "        self.resolution = resolution\n",
    "        self.split_size = split_size\n",
    "        self.num_heads = num_heads\n",
    "        head_dim = dim // num_heads\n",
    "        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n",
    "        self.scale = qk_scale or head_dim ** -0.5\n",
    "        if idx == -1:\n",
    "            H_sp, W_sp = self.resolution, self.resolution\n",
    "        elif idx == 0:\n",
    "            H_sp, W_sp = self.resolution, self.split_size\n",
    "        elif idx == 1:\n",
    "            W_sp, H_sp = self.resolution, self.split_size\n",
    "        else:\n",
    "            print (\"ERROR MODE\", idx)\n",
    "            exit(0)\n",
    "        self.H_sp = H_sp\n",
    "        self.W_sp = W_sp\n",
    "        self.block_mask_func = create_block_mask(build_cswin_op(self.resolution, self.H_sp, self.resolution, self.W_sp, \n",
    "                                                                1, 1), # Trailing 1s placeholder for 3D\n",
    "                                                 B=None, H=None, Q_LEN=resolution**2,  KV_LEN=resolution**2,\n",
    "                                                 # Block size is tricky - need to work through math to get better answer\n",
    "                                                 # For 2D, we know that we have one direction contiguous and the other at stride\n",
    "                                                 # \"Resolution\", so block size needs to be at least < resolution to get any \n",
    "                                                 # sparsity in that direction. But larger blocks translate to better \n",
    "                                                 # performance. So we'll start with resolution//4, but this is probably\n",
    "                                                 # too small.\n",
    "                                                #  BLOCK_SIZE=resolution, # Not working on nightly currently - fix later\n",
    "                                                 _compile=True\n",
    "                                                 ) # The 1 is because this is a 2D mask\n",
    "        stride = 1\n",
    "        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)\n",
    "\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "\n",
    "    def im2cswin(self, x):\n",
    "        B, N, C = x.shape\n",
    "        H = W = int(np.sqrt(N))\n",
    "        x = x.transpose(-2,-1).contiguous().view(B, C, H, W)\n",
    "        x = img2windows(x, self.H_sp, self.W_sp)\n",
    "        x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()\n",
    "        return x\n",
    "\n",
    "    def get_lepe(self, x, func):\n",
    "        B, N, C = x.shape\n",
    "        H = W = int(np.sqrt(N))\n",
    "        # x = x.transpose(-2,-1).contiguous().view(B, C, H, W)\n",
    "\n",
    "        H_sp, W_sp = self.H_sp, self.W_sp\n",
    "        x = rearrange(x, 'B (H Hsp W Wsp) C -> (B H W) C Hsp Wsp', Hsp=H_sp, Wsp=W_sp, H=H//H_sp)\n",
    "        lepe = func(x) ### B', C, H', W'\n",
    "        lepe = rearrange(lepe,'(B H W) C Hsp Wsp -> B (H Hsp W Wsp) C', B=B, H=H//H_sp)\n",
    "        return lepe\n",
    "\n",
    "    def forward(self, qkv):\n",
    "        \"\"\"\n",
    "        x: B L C\n",
    "        \"\"\"\n",
    "        q,k,v = qkv[0], qkv[1], qkv[2]\n",
    "\n",
    "        ### Img2Window\n",
    "        H = W = self.resolution\n",
    "        B, L, C = q.shape\n",
    "        assert L == H * W, \"flatten img_tokens has wrong size\"\n",
    "        \n",
    "        lepe = self.get_lepe(v, self.get_v)\n",
    "\n",
    "        # Print shapes of all major tensors\n",
    "        q, k, v, lepe = map(lambda t: rearrange(t, 'B H (he C) -> B he H C', he=self.num_heads).contiguous(), (q, k, v, lepe))\n",
    "        x = flex_attention(q, k, v, block_mask=self.block_mask_func) + lepe\n",
    "        x = rearrange(x, 'B he H C -> B H (he C)')\n",
    "        return x\n",
    "\n",
    "class CSWinBlockFlex(nn.Module):\n",
    "    def __init__(self, dim, reso, num_heads,\n",
    "                 split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n",
    "                 drop=0., attn_drop=0., drop_path=0.,\n",
    "                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n",
    "                 last_stage=False):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.num_heads = num_heads\n",
    "        self.patches_resolution = reso\n",
    "        self.split_size = split_size\n",
    "        self.mlp_ratio = mlp_ratio\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.norm1 = norm_layer(dim)\n",
    "\n",
    "        if self.patches_resolution == split_size:\n",
    "            last_stage = True\n",
    "        if last_stage:\n",
    "            self.branch_num = 1\n",
    "        else:\n",
    "            self.branch_num = 2\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(drop)\n",
    "        \n",
    "        if last_stage:\n",
    "            self.attns = nn.ModuleList([\n",
    "                LePEAttentionFlex(\n",
    "                    dim, resolution=self.patches_resolution, idx = -1,\n",
    "                    split_size=split_size, num_heads=num_heads, dim_out=dim,\n",
    "                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
    "                for i in range(self.branch_num)])\n",
    "        else:\n",
    "            self.attns = nn.ModuleList([\n",
    "                LePEAttentionFlex(\n",
    "                    dim//2, resolution=self.patches_resolution, idx = i,\n",
    "                    split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,\n",
    "                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
    "                for i in range(self.branch_num)])\n",
    "        \n",
    "\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "\n",
    "        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n",
    "        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)\n",
    "        self.norm2 = norm_layer(dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: B, H*W, C\n",
    "        \"\"\"\n",
    "\n",
    "        H = W = self.patches_resolution\n",
    "        B, L, C = x.shape\n",
    "        assert L == H * W, \"flatten img_tokens has wrong size\"\n",
    "        img = self.norm1(x)\n",
    "        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)\n",
    "        \n",
    "        if self.branch_num == 2:\n",
    "            x1 = self.attns[0](qkv[:,:,:,:C//2])\n",
    "            x2 = self.attns[1](qkv[:,:,:,C//2:])\n",
    "            attened_x = torch.cat([x1,x2], dim=2)\n",
    "        else:\n",
    "            attened_x = self.attns[0](qkv)\n",
    "        attened_x = self.proj(attened_x)\n",
    "        x = x + self.drop_path(attened_x)\n",
    "        x = x + self.drop_path(self.mlp(self.norm2(x)))\n",
    "\n",
    "        return x\n",
    "\n",
    "class CSWinBlock(nn.Module):\n",
    "\n",
    "    def __init__(self, dim, reso, num_heads,\n",
    "                 split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n",
    "                 drop=0., attn_drop=0., drop_path=0.,\n",
    "                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n",
    "                 last_stage=False):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.num_heads = num_heads\n",
    "        self.patches_resolution = reso\n",
    "        self.split_size = split_size\n",
    "        self.mlp_ratio = mlp_ratio\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.norm1 = norm_layer(dim)\n",
    "\n",
    "        if self.patches_resolution == split_size:\n",
    "            last_stage = True\n",
    "        if last_stage:\n",
    "            self.branch_num = 1\n",
    "        else:\n",
    "            self.branch_num = 2\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(drop)\n",
    "        \n",
    "        if last_stage:\n",
    "            self.attns = nn.ModuleList([\n",
    "                LePEAttention(\n",
    "                    dim, resolution=self.patches_resolution, idx = -1,\n",
    "                    split_size=split_size, num_heads=num_heads, dim_out=dim,\n",
    "                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
    "                for i in range(self.branch_num)])\n",
    "        else:\n",
    "            self.attns = nn.ModuleList([\n",
    "                LePEAttention(\n",
    "                    dim//2, resolution=self.patches_resolution, idx = i,\n",
    "                    split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,\n",
    "                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
    "                for i in range(self.branch_num)])\n",
    "        \n",
    "\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "\n",
    "        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n",
    "        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)\n",
    "        self.norm2 = norm_layer(dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: B, H*W, C\n",
    "        \"\"\"\n",
    "\n",
    "        H = W = self.patches_resolution\n",
    "        B, L, C = x.shape\n",
    "        assert L == H * W, \"flatten img_tokens has wrong size\"\n",
    "        img = self.norm1(x)\n",
    "        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)\n",
    "        \n",
    "        if self.branch_num == 2:\n",
    "            x1 = self.attns[0](qkv[:,:,:,:C//2])\n",
    "            x2 = self.attns[1](qkv[:,:,:,C//2:])\n",
    "            attened_x = torch.cat([x1,x2], dim=2)\n",
    "        else:\n",
    "            attened_x = self.attns[0](qkv)\n",
    "        attened_x = self.proj(attened_x)\n",
    "        x = x + self.drop_path(attened_x)\n",
    "        x = x + self.drop_path(self.mlp(self.norm2(x)))\n",
    "\n",
    "        return x\n",
    "\n",
    "def img2windows(img, H_sp, W_sp):\n",
    "    \"\"\"\n",
    "    img: B C H W\n",
    "    \"\"\"\n",
    "    B, C, H, W = img.shape\n",
    "    img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)\n",
    "    img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)\n",
    "    return img_perm\n",
    "\n",
    "def windows2img(img_splits_hw, H_sp, W_sp, H, W):\n",
    "    \"\"\"\n",
    "    img_splits_hw: B' H W C\n",
    "    \"\"\"\n",
    "    B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))\n",
    "\n",
    "    img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)\n",
    "    img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n",
    "    return img\n",
    "\n",
    "class Merge_Block(nn.Module):\n",
    "    def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm):\n",
    "        super().__init__()\n",
    "        self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1)\n",
    "        self.norm = norm_layer(dim_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, new_HW, C = x.shape\n",
    "        H = W = int(np.sqrt(new_HW))\n",
    "        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)\n",
    "        x = self.conv(x)\n",
    "        B, C = x.shape[:2]\n",
    "        x = x.view(B, C, -1).transpose(-2, -1).contiguous()\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "def _conv_filter(state_dict, patch_size=16):\n",
    "    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n",
    "    out_dict = {}\n",
    "    for k, v in state_dict.items():\n",
    "        if 'patch_embed.proj.weight' in k:\n",
    "            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n",
    "        out_dict[k] = v\n",
    "    return out_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E0826 18:49:06.853000 1520393 torch/_inductor/select_algorithm.py:1300] [1/0_1] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/s7/cs7nyhfqv4kzrcqg7rf7ktr2h6bdjx37y552xl4vqmavokp7mc6w.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:06.866000 1520393 torch/_inductor/select_algorithm.py:1300] [1/0_1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/6y/c6yl5mjb2oalrpsah3z6o74gwqd76owob3jmzasoiqpji4omzzg6.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:07.112000 1520393 torch/_inductor/select_algorithm.py:1300] [1/0_1] Exception out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/zs/czsilylzmwvqf6sms3vjgscf5oousk42gtwldlc4phq4vi4dgfe2.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4)\n",
      "E0826 18:49:07.278000 1520393 torch/_inductor/select_algorithm.py:1300] [1/0_1] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/q4/cq42ikwxillt527igew4fsad3w3ddfb3lnqnug2wocebfgvdaesd.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0826 18:49:07.491000 1520393 torch/_inductor/select_algorithm.py:1300] [1/0_1] Exception out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/vd/cvdkiay3vtuuwkpdx4q4nzj6zateretscu2lufnp5iiiuh76gvgt.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "W0826 18:49:08.456000 1520393 torch/_inductor/select_algorithm.py:1517] [1/0_1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:08.955000 1520393 torch/_inductor/select_algorithm.py:1517] [1/0_1] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:09.431000 1520393 torch/_inductor/select_algorithm.py:1517] [1/0_1] out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:10.057000 1520393 torch/_inductor/select_algorithm.py:1517] [1/0_1] out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:10.060000 1520393 torch/_inductor/select_algorithm.py:1517] [1/0_1] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE mm(65536x768, 768x2304)\n",
      "  mm 4.4073 ms 100.0% \n",
      "  triton_mm_16 4.4749 ms 98.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_13 5.3545 ms 82.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_9 5.3893 ms 81.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_14 5.4804 ms 80.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_10 5.6095 ms 78.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_15 5.7078 ms 77.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_mm_11 6.2372 ms 70.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_6 6.6918 ms 65.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_5 6.8659 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.5684 seconds and 1.1020 seconds precompiling\n",
      "W0826 18:49:14.706000 1520393 torch/_dynamo/utils.py:914] [3/0_1] ChromiumEventLogger: Detected overlapping events, fixing stack\n",
      "AUTOTUNE flex_attention(1x6x65536x64, 1x6x65536x64, 1x6x65536x64, 1x6x65536, 1x1x512, 1x1x512x512, 1x1x512, 1x1x512x512)\n",
      "  triton_flex_attention_19 10.2226 ms 100.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_21 10.2226 ms 100.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_22 10.2236 ms 100.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_23 10.2236 ms 100.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_20 10.2359 ms 99.9% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 0.9036 seconds and 0.0036 seconds precompiling\n",
      "AUTOTUNE flex_attention(1x6x65536x64, 1x6x65536x64, 1x6x65536x64, 1x6x65536, 1x1x512, 1x1x512x512, 1x1x512, 1x1x512x512)\n",
      "  triton_flex_attention_24 9.9717 ms 100.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_26 10.0547 ms 99.2% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_27 10.0547 ms 99.2% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_28 10.0557 ms 99.2% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "  triton_flex_attention_25 10.0680 ms 99.0% BLOCK_M=128, BLOCK_N=32, GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=64, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, num_stages=3, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 0.8882 seconds and 0.0036 seconds precompiling\n",
      "E0826 18:49:44.025000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/fo/cfooc7zjm6hxr5mrhbjqjdvbduaeqwkijky6dbw5lfbi7hejaboq.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:44.032000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/jw/cjwwv54u4yh72sstmnlhw63ndw64aknyrcyhhbosfsf4efg4baox.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0826 18:49:44.037000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/6j/c6jymuy6ato6tsj3p4n43c4q7sougpt44ptxaapudas4c3d2lj26.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0826 18:49:44.043000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/54/c54an7akgd4m6xeyvebfpm3s2e7anpmx5qwfp6urslapkm7wvsyl.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:44.045000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/vc/cvckam3ywcanustc5qbl5qcngf4x3nyizwprp2yx4yq5h3mcgkaj.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4)\n",
      "W0826 18:49:45.240000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:45.769000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:46.283000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:46.937000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:46.940000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE addmm(65536x3072, 65536x768, 768x3072)\n",
      "  triton_mm_64 5.9771 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  bias_addmm 6.0017 ms 99.6% \n",
      "  addmm 6.8833 ms 86.8% \n",
      "  triton_mm_61 7.0441 ms 84.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_57 7.3134 ms 81.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_62 7.4332 ms 80.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_58 7.6585 ms 78.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_63 7.6913 ms 77.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_mm_59 8.3251 ms 71.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_54 9.1146 ms 65.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.8826 seconds and 0.0345 seconds precompiling\n",
      "E0826 18:49:49.051000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/ui/cuitfrh7c4edjeavccazyczobdndcjy7jvqczspo5hrekk7pm3gy.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4)\n",
      "E0826 18:49:49.052000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/4a/c4aijfdihx4icu5hmfiem5qes7cvpdoshfhcrebnzaebaqfckvwf.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:49.054000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/pi/cpifx4zaujxgliocwnynnwjswcams6mraupq37fwlypal2mryxfe.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0826 18:49:49.055000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/6b/c6bno5vt5y7yejhhhz2rwe36b7ep7rcodzdb4wz6bevphtj5nxts.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:49.056000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/3c/c3cv2pqgcjh53ey7bzyo7cdzrf4ii4wo6exm52rj5fopcmfc2qyk.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "W0826 18:49:50.134000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:50.674000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:51.182000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:51.876000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:51.879000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE mm(65536x3072, 3072x768)\n",
      "  mm 5.3207 ms 100.0% \n",
      "  triton_mm_83 5.6044 ms 94.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_82 7.0892 ms 75.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_mm_76 7.4281 ms 71.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_77 7.6462 ms 69.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_78 7.8049 ms 68.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_73 9.4771 ms 56.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_72 9.6860 ms 54.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_74 10.1345 ms 52.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8\n",
      "  triton_mm_81 11.0694 ms 48.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.8209 seconds and 0.0096 seconds precompiling\n",
      "E0826 18:49:53.237000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/dr/cdrer6yicww65k47h42o6pzlswqkrdqfcvqxhwmtlg2f6dau4r3v.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0826 18:49:53.238000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/77/c77pfwwttcl24ugvglkguvusjvrazeetmsh23xcjgcqam3yapema.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:53.239000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/ry/crypa2f6yklicy2qww4r5ow6wlqifgoxf3lchkvxm57oh6s73rqh.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0826 18:49:53.240000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/wy/cwyovwzpshlsfb6k5rlqiwgzauvgyn6do46l6digbaj3eexuoayh.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0826 18:49:53.241000 1520393 torch/_inductor/select_algorithm.py:1300] [7/0] Exception out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/lk/clkpkdumk55v43wsu4wq4cx6pfghmqvakqllgopowzfyundvzqar.py, ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4)\n",
      "W0826 18:49:53.965000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:54.374000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:54.781000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:55.318000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0826 18:49:55.320000 1520393 torch/_inductor/select_algorithm.py:1517] [7/0] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE mm(65536x768, 768x768)\n",
      "  mm 1.4664 ms 100.0% \n",
      "  triton_mm_45 1.4805 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_42 1.7623 ms 83.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_38 1.7992 ms 81.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_39 1.8135 ms 80.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_43 1.8207 ms 80.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_mm_44 1.9180 ms 76.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_mm_40 2.0070 ms 73.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_35 2.2026 ms 66.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_34 2.2200 ms 66.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.0779 seconds and 0.0078 seconds precompiling\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.3196e-04, -4.1336e-05, -8.9407e-05,  ..., -2.5988e-05,\n",
       "           2.2542e-04,  1.3888e-05],\n",
       "         [-3.1590e-06,  2.7418e-05, -1.4663e-05,  ..., -3.3379e-06,\n",
       "           7.8082e-05, -1.9848e-05],\n",
       "         [-9.5963e-05,  1.0452e-04, -5.5134e-06,  ..., -1.1876e-05,\n",
       "          -5.0664e-06,  1.7166e-05],\n",
       "         ...,\n",
       "         [ 2.7031e-05,  9.2626e-05,  8.0585e-05,  ..., -5.5075e-05,\n",
       "          -6.2168e-05,  4.1842e-05],\n",
       "         [-2.5988e-05,  4.2677e-05, -5.3279e-05,  ...,  7.5996e-06,\n",
       "           1.5914e-05, -6.3479e-05],\n",
       "         [-1.8358e-05, -6.7115e-05,  1.5050e-05,  ..., -8.2254e-05,\n",
       "           4.7386e-05, -1.6004e-05]]], device='cuda:0', grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = 256\n",
    "torch.manual_seed(0)\n",
    "m = torch.compile(CSWinBlockFlex(768, res, 12, 8).cuda(), dynamic=False, mode=\"max-autotune-no-cudagraphs\")\n",
    "torch.manual_seed(0)\n",
    "m2 = torch.compile(CSWinBlock(768, res, 12, 8).cuda(), dynamic=False, mode=\"max-autotune-no-cudagraphs\")\n",
    "# mc = torch.compile(m)\n",
    "\n",
    "x = torch.randn(1, res**2, 768).cuda()\n",
    "m(x) - m2(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max mem 11.330207824707031 GB\n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                  Torch-Compiled Region         0.65%       3.119ms         9.03%      42.994ms     977.127us       0.000us         0.00%        1.422s      32.313ms           0 b           0 b      43.55 Gb      -1.50 Gb            44  \n",
      "                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us     474.389ms       100.02%     474.389ms     474.389ms           0 b           0 b           0 b           0 b             1  \n",
      "                                       CompiledFunction         1.29%       6.123ms         2.17%      10.323ms     430.119us       0.000us         0.00%     474.313ms      19.763ms           0 b           0 b      16.51 Gb      15.76 Gb            24  \n",
      "                                        model_inference         0.15%     692.818us         3.28%      15.621ms      15.621ms       0.000us         0.00%     474.313ms     474.313ms           0 b           0 b           0 b     -15.01 Gb             1  \n",
      "                                                triton_         0.00%       0.000us         0.00%       0.000us       0.000us     430.867ms        90.84%     430.867ms       5.669ms           0 b           0 b           0 b           0 b            76  \n",
      "                                     triton_tem_fused_1         0.03%     130.131us         0.04%     185.394us      23.174us     371.960ms        78.42%     371.960ms      46.495ms           0 b           0 b           0 b           0 b             8  \n",
      "                                               aten::mm         0.08%     380.528us         0.10%     480.170us      40.014us      39.866ms         8.41%      39.866ms       3.322ms           0 b           0 b           0 b           0 b            12  \n",
      "void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_...         0.00%       0.000us         0.00%       0.000us       0.000us      39.866ms         8.41%      39.866ms       3.322ms           0 b           0 b           0 b           0 b            12  \n",
      "                               triton_tem_fused_addmm_2         0.01%      55.813us         0.02%      83.215us      20.804us      21.174ms         4.46%      21.174ms       5.294ms           0 b           0 b           0 b           0 b             4  \n",
      "                                triton_poi_fused_gelu_3         0.01%      46.497us         0.02%      74.163us      18.541us       9.430ms         1.99%       9.430ms       2.358ms           0 b           0 b           0 b           0 b             4  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 476.319ms\n",
      "Self CUDA time total: 474.313ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.reset_peak_memory_stats()\n",
    "with profile(activities=[\n",
    "        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        m(x)\n",
    "        m(x)\n",
    "        m(x)\n",
    "        m(x)\n",
    "        mem = torch.cuda.max_memory_allocated() / 1024**3\n",
    "\n",
    "print('Max mem', mem, 'GB')\n",
    "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max mem 11.517707824707031 GB\n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                  Torch-Compiled Region         1.78%       4.909ms        23.45%      64.547ms     701.603us       0.000us         0.00%        1.248s      13.563ms         704 b           0 b      53.69 Gb      -3.75 Gb            92  \n",
      "                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us     272.896ms       100.03%     272.896ms     272.896ms           0 b           0 b           0 b           0 b             1  \n",
      "                                        model_inference         0.29%     794.006us         6.34%      17.454ms      17.454ms       0.000us         0.00%     272.810ms     272.810ms           0 b        -128 b           0 b     -13.51 Gb             1  \n",
      "                                       CompiledFunction         2.52%       6.928ms         3.80%      10.461ms     261.533us       0.000us         0.00%     272.810ms       6.820ms         128 b           0 b      17.26 Gb      15.75 Gb            40  \n",
      "          aten::_scaled_dot_product_efficient_attention         0.04%     100.032us         0.21%     571.545us      71.443us       0.000us         0.00%     164.692ms      20.586ms         128 b           0 b     780.00 Mb           0 b             8  \n",
      "                     aten::_efficient_attention_forward         0.05%     137.660us         0.14%     375.178us      46.897us     164.692ms        60.37%     164.692ms      20.586ms         128 b           0 b     780.00 Mb           0 b             8  \n",
      "fmha_cutlassF_f32_aligned_64x64_rf_sm80(PyTorchMemEf...         0.00%       0.000us         0.00%       0.000us       0.000us     164.692ms        60.37%     164.692ms      20.586ms           0 b           0 b           0 b           0 b             8  \n",
      "                                                triton_         0.00%       0.000us         0.00%       0.000us       0.000us      63.499ms        23.28%      63.499ms     835.513us           0 b           0 b           0 b           0 b            76  \n",
      "                                               aten::mm         0.12%     326.589us         0.15%     423.684us      35.307us      42.232ms        15.48%      42.232ms       3.519ms           0 b           0 b           0 b           0 b            12  \n",
      "void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_...         0.00%       0.000us         0.00%       0.000us       0.000us      42.232ms        15.48%      42.232ms       3.519ms           0 b           0 b           0 b           0 b            12  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 275.224ms\n",
      "Self CUDA time total: 272.810ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.reset_peak_memory_stats()\n",
    "with profile(activities=[\n",
    "        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        m2(x)\n",
    "        m2(x)\n",
    "        m2(x)\n",
    "        m2(x)\n",
    "        mem = torch.cuda.max_memory_allocated() / 1024**3\n",
    "\n",
    "print('Max mem', mem, 'GB')\n",
    "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 11.9 ms, sys: 2.93 ms, total: 14.9 ms\n",
      "Wall time: 13.5 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.9490,  0.4259,  2.0703,  ..., -2.0693,  0.1600, -0.6366],\n",
       "         [ 0.2641, -1.0232, -1.1603,  ...,  1.3648,  0.6310, -0.6931],\n",
       "         [-1.2044, -0.4342, -0.1973,  ...,  0.2379,  0.1587,  1.2159],\n",
       "         ...,\n",
       "         [-0.3595,  0.6229,  0.8600,  ...,  0.9993,  0.4613, -1.5493],\n",
       "         [ 0.9457, -1.5176, -0.1061,  ...,  0.4901, -0.8905, -0.6047],\n",
       "         [-1.3625, -1.8760, -0.1431,  ...,  1.8538,  0.9562, -0.4032]]],\n",
       "       device='cuda:0', grad_fn=<CompiledFunctionBackward>)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time m(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 82.50 MiB is free. Including non-PyTorch memory, this process has 47.06 GiB memory in use. Of the allocated memory 46.09 GiB is allocated by PyTorch, and 667.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "File \u001b[0;32m<timed eval>:1\u001b[0m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1750\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:469\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    464\u001b[0m saved_dynamic_layer_stack_depth \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    465\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_functorch\u001b[38;5;241m.\u001b[39mget_dynamic_layer_stack_depth()\n\u001b[1;32m    466\u001b[0m )\n\u001b[1;32m    468\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 469\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    470\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    471\u001b[0m     \u001b[38;5;66;03m# Restore the dynamic layer stack depth if necessary.\u001b[39;00m\n\u001b[1;32m    472\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_functorch\u001b[38;5;241m.\u001b[39mpop_dynamic_layer_stack_and_undo_to_depth(\n\u001b[1;32m    473\u001b[0m         saved_dynamic_layer_stack_depth\n\u001b[1;32m    474\u001b[0m     )\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1750\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[2], line 362\u001b[0m, in \u001b[0;36mCSWinBlock.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    359\u001b[0m qkv \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqkv(img)\u001b[38;5;241m.\u001b[39mreshape(B, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m3\u001b[39m, C)\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m3\u001b[39m)\n\u001b[1;32m    361\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbranch_num \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[0;32m--> 362\u001b[0m     x1 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattns[\u001b[38;5;241m0\u001b[39m](qkv[:,:,:,:C\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m])\n\u001b[1;32m    363\u001b[0m     x2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattns[\u001b[38;5;241m1\u001b[39m](qkv[:,:,:,C\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m:])\n\u001b[1;32m    364\u001b[0m     attened_x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([x1,x2], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
      "Cell \u001b[0;32mIn[2], line 363\u001b[0m, in \u001b[0;36mtorch_dynamo_resume_in_forward_at_362\u001b[0;34m(___stack0, self, x, C, qkv)\u001b[0m\n\u001b[1;32m    361\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbranch_num \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m    362\u001b[0m     x1 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattns[\u001b[38;5;241m0\u001b[39m](qkv[:,:,:,:C\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m])\n\u001b[0;32m--> 363\u001b[0m     x2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattns[\u001b[38;5;241m1\u001b[39m](qkv[:,:,:,C\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m:])\n\u001b[1;32m    364\u001b[0m     attened_x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([x1,x2], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m    365\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1750\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[2], line 123\u001b[0m, in \u001b[0;36mLePEAttention.forward\u001b[0;34m(self, qkv)\u001b[0m\n\u001b[1;32m    120\u001b[0m B, L, C \u001b[38;5;241m=\u001b[39m q\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m    121\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m L \u001b[38;5;241m==\u001b[39m H \u001b[38;5;241m*\u001b[39m W, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflatten img_tokens has wrong size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 123\u001b[0m q \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(q)\n\u001b[1;32m    124\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(k)\n\u001b[1;32m    125\u001b[0m v, lepe \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_lepe(v, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_v)\n",
      "Cell \u001b[0;32mIn[2], line 124\u001b[0m, in \u001b[0;36mtorch_dynamo_resume_in_forward_at_123\u001b[0;34m(___stack0, self, k, v, H, W, B, C)\u001b[0m\n\u001b[1;32m    121\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m L \u001b[38;5;241m==\u001b[39m H \u001b[38;5;241m*\u001b[39m W, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflatten img_tokens has wrong size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    123\u001b[0m q \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(q)\n\u001b[0;32m--> 124\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(k)\n\u001b[1;32m    125\u001b[0m v, lepe \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_lepe(v, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_v)\n\u001b[1;32m    127\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mscaled_dot_product_attention(q, k, v) \u001b[38;5;241m+\u001b[39m lepe\n",
      "Cell \u001b[0;32mIn[2], line 125\u001b[0m, in \u001b[0;36mtorch_dynamo_resume_in_forward_at_124\u001b[0;34m(___stack0, self, v, H, W, B, C, q)\u001b[0m\n\u001b[1;32m    123\u001b[0m q \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(q)\n\u001b[1;32m    124\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim2cswin(k)\n\u001b[0;32m--> 125\u001b[0m v, lepe \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_lepe(v, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_v)\n\u001b[1;32m    127\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mscaled_dot_product_attention(q, k, v) \u001b[38;5;241m+\u001b[39m lepe\n\u001b[1;32m    128\u001b[0m \u001b[38;5;66;03m# q = q * self.scale\u001b[39;00m\n\u001b[1;32m    129\u001b[0m \u001b[38;5;66;03m# attn = (q @ k.transpose(-2, -1))  # B head N C @ B head C N --> B head N N\u001b[39;00m\n\u001b[1;32m    130\u001b[0m \u001b[38;5;66;03m# attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)\u001b[39;00m\n\u001b[1;32m    131\u001b[0m \u001b[38;5;66;03m# attn = self.attn_drop(attn)\u001b[39;00m\n\u001b[1;32m    132\u001b[0m \u001b[38;5;66;03m# print out shapes of major tensors\u001b[39;00m\n\u001b[1;32m    133\u001b[0m \u001b[38;5;66;03m# x = (attn @ v) + lepe\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[2], line 100\u001b[0m, in \u001b[0;36mLePEAttention.get_lepe\u001b[0;34m(self, x, func)\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_lepe\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, func):\n\u001b[1;32m     99\u001b[0m     B, N, C \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[0;32m--> 100\u001b[0m     H \u001b[38;5;241m=\u001b[39m W \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(np\u001b[38;5;241m.\u001b[39msqrt(N))\n\u001b[1;32m    101\u001b[0m     x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\u001b[38;5;241m.\u001b[39mview(B, C, H, W)\n\u001b[1;32m    103\u001b[0m     H_sp, W_sp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mH_sp, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mW_sp\n",
      "Cell \u001b[0;32mIn[2], line 100\u001b[0m, in \u001b[0;36mtorch_dynamo_resume_in_get_lepe_at_100\u001b[0;34m(___stack0, self, x, func, B, C)\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_lepe\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, func):\n\u001b[1;32m     99\u001b[0m     B, N, C \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[0;32m--> 100\u001b[0m     H \u001b[38;5;241m=\u001b[39m W \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(np\u001b[38;5;241m.\u001b[39msqrt(N))\n\u001b[1;32m    101\u001b[0m     x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\u001b[38;5;241m.\u001b[39mview(B, C, H, W)\n\u001b[1;32m    103\u001b[0m     H_sp, W_sp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mH_sp, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mW_sp\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:636\u001b[0m, in \u001b[0;36mDisableContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    634\u001b[0m prior \u001b[38;5;241m=\u001b[39m _maybe_set_eval_frame(callback)\n\u001b[1;32m    635\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 636\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    637\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    638\u001b[0m     _maybe_set_eval_frame(prior)\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1100\u001b[0m, in \u001b[0;36maot_module_simplified.<locals>.forward\u001b[0;34m(*runtime_args)\u001b[0m\n\u001b[1;32m   1098\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(params_flat)\n\u001b[1;32m   1099\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(runtime_args)\n\u001b[0;32m-> 1100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfull_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:308\u001b[0m, in \u001b[0;36m_create_runtime_wrapper.<locals>.runtime_wrapper\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m    302\u001b[0m     \u001b[38;5;66;03m# It's possible to have trace_joint inside user specified with no_grad() region,\u001b[39;00m\n\u001b[1;32m    303\u001b[0m     \u001b[38;5;66;03m# if there is a nested with enable_grad(), that forces some outputs to require gradients.\u001b[39;00m\n\u001b[1;32m    304\u001b[0m     \u001b[38;5;66;03m# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.\u001b[39;00m\n\u001b[1;32m    305\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39m_force_original_view_tracking(\n\u001b[1;32m    306\u001b[0m         \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m    307\u001b[0m     ), torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 308\u001b[0m         all_outs \u001b[38;5;241m=\u001b[39m \u001b[43mcall_func_at_runtime_with_args\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    309\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcompiled_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisable_amp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_amp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteal_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m    310\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    311\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    312\u001b[0m     \u001b[38;5;66;03m# When we have an inference graph, we run with grad disabled.\u001b[39;00m\n\u001b[1;32m    313\u001b[0m     \u001b[38;5;66;03m# It's possible to get an inference graph with inputs that require grad,\u001b[39;00m\n\u001b[1;32m    314\u001b[0m     \u001b[38;5;66;03m# in which case we want to make sure autograd is disabled\u001b[39;00m\n\u001b[1;32m    315\u001b[0m     \u001b[38;5;66;03m# (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)\u001b[39;00m\n\u001b[1;32m    316\u001b[0m     \u001b[38;5;66;03m# NOTE: We use _set_grad_enabled directly to reduce runtime overhead\u001b[39;00m\n\u001b[1;32m    317\u001b[0m     grad_enabled \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mis_grad_enabled()\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:124\u001b[0m, in \u001b[0;36mcall_func_at_runtime_with_args\u001b[0;34m(f, args, steal_args, disable_amp)\u001b[0m\n\u001b[1;32m    122\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[1;32m    123\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_boxed_call\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 124\u001b[0m         out \u001b[38;5;241m=\u001b[39m normalize_as_list(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    126\u001b[0m         \u001b[38;5;66;03m# TODO: Please remove soon\u001b[39;00m\n\u001b[1;32m    127\u001b[0m         \u001b[38;5;66;03m# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670\u001b[39;00m\n\u001b[1;32m    128\u001b[0m         warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m    129\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYour compiler for AOTAutograd is returning a function that doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt take boxed arguments. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    130\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    131\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    132\u001b[0m         )\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:98\u001b[0m, in \u001b[0;36mmake_boxed_func.<locals>.g\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m     97\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mg\u001b[39m(args):\n\u001b[0;32m---> 98\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/autograd/function.py:575\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    572\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m    573\u001b[0m     \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m    574\u001b[0m     args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 575\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m    577\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_setup_ctx_defined:\n\u001b[1;32m    578\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    579\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    580\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    581\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    582\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/main/notes/extending.func.html\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    583\u001b[0m     )\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:1525\u001b[0m, in \u001b[0;36mAOTDispatchAutograd.post_compile.<locals>.CompiledFunction.forward\u001b[0;34m(ctx, *deduped_flat_tensor_args)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     ctx\u001b[38;5;241m.\u001b[39m_compiled_autograd_backward_state \u001b[38;5;241m=\u001b[39m bw_state\n\u001b[1;32m   1518\u001b[0m \u001b[38;5;66;03m# There is a pretty complicated calling convention around what the compiled fw returns.\u001b[39;00m\n\u001b[1;32m   1519\u001b[0m \u001b[38;5;66;03m# The full list of outputs and their relative order is:\u001b[39;00m\n\u001b[1;32m   1520\u001b[0m \u001b[38;5;66;03m# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;66;03m#   in the fw output order.\u001b[39;00m\n\u001b[0;32m-> 1525\u001b[0m fw_outs \u001b[38;5;241m=\u001b[39m \u001b[43mcall_func_at_runtime_with_args\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1526\u001b[0m \u001b[43m    \u001b[49m\u001b[43mCompiledFunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompiled_fw\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1527\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1528\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdisable_amp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_amp\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1531\u001b[0m num_outputs \u001b[38;5;241m=\u001b[39m CompiledFunction\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mnum_outputs\n\u001b[1;32m   1532\u001b[0m num_outputs_aliased \u001b[38;5;241m=\u001b[39m CompiledFunction\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mnum_outputs_aliased\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:124\u001b[0m, in \u001b[0;36mcall_func_at_runtime_with_args\u001b[0;34m(f, args, steal_args, disable_amp)\u001b[0m\n\u001b[1;32m    122\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[1;32m    123\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_boxed_call\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 124\u001b[0m         out \u001b[38;5;241m=\u001b[39m normalize_as_list(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    126\u001b[0m         \u001b[38;5;66;03m# TODO: Please remove soon\u001b[39;00m\n\u001b[1;32m    127\u001b[0m         \u001b[38;5;66;03m# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670\u001b[39;00m\n\u001b[1;32m    128\u001b[0m         warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m    129\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYour compiler for AOTAutograd is returning a function that doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt take boxed arguments. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    130\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    131\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    132\u001b[0m         )\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:488\u001b[0m, in \u001b[0;36mFunctionalizedRngRuntimeWrapper.post_compile.<locals>.wrapper\u001b[0;34m(runtime_args)\u001b[0m\n\u001b[1;32m    481\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_functionalized_rng_runtime_epilogue(\n\u001b[1;32m    482\u001b[0m         runtime_metadata,\n\u001b[1;32m    483\u001b[0m         out,\n\u001b[1;32m    484\u001b[0m         \u001b[38;5;66;03m# TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper\u001b[39;00m\n\u001b[1;32m    485\u001b[0m         runtime_metadata\u001b[38;5;241m.\u001b[39mnum_forward_returns,\n\u001b[1;32m    486\u001b[0m     )\n\u001b[1;32m    487\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[0;32m--> 488\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mruntime_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:667\u001b[0m, in \u001b[0;36mEffectTokensWrapper.post_compile.<locals>.inner_fn\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m    664\u001b[0m     args \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m*\u001b[39m([\u001b[38;5;28;01mNone\u001b[39;00m] \u001b[38;5;241m*\u001b[39m num_tokens), \u001b[38;5;241m*\u001b[39margs]\n\u001b[1;32m    665\u001b[0m     old_args\u001b[38;5;241m.\u001b[39mclear()\n\u001b[0;32m--> 667\u001b[0m outs \u001b[38;5;241m=\u001b[39m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    669\u001b[0m \u001b[38;5;66;03m# Inductor cache DummyModule can return None\u001b[39;00m\n\u001b[1;32m    670\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m outs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_inductor/codecache.py:1464\u001b[0m, in \u001b[0;36mCompiledFxGraph.__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   1462\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs: List[Any]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m   1463\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcurrent_callable \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1464\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venvs/well_venv/lib/python3.10/site-packages/torch/_inductor/utils.py:1954\u001b[0m, in \u001b[0;36malign_inputs_from_check_idxs.<locals>.run\u001b[0;34m(new_inputs)\u001b[0m\n\u001b[1;32m   1952\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(new_inputs: List[InputType]):\n\u001b[1;32m   1953\u001b[0m     copy_misaligned_inputs(new_inputs, inputs_to_check)\n\u001b[0;32m-> 1954\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_inputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/tmp/torchinductor_mmccabe/yb/cybcverktskeatr3di4oirkis2nfnrtnfjqfgidpojgvz2fvtylj.py:174\u001b[0m, in \u001b[0;36mcall\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m    172\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m primals_1\n\u001b[1;32m    173\u001b[0m \u001b[38;5;66;03m# Topologically Sorted Source Nodes: [lepe], Original ATen: [aten.convolution]\u001b[39;00m\n\u001b[0;32m--> 174\u001b[0m buf2 \u001b[38;5;241m=\u001b[39m \u001b[43mextern_kernels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvolution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuf1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprimals_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdilation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransposed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_padding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroups\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m384\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    175\u001b[0m assert_size_stride(buf2, (\u001b[38;5;241m32\u001b[39m, \u001b[38;5;241m384\u001b[39m, \u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m256\u001b[39m), (\u001b[38;5;241m786432\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m98304\u001b[39m, \u001b[38;5;241m384\u001b[39m))\n\u001b[1;32m    176\u001b[0m buf3 \u001b[38;5;241m=\u001b[39m empty_strided_cuda((\u001b[38;5;241m32\u001b[39m, \u001b[38;5;241m6\u001b[39m, \u001b[38;5;241m2048\u001b[39m, \u001b[38;5;241m64\u001b[39m), (\u001b[38;5;241m786432\u001b[39m, \u001b[38;5;241m131072\u001b[39m, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m1\u001b[39m), torch\u001b[38;5;241m.\u001b[39mfloat32)\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 82.50 MiB is free. Including non-PyTorch memory, this process has 47.06 GiB memory in use. Of the allocated memory 46.09 GiB is allocated by PyTorch, and 667.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
     ]
    }
   ],
   "source": [
    "%time m2(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SWIN\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "os.environ[\"TORCHDYNAMO_VERBOSE\"] = \"1\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# from timm.models.registry import register_model\n",
    "\n",
    "from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb\n",
    "from torch.profiler import profile, record_function, ProfilerActivity\n",
    "\n",
    "from functools import lru_cache\n",
    "\n",
    "\n",
    "flex_attention = torch.compile(flex_attention, dynamic=False)\n",
    "\n",
    "def build_space_block(params):\n",
    "    #input_resolution = (\n",
    "    #    params.img_size[0] // params.patch_size[0],\n",
    "    #    params.img_size[1] // params.patch_size[0],\n",
    "    #)\n",
    "\n",
    "    if params.space_type == \"swin\":\n",
    "\n",
    "        return partial(\n",
    "            SwinTransformerBlock,\n",
    "            dim=params.embed_dim,\n",
    "            num_heads=params.num_heads,\n",
    "            #input_resolution=input_resolution,\n",
    "            window_size=params.window_size,\n",
    "        )\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "\n",
    "class Mlp(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_features,\n",
    "        hidden_features=None,\n",
    "        out_features=None,\n",
    "        act_layer=nn.GELU,\n",
    "        drop=0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        out_features = out_features or in_features\n",
    "        hidden_features = hidden_features or in_features\n",
    "        self.fc1 = nn.Linear(in_features, hidden_features)\n",
    "        self.act = act_layer()\n",
    "        self.fc2 = nn.Linear(hidden_features, out_features)\n",
    "        self.drop = nn.Dropout(drop)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.drop(self.act(self.fc1(x)))\n",
    "        return self.drop(self.fc2(x))\n",
    "\n",
    "\n",
    "def window_partition(x, window_size):\n",
    "    B, H, W, C = x.shape\n",
    "    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n",
    "    #return x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)\n",
    "    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n",
    "    return windows\n",
    "\n",
    "def window_partition_flex(x, window_size):\n",
    "    B, H, W, C = x.shape\n",
    "    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n",
    "    #return x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)\n",
    "    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, window_size, window_size, C)\n",
    "    return windows\n",
    "\n",
    "\n",
    "def window_reverse(windows, window_size, H, W):\n",
    "    B = windows.shape[0] // (H * W // window_size // window_size)\n",
    "    x = windows.view(\n",
    "        B, H // window_size, W // window_size, window_size, window_size, -1\n",
    "    )\n",
    "    #return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, -1)\n",
    "    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n",
    "    return x\n",
    "\n",
    "def window_reverse_flex(windows, window_size, H, W):\n",
    "    \n",
    "    B = windows.shape[0] #// (H * W // window_size // window_size)\n",
    "    x = windows.view(\n",
    "        B, H // window_size, W // window_size, window_size, window_size, -1\n",
    "    )\n",
    "    #return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, -1)\n",
    "    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n",
    "    return x\n",
    "\n",
    "\n",
    "class WindowAttention(nn.Module):\n",
    "    r\"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n",
    "    It supports both of shifted and non-shifted window.\n",
    "\n",
    "    Args:\n",
    "        dim (int): Number of input channels.\n",
    "        window_size (tuple[int]): The height and width of the window.\n",
    "        num_heads (int): Number of attention heads.\n",
    "        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n",
    "        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n",
    "        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n",
    "        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        window_size,\n",
    "        num_heads,\n",
    "        qkv_bias=True,\n",
    "        attn_drop=0.0,\n",
    "        proj_drop=0.0,\n",
    "        pretrained_window_size=[0, 0],\n",
    "    ):\n",
    "\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.window_size = window_size  # Wh, Ww\n",
    "        self.pretrained_window_size = pretrained_window_size\n",
    "        self.num_heads = num_heads\n",
    "\n",
    "        self.logit_scale = nn.Parameter(\n",
    "            torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True\n",
    "        )\n",
    "\n",
    "        # mlp to generate continuous relative position bias\n",
    "        self.cpb_mlp = nn.Sequential(\n",
    "            nn.Linear(2, 512, bias=True),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(512, num_heads, bias=False),\n",
    "        )\n",
    "\n",
    "        # get relative_coords_table\n",
    "        relative_coords_h = torch.arange(\n",
    "            -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32\n",
    "        )\n",
    "        relative_coords_w = torch.arange(\n",
    "            -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32\n",
    "        )\n",
    "        relative_coords_table = (\n",
    "            torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))\n",
    "            .permute(1, 2, 0)\n",
    "            .contiguous()\n",
    "            .unsqueeze(0)\n",
    "        )  # 1, 2*Wh-1, 2*Ww-1, 2\n",
    "        if pretrained_window_size[0] > 0:\n",
    "            relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1\n",
    "            relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1\n",
    "        else:\n",
    "            relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1\n",
    "            relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1\n",
    "        relative_coords_table *= 8  # normalize to -8, 8\n",
    "        relative_coords_table = (\n",
    "            torch.sign(relative_coords_table)\n",
    "            * torch.log2(torch.abs(relative_coords_table) + 1.0)\n",
    "            / np.log2(8)\n",
    "        )\n",
    "\n",
    "        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n",
    "        # print(\"RPE init\", relative_coords_table)\n",
    "\n",
    "        # get pair-wise relative position index for each token inside the window\n",
    "        coords_h = torch.arange(self.window_size[0])\n",
    "        coords_w = torch.arange(self.window_size[1])\n",
    "        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n",
    "        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n",
    "        relative_coords = (\n",
    "            coords_flatten[:, :, None] - coords_flatten[:, None, :]\n",
    "        )  # 2, Wh*Ww, Wh*Ww\n",
    "        relative_coords = relative_coords.permute(\n",
    "            1, 2, 0\n",
    "        ).contiguous()  # Wh*Ww, Wh*Ww, 2\n",
    "        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n",
    "        relative_coords[:, :, 1] += self.window_size[1] - 1\n",
    "        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n",
    "        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n",
    "        self.register_buffer(\"relative_position_index\", relative_position_index)\n",
    "\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n",
    "        if qkv_bias:\n",
    "            self.q_bias = nn.Parameter(torch.zeros(dim))\n",
    "            self.v_bias = nn.Parameter(torch.zeros(dim))\n",
    "        else:\n",
    "            self.q_bias = None\n",
    "            self.v_bias = None\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(proj_drop)\n",
    "        self.softmax = nn.Softmax(dim=-1)\n",
    "\n",
    "    def forward(self, x, mask=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: input features with shape of (num_windows*B, N, C)\n",
    "            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n",
    "        \"\"\"\n",
    "        B_, N, C = x.shape\n",
    "        qkv_bias = None\n",
    "        if self.q_bias is not None:\n",
    "            qkv_bias = torch.cat(\n",
    "                (\n",
    "                    self.q_bias,\n",
    "                    torch.zeros_like(self.v_bias, requires_grad=False),\n",
    "                    self.v_bias,\n",
    "                )\n",
    "            )\n",
    "        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n",
    "        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n",
    "        q, k, v = qkv[0], qkv[1], qkv[2]\n",
    "\n",
    "        # cosine attention\n",
    "        attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)\n",
    "        # logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()\n",
    "        max_value = torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device))\n",
    "        logit_scale = torch.clamp(self.logit_scale, max=max_value).exp()\n",
    "        attn = attn * logit_scale\n",
    "\n",
    "        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table)\n",
    "        # print(relative_position_bias_table.shape)\n",
    "        relative_position_bias_table = relative_position_bias_table.view(\n",
    "            -1, self.num_heads\n",
    "        )\n",
    "        # print(\"INITIAL RPE\", self.relative_coords_table.shape, relative_position_bias_table.shape, self.relative_position_index.shape)\n",
    "        relative_position_bias = relative_position_bias_table[\n",
    "            self.relative_position_index.view(-1)\n",
    "        ].view(\n",
    "            self.window_size[0] * self.window_size[1],\n",
    "            self.window_size[0] * self.window_size[1],\n",
    "            -1,\n",
    "        )  # Wh*Ww,Wh*Ww,nH\n",
    "        relative_position_bias = relative_position_bias.permute(\n",
    "            2, 0, 1\n",
    "        ).contiguous()  # nH, Wh*Ww, Wh*Ww\n",
    "        # print(\"FINAL RPE\", relative_position_bias.shape, attn.shape)\n",
    "        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n",
    "        attn = attn + relative_position_bias.unsqueeze(0)\n",
    "\n",
    "        if mask is not None:\n",
    "            nW = mask.shape[0]\n",
    "            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(\n",
    "                1\n",
    "            ).unsqueeze(0)\n",
    "            attn = attn.view(-1, self.num_heads, N, N)\n",
    "            attn = self.softmax(attn)\n",
    "        else:\n",
    "            attn = self.softmax(attn)\n",
    "\n",
    "        attn = self.attn_drop(attn)\n",
    "\n",
    "        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n",
    "        x = self.proj(x)\n",
    "        x = self.proj_drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class SwinTransformerBlock(nn.Module):\n",
    "    r\"\"\"Swin Transformer Block.\n",
    "\n",
    "    Args:\n",
    "        dim (int): Number of input channels.\n",
    "        num_heads (int): Number of attention heads.\n",
    "        input_resolution (tuple[int]): Input resulotion.\n",
    "        window_size (int): Window size.\n",
    "        shift_size (int): Shift size for SW-MSA.\n",
    "        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n",
    "        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n",
    "        drop (float, optional): Dropout rate. Default: 0.0\n",
    "        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n",
    "        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n",
    "        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n",
    "        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n",
    "        pretrained_window_size (int): Window size in pre-training.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads,\n",
    "        #input_resolution=224,\n",
    "        window_size=7,\n",
    "        shift_size=0,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=True,\n",
    "        drop=0.0,\n",
    "        attn_drop=0.0,\n",
    "        drop_path=0.0,\n",
    "        act_layer=nn.GELU,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "        pretrained_window_size=0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        #self.input_resolution = to_2tuple(input_resolution)\n",
    "        self.num_heads = num_heads\n",
    "        self.window_size = window_size\n",
    "        self.shift_size = shift_size\n",
    "        self.mlp_ratio = mlp_ratio\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = WindowAttention(\n",
    "            dim,\n",
    "            window_size=to_2tuple(self.window_size),\n",
    "            num_heads=num_heads,\n",
    "            qkv_bias=qkv_bias,\n",
    "            attn_drop=attn_drop,\n",
    "            proj_drop=drop,\n",
    "            pretrained_window_size=to_2tuple(pretrained_window_size),\n",
    "        )\n",
    "\n",
    "        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n",
    "        self.norm2 = norm_layer(dim)\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "        self.mlp = Mlp(\n",
    "            in_features=dim,\n",
    "            hidden_features=mlp_hidden_dim,\n",
    "            act_layer=act_layer,\n",
    "            drop=drop,\n",
    "        )\n",
    "\n",
    "    def create_attn_mask(self, H, W, window_size, shift_size, bcs, device):\n",
    "\n",
    "        img_mask = torch.zeros((1, H, W, 1), device=device)  # 1 H W 1\n",
    "\n",
    "        if shift_size > 0:\n",
    "            h_slices = (\n",
    "                slice(0, -window_size),\n",
    "                slice(-window_size, -shift_size),\n",
    "                slice(-shift_size, None),\n",
    "            )\n",
    "            w_slices = (\n",
    "                slice(0, -window_size),\n",
    "                slice(-window_size, -shift_size),\n",
    "                slice(-shift_size, None),\n",
    "            )\n",
    "            cnt = 0\n",
    "            if (bcs[0, 0] == 0) and (bcs[0, 1] == 0):\n",
    "                for h in h_slices:\n",
    "                    for w in w_slices:\n",
    "                        img_mask[:, h, w, :] = cnt\n",
    "                        cnt += 1\n",
    "            elif (bcs[0, 0] == 1) and (bcs[0, 1] == 0):\n",
    "                for w in w_slices:\n",
    "                    img_mask[:, :, w, :] = cnt\n",
    "                    cnt += 1\n",
    "            elif (bcs[0, 0] == 0) and (bcs[0, 1] == 1):\n",
    "                for h in h_slices:\n",
    "                    img_mask[:, h, :, :] = cnt\n",
    "                    cnt += 1\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "            mask_windows = window_partition(\n",
    "                img_mask, window_size\n",
    "            )  # nW, window_size, window_size, 1\n",
    "            mask_windows = mask_windows.reshape(-1, window_size * window_size)\n",
    "            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n",
    "            attn_mask = attn_mask.masked_fill(\n",
    "                attn_mask != 0, float(-100.0)\n",
    "            ).masked_fill(attn_mask == 0, float(0.0))\n",
    "        else:\n",
    "            attn_mask = None\n",
    "\n",
    "        return attn_mask\n",
    "\n",
    "    def forward(self, x, bcs, resolution):\n",
    "\n",
    "        #H, W = self.input_resolution\n",
    "        H, W = resolution\n",
    "        B, L, C = x.shape\n",
    "\n",
    "        assert L == H * W, \"input feature has wrong size\"\n",
    "        device = x.device\n",
    "        attn_mask = self.create_attn_mask(\n",
    "            H, W, self.window_size, self.shift_size, bcs, device\n",
    "        )\n",
    "\n",
    "        shortcut = x\n",
    "        # x = rearrange(x, 'b (h w) c -> b h w c', h=H)\n",
    "        x = x.view(B, H, W, C)\n",
    "\n",
    "        # cyclic shift\n",
    "        if self.shift_size > 0:\n",
    "            shifted_x = torch.roll(\n",
    "                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)\n",
    "            )\n",
    "        else:\n",
    "            shifted_x = x\n",
    "\n",
    "        # partition windows\n",
    "        x_windows = window_partition(\n",
    "            shifted_x, self.window_size\n",
    "        )  # nW*B, window_size, window_size, C\n",
    "        x_windows = x_windows.view(\n",
    "            -1, self.window_size * self.window_size, C\n",
    "        )  # nW*B, window_size*window_size, C\n",
    "\n",
    "        # W-MSA/SW-MSA\n",
    "        attn_windows = self.attn(\n",
    "            x_windows, mask=attn_mask\n",
    "        )  # nW*B, window_size*window_size, C\n",
    "\n",
    "        # merge windows\n",
    "        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n",
    "        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n",
    "\n",
    "        # reverse cyclic shift\n",
    "        if self.shift_size > 0:\n",
    "            x = torch.roll(\n",
    "                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)\n",
    "            )\n",
    "        else:\n",
    "            x = shifted_x\n",
    "\n",
    "        # x = rearrange(x, 'b h w c -> b (h w) c')\n",
    "        x = x.view(B, H * W, C)\n",
    "        x = shortcut + self.drop_path(self.norm1(x))\n",
    "\n",
    "        # FFN\n",
    "        x = x + self.drop_path(self.norm2(self.mlp(x)))\n",
    "\n",
    "        return x\n",
    "    \n",
    "# def build_swin_mask_and_window_block()\n",
    "LN_100 = np.log(100)\n",
    "\n",
    "# def create_swin_score_mod(H, W):\n",
    "#     def swin_score_mod(score, b, h, q_idx, kv_idx):\n",
    "#         qx, qy = q_idx % W, q_idx // W\n",
    "#         kvx, kvy = kv_idx % W, kv_idx // W\n",
    "#         # scale = torch.clamp(logit_scale[h], max=LN_100).exp()\n",
    "#         return score + 16 * torch.sigmoid(relative_position_scores[kvy-qy + H//2, kvx-qx+W//2, h])\n",
    "#     return swin_score_mod\n",
    "\n",
    "@lru_cache\n",
    "def create_block_mask_cached(score_mod, B, H, M, N, device=\"cuda\"):\n",
    "    block_mask = create_block_mask(score_mod, B, H, M, N, device=device)\n",
    "    return block_mask\n",
    "\n",
    "\n",
    "def create_swindow_mask(H, W, window_size, shift_size, bcs, device=\"cuda\"):\n",
    "    # @lru_cache()\n",
    "    def swindow_mask(b, h, q_idx, kv_idx):\n",
    "        mask_size = window_size**2\n",
    "\n",
    "        q_window = q_idx // mask_size\n",
    "        kv_window = kv_idx // mask_size\n",
    "        # return q_window == kv_window\n",
    "        qlocal_idx = q_idx % mask_size\n",
    "        kvlocal_idx = kv_idx % mask_size\n",
    "        \n",
    "        windows_per_row = W // window_size\n",
    "        windows_per_col = H // window_size\n",
    "        # Get 2D block coordinates\n",
    "        q_block_x = q_window % windows_per_row\n",
    "        q_block_y = q_window // windows_per_row\n",
    "        kv_block_x = kv_window % windows_per_row\n",
    "        kv_block_y = kv_window // windows_per_row\n",
    "        # Now get the local coordinates within the block\n",
    "        qx = q_block_x * window_size + qlocal_idx % window_size\n",
    "        qy = q_block_y * window_size + qlocal_idx // window_size\n",
    "        kvx = kv_block_x * window_size + kvlocal_idx % window_size\n",
    "        kvy = kv_block_y * window_size + kvlocal_idx // window_size\n",
    "\n",
    "        # print(q_idx, 'block_id:', q_window, ', block x:', q_block_x, ', x:', qx, ', block y:', q_block_y, ', y:', qy)\n",
    "        # qx, qy = q_idx % W, q_idx // W\n",
    "        # kvx, kvy = kv_idx % W, kv_idx // W\n",
    "        x_mask = qx // window_size == kvx // window_size\n",
    "        y_mask = qy // window_size == kvy // window_size\n",
    "        # mask = q_idx // mask_size == kv_idx // mask_size\n",
    "        if shift_size > 0:\n",
    "            if bcs[0, 0] == 0:\n",
    "                y_mask &= kvx < H - shift_size\n",
    "                y_mask &= qx < H - shift_size\n",
    "            elif bcs[0, 1] == 0:\n",
    "                x_mask &= kvx < W - shift_size\n",
    "                y_mask &= qx < H - shift_size\n",
    "        return x_mask & y_mask\n",
    "    return swindow_mask\n",
    "\n",
    "class WindowAttentionFlex(nn.Module):\n",
    "    r\"\"\"Window based multi-head self attention (W-MSA) module with relative position bias.\n",
    "    It supports both of shifted and non-shifted window.\n",
    "\n",
    "    Args:\n",
    "        dim (int): Number of input channels.\n",
    "        window_size (tuple[int]): The height and width of the window.\n",
    "        num_heads (int): Number of attention heads.\n",
    "        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n",
    "        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n",
    "        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n",
    "        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        window_size,\n",
    "        num_heads,\n",
    "        qkv_bias=True,\n",
    "        attn_drop=0.0,\n",
    "        proj_drop=0.0,\n",
    "        pretrained_window_size=[0, 0],\n",
    "    ):\n",
    "\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.window_size = window_size  # Wh, Ww\n",
    "        self.pretrained_window_size = pretrained_window_size\n",
    "        self.num_heads = num_heads\n",
    "\n",
    "        self.logit_scale = nn.Parameter(\n",
    "            torch.log(10 * torch.ones((num_heads, 1, 1, 1))), requires_grad=True\n",
    "        )\n",
    "        # swindow_mask = create_swindow_mask(64, 64, 8, 0, torch.tensor([0, 0]))\n",
    "        # self.swin_mask = create_block_mask(swindow_mask, None, None, 256**2, 256**2, BLOCK_SIZE = 256,\n",
    "        #                                    device=\"cuda\")\n",
    "        # mlp to generate continuous relative position bias\n",
    "        self.cpb_mlp = nn.Sequential(\n",
    "            nn.Linear(2, 512, bias=True),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(512, num_heads, bias=False),\n",
    "        )\n",
    "        # get relative_coords_table\n",
    "        # relative_coords_h = torch.arange(\n",
    "        #     -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32\n",
    "        # )\n",
    "        # relative_coords_w = torch.arange(\n",
    "        #     -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32\n",
    "        # )\n",
    "        # relative_coords_table = (\n",
    "        #     torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))\n",
    "        #     .permute(1, 2, 0)\n",
    "        #     .contiguous()\n",
    "        #     .unsqueeze(0)\n",
    "        # )  # 1, 2*Wh-1, 2*Ww-1, 2\n",
    "        # if pretrained_window_size[0] > 0:\n",
    "        #     relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1\n",
    "        #     relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1\n",
    "        # else:\n",
    "        #     relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1\n",
    "        #     relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1\n",
    "        # relative_coords_table *= 8  # normalize to -8, 8\n",
    "        # relative_coords_table = (\n",
    "        #     torch.sign(relative_coords_table)\n",
    "        #     * torch.log2(torch.abs(relative_coords_table) + 1.0)\n",
    "        #     / np.log2(8)\n",
    "        # )\n",
    "\n",
    "        # self.register_buffer(\"relative_coords_table\", relative_coords_table)\n",
    "\n",
    "        # # get pair-wise relative position index for each token inside the window\n",
    "        # coords_h = torch.arange(self.window_size[0])\n",
    "        # coords_w = torch.arange(self.window_size[1])\n",
    "        # coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n",
    "        # coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n",
    "        # relative_coords = (\n",
    "        #     coords_flatten[:, :, None] - coords_flatten[:, None, :]\n",
    "        # )  # 2, Wh*Ww, Wh*Ww\n",
    "        # relative_coords = relative_coords.permute(\n",
    "        #     1, 2, 0\n",
    "        # ).contiguous()  # Wh*Ww, Wh*Ww, 2\n",
    "        # relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n",
    "        # relative_coords[:, :, 1] += self.window_size[1] - 1\n",
    "        # relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n",
    "        # relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n",
    "        # self.register_buffer(\"relative_position_index\", relative_position_index)\n",
    "        self.rope = RotaryEmbedding(dim // (num_heads * len(self.window_size)))\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n",
    "        if qkv_bias:\n",
    "            self.q_bias = nn.Parameter(torch.zeros(dim))\n",
    "            self.v_bias = nn.Parameter(torch.zeros(dim))\n",
    "        else:\n",
    "            self.q_bias = None\n",
    "            self.v_bias = None\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(proj_drop)\n",
    "        self.softmax = nn.Softmax(dim=-1)\n",
    "\n",
    "    def forward(self, x, mask=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: input features with shape of (num_windows*B, N, C)\n",
    "            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n",
    "        \"\"\"\n",
    "        # print(\"starting flex\", x.shape)\n",
    "        B_, Nw, N, C = x.shape\n",
    "        qkv_bias = None\n",
    "        if self.q_bias is not None:\n",
    "            qkv_bias = torch.cat(\n",
    "                (\n",
    "                    self.q_bias,\n",
    "                    torch.zeros_like(self.v_bias, requires_grad=False),\n",
    "                    self.v_bias,\n",
    "                )\n",
    "            )\n",
    "        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n",
    "        qkv = qkv.reshape(B_, Nw, N, 3, self.num_heads, -1).permute(3, 0, 4, 1, 2, 5)\n",
    "        q, k, v = qkv[0], qkv[1], qkv[2]\n",
    "\n",
    "\n",
    "        # cosine attention\n",
    "        q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)\n",
    "        q = torch.clamp(self.logit_scale, LN_100).exp() * q\n",
    "        \n",
    "        freqs = self.rope.get_axial_freqs(*self.window_size)\n",
    "\n",
    "        q =  q.view(B_, self.num_heads, Nw, *self.window_size, -1)\n",
    "        k = k.view(B_, self.num_heads, Nw, *self.window_size, -1)\n",
    "        q, k = apply_rotary_emb(freqs, q), apply_rotary_emb(freqs, k)\n",
    "        q, k = q.view(B_, self.num_heads,  Nw*N, -1), k.view(B_, self.num_heads,  Nw*N, -1)\n",
    "        v = v.view(B_, self.num_heads, Nw*N, -1)\n",
    "\n",
    "\n",
    "        # attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)\n",
    "        # logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()\n",
    "        # max_value = torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device))\n",
    "        # logit_scale = torch.clamp(self.logit_scale, max=max_value).exp()\n",
    "        # attn = attn * logit_scale\n",
    "        \n",
    "\n",
    "        # relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).squeeze()\n",
    "        # score_mod = create_swin_score_mod(relative_position_bias_table, self.logit_scale, self.window_size[0], self.window_size[1])\n",
    "        \n",
    "        #.view(\n",
    "        #    -1, self.num_heads\n",
    "        #)\n",
    "        # print('RPE stats', relative_position_bias_table.shape, self.relative_coords_table.shape)\n",
    "        # relative_position_bias = relative_position_bias_table[\n",
    "        #     self.relative_position_index.view(-1)\n",
    "        # ].view(\n",
    "        #     self.window_size[0] * self.window_size[1],\n",
    "        #     self.window_size[0] * self.window_size[1],\n",
    "        #     -1,\n",
    "        # )  # Wh*Ww,Wh*Ww,nH\n",
    "        # relative_position_bias = relative_position_bias.permute(\n",
    "        #     2, 0, 1\n",
    "        # ).contiguous()  # nH, Wh*Ww, Wh*Ww\n",
    "        # relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n",
    "        # print(\"Final RPE\", relative_position_bias.shape, attn.shape)\n",
    "        # attn = attn + relative_position_bias.unsqueeze(0)\n",
    "\n",
    "        # if mask is not None:\n",
    "        #     nW = mask.shape[0]\n",
    "        #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(\n",
    "        #         1\n",
    "        #     ).unsqueeze(0)\n",
    "        #     attn = attn.view(-1, self.num_heads, N, N)\n",
    "        #     attn = self.softmax(attn)\n",
    "        # else:\n",
    "        #     attn = self.softmax(attn)\n",
    "\n",
    "        # attn = self.attn_drop(attn)\n",
    "        # print shapes of all major tensors\n",
    "        # print(\"q\", q.shape, \"k\", k.shape, \"v\", v.shape)\n",
    "        x = flex_attention(q, k, v, block_mask=mask).reshape(B_, Nw, N, C)\n",
    "        # x = F.scaled_dot_product_attention(q, k, v).transpose(-1, -2).reshape(B_, Nw, N, C)\n",
    "\n",
    "        # x = (attn @ v)\n",
    "        x = self.proj(x)\n",
    "        x = self.proj_drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class SwinTransformerBlockFlex(nn.Module):\n",
    "    r\"\"\"Swin Transformer Block.\n",
    "\n",
    "    Args:\n",
    "        dim (int): Number of input channels.\n",
    "        num_heads (int): Number of attention heads.\n",
    "        input_resolution (tuple[int]): Input resulotion.\n",
    "        window_size (int): Window size.\n",
    "        shift_size (int): Shift size for SW-MSA.\n",
    "        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n",
    "        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n",
    "        drop (float, optional): Dropout rate. Default: 0.0\n",
    "        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n",
    "        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n",
    "        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n",
    "        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n",
    "        pretrained_window_size (int): Window size in pre-training.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads,\n",
    "        #input_resolution=224,\n",
    "        window_size=7,\n",
    "        shift_size=0,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=True,\n",
    "        drop=0.0,\n",
    "        attn_drop=0.0,\n",
    "        drop_path=0.0,\n",
    "        act_layer=nn.GELU,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "        pretrained_window_size=0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        #self.input_resolution = to_2tuple(input_resolution)\n",
    "        self.num_heads = num_heads\n",
    "        self.window_size = window_size\n",
    "        self.shift_size = shift_size\n",
    "        self.mlp_ratio = mlp_ratio\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = WindowAttentionFlex(\n",
    "            dim,\n",
    "            window_size=to_2tuple(self.window_size),\n",
    "            num_heads=num_heads,\n",
    "            qkv_bias=qkv_bias,\n",
    "            attn_drop=attn_drop,\n",
    "            proj_drop=drop,\n",
    "            pretrained_window_size=to_2tuple(pretrained_window_size),\n",
    "        )\n",
    "\n",
    "        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n",
    "        self.norm2 = norm_layer(dim)\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "        self.mlp = Mlp(\n",
    "            in_features=dim,\n",
    "            hidden_features=mlp_hidden_dim,\n",
    "            act_layer=act_layer,\n",
    "            drop=drop,\n",
    "        )\n",
    "\n",
    "    def create_attn_mask(self, H, W, window_size, shift_size, bcs, device):\n",
    "\n",
    "        img_mask = torch.zeros((1, H, W, 1), device=device)  # 1 H W 1\n",
    "\n",
    "        if shift_size > 0:\n",
    "            h_slices = (\n",
    "                slice(0, -window_size),\n",
    "                slice(-window_size, -shift_size),\n",
    "                slice(-shift_size, None),\n",
    "            )\n",
    "            w_slices = (\n",
    "                slice(0, -window_size),\n",
    "                slice(-window_size, -shift_size),\n",
    "                slice(-shift_size, None),\n",
    "            )\n",
    "            cnt = 0\n",
    "            if (bcs[0, 0] == 0) and (bcs[0, 1] == 0):\n",
    "                for h in h_slices:\n",
    "                    for w in w_slices:\n",
    "                        img_mask[:, h, w, :] = cnt\n",
    "                        cnt += 1\n",
    "            elif (bcs[0, 0] == 1) and (bcs[0, 1] == 0):\n",
    "                for w in w_slices:\n",
    "                    img_mask[:, :, w, :] = cnt\n",
    "                    cnt += 1\n",
    "            elif (bcs[0, 0] == 0) and (bcs[0, 1] == 1):\n",
    "                for h in h_slices:\n",
    "                    img_mask[:, h, :, :] = cnt\n",
    "                    cnt += 1\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "            mask_windows = window_partition(\n",
    "                img_mask, window_size\n",
    "            )  # nW, window_size, window_size, 1\n",
    "            mask_windows = mask_windows.reshape(-1, window_size * window_size)\n",
    "            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n",
    "            attn_mask = attn_mask.masked_fill(\n",
    "                attn_mask != 0, float(-100.0)\n",
    "            ).masked_fill(attn_mask == 0, float(0.0))\n",
    "        else:\n",
    "            attn_mask = None\n",
    "\n",
    "        return attn_mask\n",
    "\n",
    "    def forward(self, x, bcs, resolution, mask=None):\n",
    "\n",
    "        #H, W = self.input_resolution\n",
    "        H, W = resolution\n",
    "        B, L, C = x.shape\n",
    "\n",
    "        assert L == H * W, \"input feature has wrong size\"\n",
    "        device = x.device\n",
    "        # attn_mask = self.create_attn_mask(\n",
    "        #     H, W, self.window_size, self.shift_size, bcs, device\n",
    "        # )\n",
    "        # swin_mask = create_swindow_mask(H, W, self.window_size, self.shift_size, bcs, device)\n",
    "        # swin_mask = create_block_mask(swin_mask, None, None, H*W, H*W, BLOCK_SIZE = self.window_size**2,\n",
    "        #                      device=device)\n",
    "        shortcut = x\n",
    "        # x = rearrange(x, 'b (h w) c -> b h w c', h=H)\n",
    "        x = x.view(B, H, W, C)\n",
    "\n",
    "        # cyclic shift\n",
    "        if self.shift_size > 0:\n",
    "            shifted_x = torch.roll(\n",
    "                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)\n",
    "            )\n",
    "        else:\n",
    "            shifted_x = x\n",
    "\n",
    "        # partition windows\n",
    "        x_windows = window_partition_flex(\n",
    "            shifted_x, self.window_size\n",
    "        )  # B, nW, window_size, window_size, C\n",
    "        x_windows = x_windows.view(\n",
    "            B, -1, self.window_size * self.window_size, C\n",
    "        )  # B, nW, window_size*window_size, C\n",
    "\n",
    "        # W-MSA/SW-MSA\n",
    "        attn_windows = self.attn(\n",
    "            x_windows, mask=mask\n",
    "        )  # B, nW, window_size*window_size, C\n",
    "\n",
    "        # merge windows\n",
    "        attn_windows = attn_windows.view(B, -1, self.window_size, self.window_size, C)\n",
    "        shifted_x = window_reverse_flex(attn_windows, self.window_size, H, W)  # B H' W' C\n",
    "\n",
    "        # reverse cyclic shift\n",
    "        if self.shift_size > 0:\n",
    "            x = torch.roll(\n",
    "                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)\n",
    "            )\n",
    "        else:\n",
    "            x = shifted_x\n",
    "\n",
    "        # x = rearrange(x, 'b h w c -> b (h w) c')\n",
    "        x = x.view(B, H * W, C)\n",
    "        x = shortcut + self.drop_path(self.norm1(x))\n",
    "\n",
    "        # FFN\n",
    "        x = x + self.drop_path(self.norm2(self.mlp(x)))\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/home/mmccabe/venvs/well_venv/lib/python3.10/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UserFunctionVariable()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 4096, 768]), torch.Size([1, 4096, 768]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resolution = 64\n",
    "window_size = 8\n",
    "m = SwinTransformerBlockFlex(768, 12, window_size).cuda()\n",
    "m2 = SwinTransformerBlock(768, 12, window_size).cuda()\n",
    "x = torch.randn(1, resolution**2, 768).cuda()\n",
    "BCs = torch.tensor([[0, 0]]).cuda()\n",
    "swindow_mask = create_swindow_mask(resolution, resolution, window_size, 0, BCs, device=\"cuda\")\n",
    "\n",
    "mask = create_block_mask(swindow_mask, None, None, resolution**2, resolution**2, BLOCK_SIZE = window_size**2,\n",
    "                                           device=\"cuda\", _compile=True)\n",
    "\n",
    "\n",
    "m(x, BCs, (resolution, resolution), mask).shape, m2(x, BCs, (resolution, resolution)).shape\n",
    "\n",
    "# m(x, BCs, (resolution, resolution))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BlockMask(shape=(1, 1, 8192, 8192), sparsity=98.44%, \n",
      "(0, 0)\n",
      "░░                              \n",
      "  ░░                            \n",
      "    ░░                          \n",
      "      ░░                        \n",
      "        ░░                      \n",
      "          ░░                    \n",
      "            ░░                  \n",
      "              ░░                \n",
      "                ░░              \n",
      "                  ░░            \n",
      "                    ░░          \n",
      "                      ░░        \n",
      "                        ░░      \n",
      "                          ░░    \n",
      "                            ░░  \n",
      "                              ░░\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!export TORCHDYNAMO_VERBOSE=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 11.1 ms, sys: 1.04 ms, total: 12.1 ms\n",
      "Wall time: 11.2 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.8911, -2.0237,  0.0874,  ...,  1.8481,  1.9463, -1.6693],\n",
       "         [ 2.3655, -1.3664,  2.8995,  ...,  2.5285,  0.8015, -1.1482],\n",
       "         [ 1.0134, -1.1104, -1.7366,  ..., -0.5531,  5.1501,  0.4722],\n",
       "         ...,\n",
       "         [ 2.4366, -0.6986,  0.0683,  ...,  1.6818,  1.6937, -2.3127],\n",
       "         [ 1.4351,  2.2723,  1.6191,  ..., -1.1244,  2.6846, -1.7992],\n",
       "         [ 2.8857,  1.1299, -0.7580,  ..., -1.2703,  1.9062, -0.0733]]],\n",
       "       device='cuda:0', grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time m(x, BCs, (resolution, resolution))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 9.78 ms, sys: 1.21 ms, total: 11 ms\n",
      "Wall time: 10.1 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.7518, -0.1979, -0.2501,  ...,  2.0128, -0.7517, -2.0941],\n",
       "         [-0.0382,  1.1054, -0.1532,  ...,  0.9081, -1.3076, -0.3362],\n",
       "         [ 0.4059,  2.5880, -1.6139,  ..., -1.6082,  0.2722,  0.8511],\n",
       "         ...,\n",
       "         [ 1.7862,  2.9913,  0.8791,  ...,  2.2209, -2.2367, -1.7994],\n",
       "         [-1.8505,  3.8987, -0.4162,  ..., -0.6142, -1.0749, -2.2390],\n",
       "         [-0.6067,  3.0438,  0.8687,  ..., -0.4416, -0.7304, -1.5081]]],\n",
       "       device='cuda:0', grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "m2(x, BCs, (resolution, resolution))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: an illegal memory access was encountered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresolution\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m768\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mreset_peak_memory_stats()\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m profile(activities\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m      4\u001b[0m         ProfilerActivity\u001b[38;5;241m.\u001b[39mCPU, ProfilerActivity\u001b[38;5;241m.\u001b[39mCUDA], record_shapes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, profile_memory\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m prof:\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: an illegal memory access was encountered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1, resolution**2, 768).cuda()\n",
    "torch.cuda.reset_peak_memory_stats()\n",
    "with profile(activities=[\n",
    "        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        x = m(x, BCs, (resolution, resolution), mask)\n",
    "        x = m(x, BCs, (resolution, resolution), mask)\n",
    "        x = m(x, BCs, (resolution, resolution), mask)\n",
    "        y = m(x, BCs, (resolution, resolution), mask)\n",
    "        z = y.sum()\n",
    "        z.backward()\n",
    "        mem = torch.cuda.max_memory_allocated() / 1024**3\n",
    "\n",
    "print('Max mem', mem, 'GB')\n",
    "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n",
    "m.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "AUTOTUNE mm(961x12, 12x512)\n",
      "  triton_mm_973 0.0075 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8\n",
      "  triton_mm_977 0.0075 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8\n",
      "  triton_mm_971 0.0075 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=1, num_warps=2\n",
      "  triton_mm_972 0.0076 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_974 0.0078 ms 95.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8\n",
      "  triton_mm_975 0.0078 ms 95.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4\n",
      "  triton_mm_976 0.0078 ms 95.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_mm_978 0.0078 ms 95.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4\n",
      "  triton_mm_979 0.0080 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_981 0.0080 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 1.7913 seconds and 0.0012 seconds precompiling\n",
      "E0827 13:17:18.495000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/jy/cjyxef7lhippi4wkfxyrmpgyq7syazzr75tzea2yeanlqzdrhx6q.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0827 13:17:18.497000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/id/ciddsid3k4s7l4cbsd4gtccampcwbtvoutzfklyk3t6rkngdah5l.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0827 13:17:18.497000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/bw/cbwio7bfs22xugiznqcltyv44x3c5kwc2iq6e3gltimbjkl6slna.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:18.498000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/pp/cppqcqej2siimx24jyoys2rsc5nc55imwggt3dvuxvvgnvc36sqd.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "W0827 13:17:19.202000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:19.595000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:19.990000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:20.521000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE bmm(3072x256x256, 3072x256x64)\n",
      "  triton_bmm_926 1.7371 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_922 1.7391 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_930 1.7447 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_923 1.7508 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_927 1.7714 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_bmm_932 1.7944 ms 96.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_bmm_931 1.8046 ms 96.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_bmm_933 1.9247 ms 90.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_917 2.0164 ms 86.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2\n",
      "  triton_bmm_924 2.0279 ms 85.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.0227 seconds and 0.0069 seconds precompiling\n",
      "AUTOTUNE bmm(3072x256x64, 3072x64x256)\n",
      "  bmm 2.1030 ms 100.0% \n",
      "  triton_bmm_948 2.5389 ms 82.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_944 2.5424 ms 82.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_946 2.5775 ms 81.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_947 2.5778 ms 81.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4\n",
      "  triton_bmm_950 2.6528 ms 79.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_bmm_951 2.7452 ms 76.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_953 2.7855 ms 75.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8\n",
      "  triton_bmm_943 3.4857 ms 60.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4\n",
      "  triton_bmm_941 3.4865 ms 60.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.9984 seconds and 0.0025 seconds precompiling\n",
      "E0827 13:17:23.526000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 221184, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/qa/cqaa3bmgy6keuaxzqhxzpc6c6yf23gv4jwknmtlkfafz64qbjuxp.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0827 13:17:23.527000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/gs/cgsi2q76mqnty5jyojkmtdjtdasqqsksa3qw2dqacv36xc5z5l4f.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:23.528000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/ly/cly57yoyrehstmslskvgoudve4ihme4yyeueifqfyjoqmvz5w6rh.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "W0827 13:17:24.488000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:24.808000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 221184, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:25.128000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE mm(12x961, 961x512)\n",
      "  mm 0.0150 ms 100.0% \n",
      "  triton_mm_957 0.0180 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2\n",
      "  triton_mm_958 0.0204 ms 73.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2\n",
      "  triton_mm_961 0.0220 ms 68.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_955 0.0224 ms 67.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2\n",
      "  triton_mm_965 0.0252 ms 59.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_968 0.0265 ms 56.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4\n",
      "  triton_mm_967 0.0266 ms 56.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_956 0.0271 ms 55.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4\n",
      "  triton_mm_963 0.0302 ms 49.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "SingleProcess AUTOTUNE benchmarking takes 1.6003 seconds and 0.0051 seconds precompiling\n",
      "E0827 13:17:25.132000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/j2/cj2wnvw4ubrajfwacrbe6zjhyajj7id4i5goy5j7k4qg67ywhcxs.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:25.133000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/bh/cbhx54hlbvi6gdxttppfaytozuvegybgytqdiyb37orj67s7xxag.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0827 13:17:25.134000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 122880, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/ls/clssc6qgpp2nsyzbbzbp3r5hpr63nqf5ryempqy4stnvg2gsaw6l.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "W0827 13:17:25.673000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:26.204000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 122880, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:26.632000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE mm(512x961, 961x2)\n",
      "  triton_mm_989 0.0185 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2\n",
      "  mm 0.0226 ms 81.9% \n",
      "  triton_mm_994 0.0239 ms 77.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_998 0.0262 ms 70.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_1001 0.0279 ms 66.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_995 0.0326 ms 56.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_mm_996 0.0330 ms 56.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4\n",
      "  triton_mm_990 0.0335 ms 55.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4\n",
      "  triton_mm_988 0.0364 ms 50.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2\n",
      "  triton_mm_999 0.0385 ms 48.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8\n",
      "SingleProcess AUTOTUNE benchmarking takes 1.4987 seconds and 0.0048 seconds precompiling\n",
      "E0827 13:17:26.636000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/cg/ccg4idhit34nw23e7jujg4ciriktpdhyrlgxsbtq32553a424yqw.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:26.637000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/4j/c4jkvxhz3ownrb3mpk3hekzkzpogeqoxuc3f3tvbz7dgk7nks4tx.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "E0827 13:17:26.638000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/nt/cnt6xn5sqn6oqdlbcojy3irfmqhj4vcrjgyrygdsu5f4tmhqzwdf.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:26.639000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/fh/cfhdwicglg2caxr3hv4d7wxjz42pz72oenrnirq2zowsazgpe33h.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "W0827 13:17:27.350000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:27.748000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:28.145000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 294912, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:28.542000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE bmm(3072x64x256, 3072x256x256)\n",
      "  triton_bmm_1016 1.7368 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1013 1.7391 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_bmm_1008 1.7406 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_1012 1.7442 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1009 1.7492 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_1018 1.7566 ms 98.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_bmm_1017 1.7715 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_bmm_1014 1.9222 ms 90.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1010 2.0398 ms 85.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8\n",
      "  triton_bmm_1003 2.0652 ms 84.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2\n",
      "SingleProcess AUTOTUNE benchmarking takes 1.9038 seconds and 0.0062 seconds precompiling\n",
      "E0827 13:17:28.546000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/6q/c6qhvg4i5mqlvefxmsznafunbluwke4ykflsho4njflcbmi6i7no.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:28.547000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/yo/cyob7v6qlcexc33vmmzbmm7q5sgqbunx25cxn46vx45yvawgcvmi.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4)\n",
      "E0827 13:17:28.548000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/dt/cdtupthjbcbyrnjym2uzandekupkleqxxcnh7upurmlsl2hpz2ve.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4)\n",
      "E0827 13:17:28.549000 1574237 torch/_inductor/select_algorithm.py:1300] [1/1] Exception out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_mmccabe/mu/cmupvwavvscwet3foxpwntcrimd2wltkjcskwvruez3i6k5ldog2.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8)\n",
      "W0827 13:17:29.328000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:29.771000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 262144, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:30.214000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "W0827 13:17:30.773000 1574237 torch/_inductor/select_algorithm.py:1517] [1/1] out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.\n",
      "AUTOTUNE bmm(3072x256x256, 3072x256x64)\n",
      "  triton_bmm_1033 2.2677 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  bmm 2.8046 ms 80.9% \n",
      "  triton_bmm_1036 2.8353 ms 80.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1035 3.3034 ms 68.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8\n",
      "  triton_bmm_1029 3.3206 ms 68.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1026 3.3347 ms 68.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_1025 3.4478 ms 65.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4\n",
      "  triton_bmm_1034 3.5195 ms 64.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "  triton_bmm_1031 3.6855 ms 61.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4\n",
      "  triton_bmm_1030 5.6966 ms 39.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8\n",
      "SingleProcess AUTOTUNE benchmarking takes 2.2249 seconds and 0.0061 seconds precompiling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max mem 35.06066131591797 GB\n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "autograd::engine::evaluate_function: CompiledFunctio...         0.00%     328.753us        49.33%       98.048s       24.512s       0.000us         0.00%       18.096s        4.524s           0 b           0 b     -20.26 Gb    -657.20 Mb             4  \n",
      "                               CompiledFunctionBackward         0.00%       7.705ms        49.33%       98.048s       24.512s        2.714s        14.79%       18.096s        4.524s           0 b           0 b     -19.62 Gb     -19.80 Gb             4  \n",
      "                      Scheduler.__init__ (dynamo_timed)         0.00%       0.000us         0.00%       0.000us       0.000us       15.069s        82.09%       15.069s       15.069s           0 b           0 b           0 b           0 b             1  \n",
      "                        compile_fx_inner (dynamo_timed)         0.23%     461.182ms        48.30%       96.009s       32.003s       0.000us         0.00%       14.199s        4.733s           0 b     -14.68 Kb        -512 b           0 b             3  \n",
      "         GraphLowering.compile_to_module (dynamo_timed)         0.64%        1.282s         9.96%       19.796s        6.599s       0.000us         0.00%       14.199s        4.733s           0 b           0 b           0 b           0 b             3  \n",
      "                      Scheduler.__init__ (dynamo_timed)         1.76%        3.490s         8.59%       17.072s        5.691s       0.000us         0.00%       14.199s        4.733s           0 b    -247.95 Kb           0 b     -32.93 Gb             3  \n",
      "         compile_fx.<locals>.bw_compiler (dynamo_timed)         0.00%     943.782us        47.46%       94.331s       47.165s       0.000us         0.00%       14.199s        7.099s           0 b           0 b        -512 b           0 b             2  \n",
      "                           Placeholder.DESCRIPTIVE_NAME         0.07%     132.450ms         0.11%     220.757ms      12.595us        7.939s        43.25%        7.939s     452.905us           0 b           0 b           0 b           0 b         17528  \n",
      "                                             triton_bmm         0.00%       0.000us         0.00%       0.000us       0.000us        7.648s        41.67%        7.648s       2.835ms           0 b           0 b           0 b           0 b          2698  \n",
      "                                            aten::fill_         0.04%      77.676ms         0.09%     171.745ms       9.574us        6.424s        35.00%        6.424s     358.138us           0 b           0 b           0 b           0 b         17938  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 198.762s\n",
      "Self CUDA time total: 18.355s\n",
      "\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1, resolution**2, 768).cuda()\n",
    "torch.cuda.reset_peak_memory_stats()\n",
    "with profile(activities=[\n",
    "        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:\n",
    "    with record_function(\"model_inference\"):\n",
    "        x = m2(x, BCs, (resolution, resolution))\n",
    "        x = m2(x, BCs, (resolution, resolution))\n",
    "        x = m2(x, BCs, (resolution, resolution))\n",
    "        y = m2(x, BCs, (resolution, resolution))\n",
    "        z = y.sum()\n",
    "        z.backward()\n",
    "        mem = torch.cuda.max_memory_allocated() / 1024**3\n",
    "\n",
    "print('Max mem', mem, 'GB')\n",
    "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n",
    "m2.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.126898129799549\n",
      "tensor([9.7728e-01, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05, 4.4369e-05,\n",
      "        4.4369e-05, 4.4369e-05, 4.4369e-05])\n",
      "tensor([9.7894e-01, 2.0385e-07, 2.0385e-07,  ..., 2.0385e-07, 2.0385e-07,\n",
      "        2.0385e-07])\n"
     ]
    }
   ],
   "source": [
    "orig = 512\n",
    "k = 200\n",
    "extrap = k*orig\n",
    "\n",
    "\n",
    "orig_logits = [10] + [0]*orig\n",
    "new_logits = [10] + [0]*extrap\n",
    "\n",
    "t = -orig_logits[0] / (np.log(k)-orig_logits[0])\n",
    "print(t)\n",
    "\n",
    "print(torch.softmax(torch.Tensor(orig_logits), dim=-1))\n",
    "print(torch.softmax(torch.Tensor(new_logits)/.65, dim=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0],\n",
      "        [ 1],\n",
      "        [ 2],\n",
      "        [ 3],\n",
      "        [ 4],\n",
      "        [ 5],\n",
      "        [ 6],\n",
      "        [ 7],\n",
      "        [ 8],\n",
      "        [ 9],\n",
      "        [10],\n",
      "        [11],\n",
      "        [12],\n",
      "        [13],\n",
      "        [14],\n",
      "        [15]]) block_id: tensor([[0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [2],\n",
      "        [2],\n",
      "        [2],\n",
      "        [2],\n",
      "        [3],\n",
      "        [3],\n",
      "        [3],\n",
      "        [3]]) , block x: tensor([[0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1]]) , x: tensor([[0],\n",
      "        [1],\n",
      "        [0],\n",
      "        [1],\n",
      "        [2],\n",
      "        [3],\n",
      "        [2],\n",
      "        [3],\n",
      "        [0],\n",
      "        [1],\n",
      "        [0],\n",
      "        [1],\n",
      "        [2],\n",
      "        [3],\n",
      "        [2],\n",
      "        [3]]) , block y: tensor([[0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1],\n",
      "        [1]]) , y: tensor([[0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [0],\n",
      "        [0],\n",
      "        [1],\n",
      "        [1],\n",
      "        [2],\n",
      "        [2],\n",
      "        [3],\n",
      "        [3],\n",
      "        [2],\n",
      "        [2],\n",
      "        [3],\n",
      "        [3]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f6c2842f3d0>"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbEElEQVR4nO3df2zUhf3H8dfRo0fH2tPW0fZGC53hKwoV1ALBmg1CY9MgyndRp8HaYLLNrQq1hpVuK25BPHEbqygpYjJhifjjD0FHpoRVBIn87Fkn2caP2NXOpnQmegclnLX9fP/4frnvKv1B8XN935XnI7k/7u7Tfl5ByzOf9rh6HMdxBADACBtjPQAAcHkiQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwITXesBX9fb2qr29Xenp6fJ4PNZzAADD5DiOTp8+rUAgoDFjBr7OSbgAtbe3Ky8vz3oGAOBramtr08SJEwd8PuEClJ6eLklqDU1Wxjf5DuFI+O//KrSeAGAU+VLd2qc/x/4+H0jCBej8t90yvjlGGekEaCR4PWOtJwAYTf7vHUaH+jEKf8MDAEwQIACACQIEADBBgAAAJuIWoA0bNmjy5MkaN26c5syZo0OHDsXrVACAJBSXAL3yyiuqrq7WY489plAopBkzZqi0tFSdnZ3xOB0AIAnFJUDr1q3TD3/4Qy1dulTXXXedNm7cqG984xv6wx/+EI/TAQCSkOsB+uKLL9TU1KSSkpL/P8mYMSopKdH+/fsvOD4ajSoSifS5AQBGP9cD9Omnn6qnp0fZ2dl9Hs/OzlZHR8cFxweDQfn9/tiNt+EBgMuD+avgamtrFQ6HY7e2tjbrSQCAEeD6W/FcddVVSklJ0alTp/o8furUKeXk5FxwvM/nk8/nc3sGACDBuX4FlJqaqptuukmNjY2xx3p7e9XY2Ki5c+e6fToAQJKKy5uRVldXq6KiQkVFRZo9e7bq6+vV1dWlpUuXxuN0AIAkFJcA/eAHP9C///1vrVq1Sh0dHZo5c6beeuutC16YAAC4fHkcx3GsR/ynSCQiv9+vz45/h1/HMEJKAzOtJwAYRb50uvWOXlc4HFZGRsaAx/E3PADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABOuBygYDGrWrFlKT0/XhAkTtHjxYh07dszt0wAAkpzrAdqzZ48qKyt14MAB7dq1S93d3br11lvV1dXl9qkAAEnM6/YnfOutt/rc37x5syZMmKCmpiZ997vfdft0AIAk5XqAviocDkuSMjMz+30+Go0qGo3G7kcikXhPAgAkgLi+CKG3t1dVVVUqLi7W9OnT+z0mGAzK7/fHbnl5efGcBABIEHENUGVlpY4ePaqXX355wGNqa2sVDodjt7a2tnhOAgAkiLh9C+6hhx7Sjh07tHfvXk2cOHHA43w+n3w+X7xmAAASlOsBchxHDz/8sLZt26Z33nlHBQUFbp8CADAKuB6gyspKbd26Va+//rrS09PV0dEhSfL7/UpLS3P7dACAJOX6z4AaGhoUDoc1b9485ebmxm6vvPKK26cCACSxuHwLDgCAofBecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHitBwzkv/+rUF7PWOsZl4Wd7c1x/fylgZlx+9zx3B7P3QC4AgIAGCFAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACbiHqAnn3xSHo9HVVVV8T4VACCJxDVAhw8f1nPPPafrr78+nqcBACShuAXozJkzWrJkiZ5//nldeeWV8ToNACBJxS1AlZWVWrhwoUpKSuJ1CgBAEovLe8G9/PLLCoVCOnz48JDHRqNRRaPR2P1IJBKPSQCABOP6FVBbW5uWL1+uF198UePGjRvy+GAwKL/fH7vl5eW5PQkAkIBcD1BTU5M6Ozt14403yuv1yuv1as+ePVq/fr28Xq96enr6HF9bW6twOBy7tbW1uT0JAJCAXP8W3IIFC/Thhx/2eWzp0qWaOnWqampqlJKS0uc5n88nn8/n9gwAQIJzPUDp6emaPn16n8fGjx+vrKysCx4HAFy+eCcEAICJEfmNqO+8885InAYAkES4AgIAmCBAAAATBAgAYIIAAQBMECAAgIkReRUcEltpYKb1hEuWzNuByx1XQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa81gNgb2d7c1w/f2lgZtw+dzy3x3M3AK6AAABGCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAibgE6JNPPtF9992nrKwspaWlqbCwUEeOHInHqQAAScr1f4j62Wefqbi4WPPnz9ebb76pb33rWzpx4oSuvPJKt08FAEhirgdo7dq1ysvL0wsvvBB7rKCgwO3TAACSnOvfgnvjjTdUVFSku+66SxMmTNANN9yg559/fsDjo9GoIpFInxsAYPRzPUAfffSRGhoaNGXKFO3cuVM/+clPtGzZMm3ZsqXf44PBoPx+f+yWl5fn9iQAQALyOI7juPkJU1NTVVRUpPfeey/22LJly3T48GHt37//guOj0aii0WjsfiQSUV5enubpDnk9Y92chgHwZqT9481IgUvzpdOtd/S6wuGwMjIyBjzO9Sug3NxcXXfddX0eu/baa/Xxxx/3e7zP51NGRkafGwBg9HM9QMXFxTp27Fifx44fP65Jkya5fSoAQBJzPUCPPPKIDhw4oCeeeEInT57U1q1btWnTJlVWVrp9KgBAEnM9QLNmzdK2bdv00ksvafr06Vq9erXq6+u1ZMkSt08FAEhicfmNqLfddptuu+22eHxqAMAowXvBAQBMECAAgAkCBAAwQYAAACbi8iIEJJdk/hf/ybwduNxxBQQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDCaz0AuBztbG+2nnBZKQ3MtJ6AfnAFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATLgeoJ6eHtXV1amgoEBpaWm6+uqrtXr1ajmO4/apAABJzPV/iLp27Vo1NDRoy5YtmjZtmo4cOaKlS5fK7/dr2bJlbp8OAJCkXA/Qe++9pzvuuEMLFy6UJE2ePFkvvfSSDh065PapAABJzPVvwd18881qbGzU8ePHJUkffPCB9u3bp7Kysn6Pj0ajikQifW4AgNHP9SuglStXKhKJaOrUqUpJSVFPT4/WrFmjJUuW9Ht8MBjUr3/9a7dnAAASnOtXQK+++qpefPFFbd26VaFQSFu2bNFvf/tbbdmypd/ja2trFQ6HY7e2tja3JwEAEpDrV0ArVqzQypUrdc8990iSCgsL1draqmAwqIqKiguO9/l88vl8bs8AACQ416+Azp49qzFj+n7alJQU9fb2un0qAEASc/0KaNGiRVqzZo3y8/M1bdo0vf/++1q3bp0eeOABt08FAEhirgfomWeeUV1dnX7605+qs7NTgUBAP/7xj7Vq1Sq3TwUASGKuByg9PV319fWqr693+1MDAEYR3gsOAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAiWEHaO/evVq0aJECgYA8Ho+2b9/e53nHcbRq1Srl5uYqLS1NJSUlOnHihFt7AQCjxLAD1NXVpRkzZmjDhg39Pv/UU09p/fr12rhxow4ePKjx48ertLRU586d+9pjAQCjh3e4H1BWVqaysrJ+n3McR/X19frlL3+pO+64Q5L0xz/+UdnZ2dq+fbvuueeer7cWADBquPozoJaWFnV0dKikpCT2mN/v15w5c7R///5+PyYajSoSifS5AQBGP1cD1NHRIUnKzs7u83h2dnbsua8KBoPy+/2xW15enpuTAAAJyvxVcLW1tQqHw7FbW1ub9SQAwAhwNUA5OTmSpFOnTvV5/NSpU7Hnvsrn8ykjI6PPDQAw+rkaoIKCAuXk5KixsTH2WCQS0cGDBzV37lw3TwUASHLDfhXcmTNndPLkydj9lpYWNTc3KzMzU/n5+aqqqtLjjz+uKVOmqKCgQHV1dQoEAlq8eLGbuwEASW7YATpy5Ijmz58fu19dXS1Jqqio0ObNm/Wzn/1MXV1d+tGPfqTPP/9ct9xyi9566y2NGzfOvdUAgKTncRzHsR7xnyKRiPx+v+bpDnk9Y63nAHGxs73ZesJlpTQw03rCZeVLp1vv6HWFw+FBf65v/io4AMDliQABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACa/1AOByVBqYaT3hsrKzvTmunz+e/z3jud36/0OugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIlhB2jv3r1atGiRAoGAPB6Ptm/fHnuuu7tbNTU1Kiws1Pjx4xUIBHT//fervb3dzc0AgFFg2AHq6urSjBkztGHDhgueO3v2rEKhkOrq6hQKhfTaa6/p2LFjuv32210ZCwAYPYb9TghlZWUqKyvr9zm/369du3b1eezZZ5/V7Nmz9fHHHys/P//SVgIARp24vxVPOByWx+PRFVdc0e/z0WhU0Wg0dj8SicR7EgAgAcT1RQjnzp1TTU2N7r33XmVkZPR7TDAYlN/vj93y8vLiOQkAkCDiFqDu7m7dfffdchxHDQ0NAx5XW1urcDgcu7W1tcVrEgAggcTlW3Dn49Pa2qq33357wKsfSfL5fPL5fPGYAQBIYK4H6Hx8Tpw4od27dysrK8vtUwAARoFhB+jMmTM6efJk7H5LS4uam5uVmZmp3Nxc3XnnnQqFQtqxY4d6enrU0dEhScrMzFRqaqp7ywEASW3YATpy5Ijmz58fu19dXS1Jqqio0K9+9Su98cYbkqSZM2f2+bjdu3dr3rx5l74UADCqDDtA8+bNk+M4Az4/2HMAAJzHe8EBAEwQIACACQIEADBBgAAAJggQAMBE3N+MFACslQZmWk+4ZMm8fShcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjwWg8AgHjb2d4c189fGpgZt88dz+3x3H0xuAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMDHsAO3du1eLFi1SIBCQx+PR9u3bBzz2wQcflMfjUX19/deYCAAYjYYdoK6uLs2YMUMbNmwY9Lht27bpwIEDCgQClzwOADB6DfsfopaVlamsrGzQYz755BM9/PDD2rlzpxYuXHjJ4wAAo5frPwPq7e1VeXm5VqxYoWnTprn96QEAo4Trb8Wzdu1aeb1eLVu27KKOj0ajikajsfuRSMTtSQCABOTqFVBTU5Oefvppbd68WR6P56I+JhgMyu/3x255eXluTgIAJChXA/Tuu++qs7NT+fn58nq98nq9am1t1aOPPqrJkyf3+zG1tbUKh8OxW1tbm5uTAAAJytVvwZWXl6ukpKTPY6WlpSovL9fSpUv7/Rifzyefz+fmDABAEhh2gM6cOaOTJ0/G7re0tKi5uVmZmZnKz89XVlZWn+PHjh2rnJwcXXPNNV9/LQBg1Bh2gI4cOaL58+fH7ldXV0uSKioqtHnzZteGAQBGt2EHaN68eXIc56KP/+c//zncUwAALgO8FxwAwAQBAgCYIEAAABMECABgggABAEy4/l5wAJBoSgMzrSdcsmTePhSugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgwms94Kscx5EkfaluyTEeAwAYti/VLen//z4fSMIF6PTp05Kkffqz8RIAwNdx+vRp+f3+AZ/3OEMlaoT19vaqvb1d6enp8ng8Qx4fiUSUl5entrY2ZWRkjMBCd7B7ZCXrbil5t7N7ZCXSbsdxdPr0aQUCAY0ZM/BPehLuCmjMmDGaOHHisD8uIyPD/A/9UrB7ZCXrbil5t7N7ZCXK7sGufM7jRQgAABMECABgIukD5PP59Nhjj8nn81lPGRZ2j6xk3S0l73Z2j6xk3J1wL0IAAFwekv4KCACQnAgQAMAEAQIAmCBAAAATSR2gDRs2aPLkyRo3bpzmzJmjQ4cOWU8aUjAY1KxZs5Senq4JEyZo8eLFOnbsmPWsYXvyySfl8XhUVVVlPWVIn3zyie677z5lZWUpLS1NhYWFOnLkiPWsQfX09Kiurk4FBQVKS0vT1VdfrdWrVw/53loW9u7dq0WLFikQCMjj8Wj79u19nnccR6tWrVJubq7S0tJUUlKiEydO2Iz9D4Pt7u7uVk1NjQoLCzV+/HgFAgHdf//9am9vtxv8f4b68/5PDz74oDwej+rr60ds33AkbYBeeeUVVVdX67HHHlMoFNKMGTNUWlqqzs5O62mD2rNnjyorK3XgwAHt2rVL3d3duvXWW9XV1WU97aIdPnxYzz33nK6//nrrKUP67LPPVFxcrLFjx+rNN9/U3/72N/3ud7/TlVdeaT1tUGvXrlVDQ4OeffZZ/f3vf9fatWv11FNP6ZlnnrGedoGuri7NmDFDGzZs6Pf5p556SuvXr9fGjRt18OBBjR8/XqWlpTp37twIL+1rsN1nz55VKBRSXV2dQqGQXnvtNR07dky33367wdK+hvrzPm/btm06cOCAAoHACC27BE6Smj17tlNZWRm739PT4wQCAScYDBquGr7Ozk5HkrNnzx7rKRfl9OnTzpQpU5xdu3Y53/ve95zly5dbTxpUTU2Nc8stt1jPGLaFCxc6DzzwQJ/Hvv/97ztLliwxWnRxJDnbtm2L3e/t7XVycnKc3/zmN7HHPv/8c8fn8zkvvfSSwcL+fXV3fw4dOuRIclpbW0dm1EUYaPe//vUv59vf/rZz9OhRZ9KkSc7vf//7Ed92MZLyCuiLL75QU1OTSkpKYo+NGTNGJSUl2r9/v+Gy4QuHw5KkzMxM4yUXp7KyUgsXLuzzZ5/I3njjDRUVFemuu+7ShAkTdMMNN+j555+3njWkm2++WY2NjTp+/Lgk6YMPPtC+fftUVlZmvGx4Wlpa1NHR0ef/F7/frzlz5iTl16rH49EVV1xhPWVQvb29Ki8v14oVKzRt2jTrOYNKuDcjvRiffvqpenp6lJ2d3efx7Oxs/eMf/zBaNXy9vb2qqqpScXGxpk+fbj1nSC+//LJCoZAOHz5sPeWiffTRR2poaFB1dbV+/vOf6/Dhw1q2bJlSU1NVUVFhPW9AK1euVCQS0dSpU5WSkqKenh6tWbNGS5YssZ42LB0dHZLU79fq+eeSwblz51RTU6N77703Id7oczBr166V1+vVsmXLrKcMKSkDNFpUVlbq6NGj2rdvn/WUIbW1tWn58uXatWuXxo0bZz3novX29qqoqEhPPPGEJOmGG27Q0aNHtXHjxoQO0KuvvqoXX3xRW7du1bRp09Tc3KyqqioFAoGE3j0adXd36+6775bjOGpoaLCeM6impiY9/fTTCoVCF/XrbKwl5bfgrrrqKqWkpOjUqVN9Hj916pRycnKMVg3PQw89pB07dmj37t2X9OsnRlpTU5M6Ozt14403yuv1yuv1as+ePVq/fr28Xq96enqsJ/YrNzdX1113XZ/Hrr32Wn388cdGiy7OihUrtHLlSt1zzz0qLCxUeXm5HnnkEQWDQetpw3L+6zFZv1bPx6e1tVW7du1K+Kufd999V52dncrPz499nba2turRRx/V5MmTreddICkDlJqaqptuukmNjY2xx3p7e9XY2Ki5c+caLhua4zh66KGHtG3bNr399tsqKCiwnnRRFixYoA8//FDNzc2xW1FRkZYsWaLm5malpKRYT+xXcXHxBS9zP378uCZNmmS06OKcPXv2gl/klZKSot7eXqNFl6agoEA5OTl9vlYjkYgOHjyY8F+r5+Nz4sQJ/eUvf1FWVpb1pCGVl5frr3/9a5+v00AgoBUrVmjnzp3W8y6QtN+Cq66uVkVFhYqKijR79mzV19erq6tLS5cutZ42qMrKSm3dulWvv/660tPTY98H9/v9SktLM143sPT09At+TjV+/HhlZWUl9M+vHnnkEd1888164okndPfdd+vQoUPatGmTNm3aZD1tUIsWLdKaNWuUn5+vadOm6f3339e6dev0wAMPWE+7wJkzZ3Ty5MnY/ZaWFjU3NyszM1P5+fmqqqrS448/rilTpqigoEB1dXUKBAJavHix3WgNvjs3N1d33nmnQqGQduzYoZ6entjXamZmplJTU61mD/nn/dVQjh07Vjk5ObrmmmtGeurQrF+G93U888wzTn5+vpOamurMnj3bOXDggPWkIUnq9/bCCy9YTxu2ZHgZtuM4zp/+9Cdn+vTpjs/nc6ZOneps2rTJetKQIpGIs3z5cic/P98ZN26c853vfMf5xS9+4USjUetpF9i9e3e//09XVFQ4jvO/L8Wuq6tzsrOzHZ/P5yxYsMA5duyY7Whn8N0tLS0Dfq3u3r07YXf3J5Ffhs2vYwAAmEjKnwEBAJIfAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDifwCFT8sB69aS6AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def create_swindow_mask(H, W, window_size, shift_size, bcs, device=\"cuda\"):\n",
    "    # @lru_cache()\n",
    "    def swindow_mask(b, h, q_idx, kv_idx):\n",
    "        mask_size = window_size**2\n",
    "\n",
    "        q_window = q_idx // mask_size\n",
    "        kv_window = kv_idx // mask_size\n",
    "        # return q_window == kv_window\n",
    "        qlocal_idx = q_idx % mask_size\n",
    "        kvlocal_idx = kv_idx % mask_size\n",
    "        \n",
    "        windows_per_row = W // window_size\n",
    "        windows_per_col = H // window_size\n",
    "        # Get 2D block coordinates\n",
    "        q_block_x = q_window % windows_per_row\n",
    "        q_block_y = q_window // windows_per_row\n",
    "        kv_block_x = kv_window % windows_per_row\n",
    "        kv_block_y = kv_window // windows_per_row\n",
    "        # Now get the local coordinates within the block\n",
    "        qx = q_block_x * window_size + qlocal_idx % window_size\n",
    "        qy = q_block_y * window_size + qlocal_idx // window_size\n",
    "        kvx = kv_block_x * window_size + kvlocal_idx % window_size\n",
    "        kvy = kv_block_y * window_size + kvlocal_idx // window_size\n",
    "\n",
    "        # print(q_idx, 'block_id:', q_window, ', block x:', q_block_x, ', x:', qx, ', block y:', q_block_y, ', y:', qy)\n",
    "        # qx, qy = q_idx % W, q_idx // W\n",
    "        # kvx, kvy = kv_idx % W, kv_idx // W\n",
    "        x_mask = qx // window_size == kvx // window_size\n",
    "        y_mask = qy // window_size == kvy // window_size\n",
    "        # mask = q_idx // mask_size == kv_idx // mask_size\n",
    "        if shift_size > 0:\n",
    "            if bcs[0, 0] == 0:\n",
    "                y_mask &= kvx < H - shift_size\n",
    "                y_mask &= qx < H - shift_size\n",
    "            elif bcs[0, 1] == 0:\n",
    "                x_mask &= kvx < W - shift_size\n",
    "                y_mask &= qx < H - shift_size\n",
    "        return x_mask & y_mask\n",
    "    return swindow_mask\n",
    "\n",
    "H = 4\n",
    "W = 4\n",
    "window_size = 2\n",
    "shift_size = 1\n",
    "bcs = torch.tensor([[0, 0]])\n",
    "device = \"cuda\"\n",
    "swindow_mask = create_swindow_mask(H, W, window_size, shift_size, bcs, device)\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "mask = swindow_mask(None, None, torch.arange(H*W)[:, None], torch.arange(H*W)[None, :])\n",
    "\n",
    "plt.imshow(mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 block_id: 0 , block x: 0 , x: 0 , block y: 0 , y: 0\n",
      "1 block_id: 0 , block x: 0 , x: 1 , block y: 0 , y: 0\n",
      "2 block_id: 0 , block x: 0 , x: 0 , block y: 0 , y: 1\n",
      "3 block_id: 0 , block x: 0 , x: 1 , block y: 0 , y: 1\n",
      "4 block_id: 1 , block x: 1 , x: 2 , block y: 0 , y: 0\n",
      "5 block_id: 1 , block x: 1 , x: 3 , block y: 0 , y: 0\n",
      "6 block_id: 1 , block x: 1 , x: 2 , block y: 0 , y: 1\n",
      "7 block_id: 1 , block x: 1 , x: 3 , block y: 0 , y: 1\n",
      "8 block_id: 2 , block x: 0 , x: 0 , block y: 1 , y: 2\n",
      "9 block_id: 2 , block x: 0 , x: 1 , block y: 1 , y: 2\n",
      "10 block_id: 2 , block x: 0 , x: 0 , block y: 1 , y: 3\n",
      "11 block_id: 2 , block x: 0 , x: 1 , block y: 1 , y: 3\n",
      "12 block_id: 3 , block x: 1 , x: 2 , block y: 1 , y: 2\n",
      "13 block_id: 3 , block x: 1 , x: 3 , block y: 1 , y: 2\n",
      "14 block_id: 3 , block x: 1 , x: 2 , block y: 1 , y: 3\n",
      "15 block_id: 3 , block x: 1 , x: 3 , block y: 1 , y: 3\n"
     ]
    }
   ],
   "source": [
    "for i in range(16):\n",
    "    swindow_mask(None, None, i, i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "well_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
