{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:11.880436Z",
     "start_time": "2025-09-29T05:14:10.358372Z"
    }
   },
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))\n",
    "\n",
    "from csgp.layers.kernels import LaplaceL1Kernel\n",
    "from csgp.design_class import HyperbolicCrossDesign\n",
    "from csgp.chol_inv import mk_chol_inv"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "id": "9466ace5ee5803a9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:12.961924Z",
     "start_time": "2025-09-29T05:14:12.867041Z"
    }
   },
   "source": [
    "L = 3 # Level-L dyadic grid: m=2^L-1 grid points\n",
    "\n",
    "dyadic_design = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1)\n",
    "design_points = dyadic_design.points.reshape(-1, 1)  # [m, 1] size tensor\n",
    "print(f'Dyadic sorted design points: {dyadic_design.points}')"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dyadic sorted design points: tensor([0.5000, 0.2500, 0.7500, 0.1250, 0.3750, 0.6250, 0.8750])\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "id": "7e60f44d9d21a607",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:14.213570Z",
     "start_time": "2025-09-29T05:14:14.210610Z"
    }
   },
   "source": [
    "import torch\n",
    "\n",
    "x = torch.tensor([0.35, 0.65])"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "ffdcb75d2c0d762b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:17.214465Z",
     "start_time": "2025-09-29T05:14:17.187006Z"
    }
   },
   "source": [
    "###########################################################\n",
    "# PART 1: Use Cholesky decomposition phi(x) = k(x,U)L^{-T}\n",
    "###########################################################\n",
    "chol_inv = mk_chol_inv(\n",
    "    dyadic_design=dyadic_design,\n",
    "    markov_kernel=LaplaceL1Kernel(lengthscale=1.),\n",
    "    upper=True)  # [m, m] size tensor\n",
    "k_xu = LaplaceL1Kernel(lengthscale=1.)(x, design_points)\n",
    "phi = torch.matmul(k_xu, chol_inv)\n",
    "print(f\"x: {x}\")\n",
    "print(f\"phi(x): {phi}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: tensor([0.3500, 0.6500])\n",
      "phi(x): tensor([[ 8.6071e-01,  3.7387e-01, -1.4539e-07, -5.6704e-09,  2.8185e-01,\n",
      "          1.5002e-07, -3.2482e-08],\n",
      "        [ 8.6071e-01, -1.4539e-07,  3.7387e-01, -3.2482e-08,  2.0258e-07,\n",
      "          2.8185e-01, -5.6704e-09]])\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "id": "685b43f5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:20.798581Z",
     "start_time": "2025-09-29T05:14:20.787237Z"
    }
   },
   "source": [
    "def dyadic_nonzero_indices(x: torch.Tensor, L: int, return_anchor=False):\n",
    "        if x.ndim == 0:\n",
    "            x = x.unsqueeze(0)  # promote scalar to shape (1,)\n",
    "\n",
    "        # 2^s for s=1..L\n",
    "        pow2 = torch.pow(2, torch.arange(1, L+1, device=x.device, dtype=torch.int64))  # (L,)\n",
    "\n",
    "        # k_s = ceil(2^s * x) clamped to [1, 2^s - 1]\n",
    "        ks = torch.ceil(x[..., None] * pow2.to(x.dtype)).to(torch.int64)  # (..., L)\n",
    "        ks = torch.clamp(ks, min=1)\n",
    "        ks_max = (pow2 - 1)  # (L,)\n",
    "        ks = torch.minimum(ks, ks_max)  # (..., L)\n",
    "\n",
    "        # r_s^(odd): force to be odd (right endpoint index made odd)\n",
    "        # if ks even -> ks-1, else ks\n",
    "        rs = ks - ((ks & 1) == 0).to(torch.int64)  # (..., L), odd in {1,3,...,2^s-1}\n",
    "\n",
    "        # position within level s: t_s in {1,...,2^{s-1}}\n",
    "        ts = (rs + 1) // 2  # (..., L)\n",
    "\n",
    "        # offsets: number of columns before level s (0-based indexing)\n",
    "        offsets = (pow2 // 2) - 1  # (L,)\n",
    "\n",
    "        # global 0-based indices: J_s = offset(s) + (t_s - 1)\n",
    "        idx = offsets + (ts - 1)  # (..., L)\n",
    "        \n",
    "        if return_anchor:\n",
    "            anchor = torch.tensor([2**L-1, 2**L], device=x.device, dtype=torch.int64).expand((*idx.shape[:-1], 2))  # (..., 2)\n",
    "            idx = torch.cat([idx, anchor], dim=-1)  # (..., L+2)\n",
    "\n",
    "        return idx\n",
    "    \n",
    "nonzero_idx = dyadic_nonzero_indices(x, L)\n",
    "print(f\"Non-zero indices:\\n {nonzero_idx}\")\n",
    "print(f\"idx shape: {nonzero_idx.shape}\")\n",
    "\n",
    "nonzero_idx_anchor = dyadic_nonzero_indices(x, L, return_anchor=True)\n",
    "print(f\"Non-zero indices with anchor:\\n {nonzero_idx_anchor}\")\n",
    "print(f\"idx with anchor shape: {nonzero_idx_anchor.shape}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-zero indices:\n",
      " tensor([[0, 1, 4],\n",
      "        [0, 2, 5]])\n",
      "idx shape: torch.Size([2, 3])\n",
      "Non-zero indices with anchor:\n",
      " tensor([[0, 1, 4, 7, 8],\n",
      "        [0, 2, 5, 7, 8]])\n",
      "idx with anchor shape: torch.Size([2, 5])\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "ff22c4153305963c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:23.800703Z",
     "start_time": "2025-09-29T05:14:23.795332Z"
    }
   },
   "source": [
    "###########################################################\n",
    "# PART 2: Use compact support and sparse psi\n",
    "###########################################################\n",
    "import numpy as np\n",
    "\n",
    "def anchor_points(x: torch.Tensor, ell_c: float = 1.0):\n",
    "        x = torch.exp(- (x / ell_c)) + torch.exp(- ((1 - x) / ell_c))\n",
    "        coeff = torch.tensor([1.0 / np.sqrt(2.0 * (1 + np.exp(- 1.0 / ell_c))), 1.0 / np.sqrt(2.0 * (1 - np.exp(- 1.0 / ell_c)))], device=x.device, dtype=x.dtype)\n",
    "        res = x.unsqueeze(-1) @ coeff.unsqueeze(0)  # (..., 1)\n",
    "        return res  # (..., 2)\n",
    "    \n",
    "def dyadic_psi(x: torch.Tensor, L: int, sigma: float = 1.0, ell_c: float = 1.0, return_anchor: bool = False):\n",
    "    \"\"\"\n",
    "    Batch-wise dyadic nonzero indices.\n",
    "\n",
    "    Args\n",
    "    ----\n",
    "    x : (...,) tensor with values in [0, 1].\n",
    "        Works with any number of leading batch dims.\n",
    "    L : int, number of dyadic levels (total columns m = 2^L - 1).\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    idx : (..., L) long tensor\n",
    "        0-based global column indices in dyadic order for each level (DC is level 1).\n",
    "        The returned shape matches the leading shape of x, with an extra trailing dim of size L.\n",
    "    \"\"\"\n",
    "    if x.ndim == 0:\n",
    "        x = x.unsqueeze(0)  # promote scalar to shape (1,)\n",
    "    device, dtype = x.device, x.dtype\n",
    "\n",
    "    # 2^s for s=1..L\n",
    "    pow2 = torch.pow(2, torch.arange(1, L+1, device=device, dtype=torch.int64))  # (L,)\n",
    "\n",
    "    # k_s = ceil(2^s * x) clamped to [1, 2^s - 1]\n",
    "    ks = torch.ceil(x[..., None] * pow2.to(x.dtype)).to(torch.int64)  # (..., L)\n",
    "    ks = torch.clamp(ks, min=1)\n",
    "    ks_max = (pow2 - 1)  # (L,)\n",
    "    ks = torch.minimum(ks, ks_max)  # (..., L)\n",
    "\n",
    "    # r_s^(odd): force to be odd (right endpoint index made odd)\n",
    "    # if ks even -> ks-1, else ks\n",
    "    rs = ks - ((ks & 1) == 0).to(torch.int64) # (..., L), odd in {1,3,...,2^s-1}\n",
    "\n",
    "    # position within level s: t_s in {1,...,2^{s-1}}\n",
    "    ts = (rs + 1) // 2  # (..., L)\n",
    "\n",
    "    # offsets: number of columns before level s (0-based indexing)\n",
    "    offsets = (pow2 // 2) - 1  # (L,)\n",
    "\n",
    "    # global 0-based indices: J_s = offset(s) + (t_s - 1)\n",
    "    idx = offsets + (ts - 1)  # (..., L)\n",
    "\n",
    "    # u = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1).points # (2^L-1,)\n",
    "    # view_shape = (1,) * x.dim() + (u.shape[0],)      # (1,1,...,1, 2^L-1)\n",
    "    # u_selected = torch.gather(u.view(view_shape).expand(*x.shape, -1), dim=-1, index=idx) # (..., L)\n",
    "    # delta = torch.abs(x.unsqueeze(-1) - u_selected) # |x - m2^{-l}|\n",
    "\n",
    "    delta = torch.abs(x.unsqueeze(-1) - rs/pow2)  # (..., L)\n",
    "    pow2_f = (1.0 / pow2).to(x.dtype)\n",
    "\n",
    "    psi =  sigma * torch.sqrt(2 / torch.sinh(pow2_f * 2 * ell_c)) * torch.sinh(ell_c * (pow2_f - delta))\n",
    "    \n",
    "    if return_anchor:\n",
    "        anchor_idx = torch.tensor([2**L-1, 2**L], device=x.device, dtype=torch.int64).expand((*idx.shape[:-1], 2))  # (..., 2)\n",
    "        idx = torch.cat([idx, anchor_idx], dim=-1)  # (..., L+2)\n",
    "        anchor_vals = anchor_points(x, ell_c=ell_c)  # (..., 2)\n",
    "        psi = torch.cat([psi, anchor_vals], dim=-1)  # (...,\n",
    "\n",
    "    return psi, idx  # (..., L)"
   ],
   "outputs": [],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "id": "232c8570557046fa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:14:30.942325Z",
     "start_time": "2025-09-29T05:14:30.927510Z"
    }
   },
   "source": [
    "# Let's find out the non-zero indices using wavelet design\n",
    "# Compared to phi_dense, the selected nonzero indices are correct\n",
    "psi, idx = dyadic_psi(x, L)\n",
    "print(f\"Non-zero idx:\\n {idx}\")\n",
    "\n",
    "phi_nonzero = torch.gather(phi, dim=-1, index=idx)\n",
    "print(f\"Non-zero phi(x):\\n {phi_nonzero}\")\n",
    "print(f\"psi(x):\\n {psi}\")\n",
    "\n",
    "psi_anchor, idx_anchor = dyadic_psi(x, L, return_anchor=True)\n",
    "print(f\"Non-zero idx with anchor:\\n {idx_anchor}\")\n",
    "print(f\"psi(x) with anchor:\\n {psi_anchor}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-zero idx:\n",
      " tensor([[0, 1, 4],\n",
      "        [0, 2, 5]])\n",
      "Non-zero phi(x):\n",
      " tensor([[0.8607, 0.3739, 0.2818],\n",
      "        [0.8607, 0.3739, 0.2818]])\n",
      "psi(x):\n",
      " tensor([[0.4660, 0.2950, 0.2818],\n",
      "        [0.4660, 0.2950, 0.2818]])\n",
      "Non-zero idx with anchor:\n",
      " tensor([[0, 1, 4, 7, 8],\n",
      "        [0, 2, 5, 7, 8]])\n",
      "psi(x) with anchor:\n",
      " tensor([[0.4660, 0.2950, 0.2818, 0.7417, 1.0910],\n",
      "        [0.4660, 0.2950, 0.2818, 0.7417, 1.0910]])\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "id": "9334be0a6a5bf7f1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:20:10.648906Z",
     "start_time": "2025-09-29T05:20:10.643217Z"
    }
   },
   "source": [
    "###########################################################\n",
    "# PART3: verify the batch-wise computing of psi(x)\n",
    "###########################################################\n",
    "x = torch.tensor([[0.1, 0.8], [0.8, 0.1], [0.35, 0.65], [0.65, 0.35]]) # (B=4, N=2)\n",
    "psi, idx = dyadic_psi(x, L)\n",
    "print(f\"Non-zero idx:\\n {idx}\")\n",
    "print(f\"idx shape: {idx.shape}\")\n",
    "print(f\"psi shape: {psi.shape}\")\n",
    "\n",
    "psi_anchor, idx_anchor = dyadic_psi(x, L, return_anchor=True)\n",
    "print(f\"idx shape with anchor: {idx_anchor.shape}\")\n",
    "print(f\"psi shape with anchor: {psi_anchor.shape}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-zero idx:\n",
      " tensor([[[0, 1, 3],\n",
      "         [0, 2, 6]],\n",
      "\n",
      "        [[0, 2, 6],\n",
      "         [0, 1, 3]],\n",
      "\n",
      "        [[0, 1, 4],\n",
      "         [0, 2, 5]],\n",
      "\n",
      "        [[0, 2, 5],\n",
      "         [0, 1, 4]]])\n",
      "idx shape: torch.Size([4, 2, 3])\n",
      "psi shape: torch.Size([4, 2, 3])\n",
      "idx shape with anchor: torch.Size([4, 2, 5])\n",
      "psi shape with anchor: torch.Size([4, 2, 5])\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "id": "24a4fea4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:20:12.892328Z",
     "start_time": "2025-09-29T05:20:12.888953Z"
    }
   },
   "source": [
    "def dyadic_to_dense(vals, idx, m):\n",
    "    \"\"\"\n",
    "    Convert dyadic sparse representation to dense.\n",
    "\n",
    "    Args\n",
    "    ----\n",
    "    vals : (..., fsize, m) tensor\n",
    "        Values at the dyadic non-zero indices.\n",
    "    idx : (..., fsize, m) long tensor\n",
    "        0-based global column indices in dyadic order for each level (DC is level 1).\n",
    "        The leading shape of idx must match that of vals.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    dense : (..., m) tensor\n",
    "        Dense representation.\n",
    "    \"\"\"\n",
    "    dense = torch.zeros(*vals.shape[:-1], m, device=vals.device, dtype=vals.dtype)\n",
    "    dense.scatter_(-1, idx, vals)\n",
    "    return dense"
   ],
   "outputs": [],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "id": "74ebd2d9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:20:15.095374Z",
     "start_time": "2025-09-29T05:20:15.088842Z"
    }
   },
   "source": [
    "psi_dense = dyadic_to_dense(psi, idx, 2**L-1)\n",
    "print(f\"psi_dense:\\n {psi_dense}\")\n",
    "print(f\"psi_dense shape: {psi_dense.shape}\")\n",
    "\n",
    "psi_dense_anchor = dyadic_to_dense(psi_anchor, idx_anchor, 2**L+1)\n",
    "print(f\"psi_dense shape with anchor: {psi_dense_anchor.shape}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "psi_dense:\n",
      " tensor([[[0.1307, 0.1962, 0.0000, 0.2818, 0.0000, 0.0000, 0.0000],\n",
      "         [0.2627, 0.0000, 0.3944, 0.0000, 0.0000, 0.0000, 0.1407]],\n",
      "\n",
      "        [[0.2627, 0.0000, 0.3944, 0.0000, 0.0000, 0.0000, 0.1407],\n",
      "         [0.1307, 0.1962, 0.0000, 0.2818, 0.0000, 0.0000, 0.0000]],\n",
      "\n",
      "        [[0.4660, 0.2950, 0.0000, 0.0000, 0.2818, 0.0000, 0.0000],\n",
      "         [0.4660, 0.0000, 0.2950, 0.0000, 0.0000, 0.2818, 0.0000]],\n",
      "\n",
      "        [[0.4660, 0.0000, 0.2950, 0.0000, 0.0000, 0.2818, 0.0000],\n",
      "         [0.4660, 0.2950, 0.0000, 0.0000, 0.2818, 0.0000, 0.0000]]])\n",
      "psi_dense shape: torch.Size([4, 2, 7])\n",
      "psi_dense shape with anchor: torch.Size([4, 2, 9])\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "id": "b3836180",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-29T05:20:16.716856Z",
     "start_time": "2025-09-29T05:20:16.711131Z"
    }
   },
   "source": [
    "import numpy as np\n",
    "\n",
    "def anchor_points(x: torch.Tensor, ell_c: float = 1.0):\n",
    "        x = torch.exp(- (x / ell_c)) + torch.exp(- ((1 - x) / ell_c))\n",
    "        coeff = torch.tensor([1.0 / np.sqrt(2.0 * (1 + np.exp(- 1.0 / ell_c))), 1.0 / np.sqrt(2.0 * (1 - np.exp(- 1.0 / ell_c)))], device=x.device, dtype=x.dtype)\n",
    "        res = x.unsqueeze(-1) @ coeff.unsqueeze(0)  # (..., 1)\n",
    "        return res  # (..., 2)\n",
    "    \n",
    "x = torch.Tensor(2, 3)\n",
    "print(x)\n",
    "res = anchor_points(x)\n",
    "print(res)\n",
    "print(res.shape)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[3.0779e-34, 0.0000e+00, 0.0000e+00],\n",
      "        [0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "tensor([[[0.8270, 1.2166],\n",
      "         [0.8270, 1.2166],\n",
      "         [0.8270, 1.2166]],\n",
      "\n",
      "        [[0.8270, 1.2166],\n",
      "         [0.8270, 1.2166],\n",
      "         [0.8270, 1.2166]]])\n",
      "torch.Size([2, 3, 2])\n"
     ]
    }
   ],
   "execution_count": 11
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dgp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
