{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb7bac73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from functools import partial\n",
    "from jax.numpy import ceil, log2, isclose, floor\n",
    "import itertools\n",
    "from jax import jit, vmap\n",
    "# https://pypi.org/project/hadamard-transform/\n",
    "from hadamard_transform import hadamard_transform\n",
    "import torch\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "id": "a9a67ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def convert_vec_to_binary(v):\n",
    "    n = v.shape[0]\n",
    "    return jnp.dot( 2 ** jnp.arange(n-1,-1,-1), v)\n",
    "\n",
    "hash_to_index = jit(vmap(convert_vec_to_binary, in_axes=(1)))\n",
    "\n",
    "\n",
    "def run(ref_wht, shifted_wht, freqs, amps, seed):\n",
    "    \"\"\"\n",
    "    ref_wht is 2^b array. Coressponding to zero shift\n",
    "    shifted_wht is n * 2^b array. Coressponding to e_i shifts\n",
    "    freqs is n * B array consisiting of current estimate of frequencies\n",
    "    amps is B array consisting of current estimates of amplitudes\n",
    "    \"\"\"\n",
    "    # hashing matrix setup\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    hashing_matrix = jax.random.randint(key=key, shape=(n,b), minval=0, maxval=2)\n",
    "    # Hash frequenices to buckets and peel\n",
    "    hashed_freqs = (hashing_matrix.T @ freqs) % 2\n",
    "    index = hash_to_index(hashed_freqs)                                                                                                    \n",
    "    ref_wht_peeled = ref_wht.at[index].add(-amps)                                                                                                 \n",
    "    signed_amps = jnp.where(freqs==0, 1, -1) * amps\n",
    "    shifted_wht_peeled = shifted_wht.at[:, index].add(-signed_amps)\n",
    "    #recover requencies, n * B array\n",
    "    recovered_freqs = jnp.where(jnp.sign(ref_wht_peeled)==jnp.sign(shifted_wht_peeled), 0, 1)\n",
    "    print(hashing_matrix, hashed_freqs, ref_wht, ref_wht_peeled, shifted_wht, shifted_wht_peeled, freqs, amps)\n",
    "    return recovered_freqs, ref_wht_peeled\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "id": "048fdb3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 0]\n",
      " [1 1]\n",
      " [0 0]\n",
      " [1 0]\n",
      " [1 0]] [[0 1 1 0]\n",
      " [0 1 0 1]] [0. 0. 0. 0.] [-1. -7. -5. -2.] [[0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]] [[-1. -7. -5. -2.]\n",
      " [-1.  7. -5.  2.]\n",
      " [ 1. -7. -5. -2.]\n",
      " [-1.  7. -5. -2.]\n",
      " [-1. -7.  5. -2.]] [[0 0 0 0]\n",
      " [0 1 0 1]\n",
      " [1 0 0 0]\n",
      " [0 0 0 1]\n",
      " [0 0 1 0]] [1. 2. 5. 7.]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[0, 0, 0, 0],\n",
       "              [0, 1, 0, 1],\n",
       "              [1, 0, 0, 0],\n",
       "              [0, 1, 0, 0],\n",
       "              [0, 0, 1, 0]], dtype=int32, weak_type=True),\n",
       " DeviceArray([-1., -7., -5., -2.], dtype=float32))"
      ]
     },
     "execution_count": 323,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = 5\n",
    "b = 2\n",
    "B = 2 ** b\n",
    "T = 3\n",
    "freqs = jnp.array([[0,0,0,0],\n",
    "                   [0,1,0,1],\n",
    "                   [1,0,0,0],\n",
    "                   [0,0,0,1],\n",
    "                   [0,0,1,0]], dtype=jnp.int32)\n",
    "amps  = jnp.array( [1,2,5,7], dtype=jnp.float32)   \n",
    "run(jnp.zeros(B), jnp.zeros((n,B)), freqs, amps, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "3df9fc9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sparse_wht(ref_wht, shifted_wht):\n",
    "    freqs = jnp.zeros((n, T*B), dtype=jnp.int32)\n",
    "    amps = jnp.zeros((T*B), dtype=jnp.int32)\n",
    "    for i in range(T):\n",
    "        recovered_freqs, recovered_amps = run(ref_wht[i], shifted_wht[i], freqs, amps, i)\n",
    "        print(\"recovered\", recovered_freqs, recovered_amps)\n",
    "        freqs = freqs.at[:,i*B:(i+1)*B].set(recovered_freqs)\n",
    "        amps = amps.at[i*B:(i+1)*B].set(recovered_amps)\n",
    "        \n",
    "    return freqs, amps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d802770d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "024da642",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_time_samples(n, b, T):\n",
    "    t_list = jnp.array(list(itertools.product([0, 1], repeat=b))).T\n",
    "    ret_value = []\n",
    "    for i in range(T):\n",
    "        key = jax.random.PRNGKey(i)\n",
    "        hashing_matrix = jax.random.randint(key=key, shape=(n,b), minval=0, maxval=2)\n",
    "        time_samples = (hashing_matrix @ t_list) %2\n",
    "        ret_value.append(time_samples)\n",
    "    # T * n * 2^b array have to ad the shifts yourself\n",
    "    return jnp.array(ret_value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "93aa0f3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[[0, 1, 0, 1],\n",
       "              [0, 0, 1, 1],\n",
       "              [0, 1, 1, 0],\n",
       "              [0, 1, 0, 1],\n",
       "              [0, 0, 0, 0]],\n",
       "\n",
       "             [[0, 1, 0, 1],\n",
       "              [0, 1, 1, 0],\n",
       "              [0, 1, 1, 0],\n",
       "              [0, 1, 0, 1],\n",
       "              [0, 1, 1, 0]],\n",
       "\n",
       "             [[0, 0, 0, 0],\n",
       "              [0, 1, 1, 0],\n",
       "              [0, 0, 1, 1],\n",
       "              [0, 0, 0, 0],\n",
       "              [0, 0, 0, 0]]], dtype=int32)"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = get_time_samples(n, b, T)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "45436ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SparseFourierDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, time_samples):\n",
    "        # T * n * 2^b array\n",
    "        time_samples = np.array(time_samples)\n",
    "        T, n, B = time_samples.shape\n",
    "        self.T, self.n, self.B = T, n, B\n",
    "        shifted_samples = []\n",
    "        for i in range(T):\n",
    "            shifts = np.eye(self.n, dtype=np.int32).reshape((n,n,1))\n",
    "            # no_shifts * n * 2^b array \n",
    "            time_samples_with_shifts = (np.tile(np.array(time_samples[i]), reps=(n, 1, 1)) + shifts)%2\n",
    "            shifted_samples.append(time_samples_with_shifts)\n",
    "        # T * no_shifts * n * 2^b array\n",
    "        shifted_samples = np.array(shifted_samples)\n",
    "        # T * n * (no_shifts*2^b array)\n",
    "        all_samples = [np.concatenate([time_samples[i], np.swapaxes(shifted_samples[i], 0,1).reshape((n,n*B))], axis=1) \n",
    "                        for i in range(T)] \n",
    "        # T * n * (no_shifts*2^b array)\n",
    "        all_samples = np.array(all_samples)\n",
    "        # n * (T * no_shifts * 2^b) array\n",
    "        all_samples = np.swapaxes(all_samples, 0,1).reshape((n,-1))\n",
    "        self.all_samples = all_samples.T\n",
    "        self.no_samples = T * B * (n+1)\n",
    "    def __len__(self):\n",
    "        return self.no_samples\n",
    "\n",
    "    # get a row at an index\n",
    "    def __getitem__(self, idx):\n",
    "        return self.all_samples[idx]\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "id": "17993ce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dummy_f(x):\n",
    "    y = torch.where(x[:,0]==0, 2, -2)\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "id": "e87275de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_and_wht(f, n, b, T, batch_size=20, num_workers=5):\n",
    "    \"\"\"\n",
    "    f: callable to be sampled\n",
    "    n: dimension of domain of f\n",
    "    b: 2^b is no buckets \n",
    "    T: no. of peeling rounds\n",
    "    \"\"\"\n",
    "    # perpeare timesamples\n",
    "    time_samples = get_time_samples(n, b, T)\n",
    "    dataset = SparseFourierDataset(time_samples)\n",
    "    no_samples = len(dataset)\n",
    "    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "    # evaluate function on time samples\n",
    "    y = torch.zeros((no_samples))\n",
    "    print(\"sampling\")\n",
    "    for i, x in tqdm.tqdm(enumerate(dataloader)):\n",
    "        y[i*batch_size:(i+1)*batch_size] = f(x)\n",
    "    \n",
    "    y = y.reshape((T*(n+1), -1))\n",
    "    # Get WHT\n",
    "    y_wht = torch.zeros(y.shape)\n",
    "    print(\"geting wht\")\n",
    "    for i in tqdm.tqdm(range(y.shape[0])):\n",
    "        y_wht[i] = hadamard_transform(y[i])\n",
    "    ref_wht_index = slice(0,-1,n+1)\n",
    "    # convert from torch tensors to jax arrays \n",
    "    print(\"converting to jax\")\n",
    "    ref_wht = jnp.array(y_wht[ref_wht_index])\n",
    "    shifted_wht = jnp.array(np.delete(y_wht, ref_wht_index, axis=0)).reshape((T,n,-1))\n",
    "    return ref_wht, shifted_wht"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 290,
   "id": "895bdbe5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sampling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4it [00:00, 88.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "geting wht\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 4450.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "converting to jax\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "ref_wht, shifted_wht = sample_and_wht(dummy_f, n, b, T, batch_size=20, num_workers=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "id": "855c65c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[0., 4., 0., 0.],\n",
       "              [0., 4., 0., 0.],\n",
       "              [4., 0., 0., 0.]], dtype=float32),\n",
       " DeviceArray([[[ 0., -4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.]],\n",
       " \n",
       "              [[ 0., -4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.],\n",
       "               [ 0.,  4.,  0.,  0.]],\n",
       " \n",
       "              [[-4.,  0.,  0.,  0.],\n",
       "               [ 4.,  0.,  0.,  0.],\n",
       "               [ 4.,  0.,  0.,  0.],\n",
       "               [ 4.,  0.,  0.,  0.],\n",
       "               [ 4.,  0.,  0.,  0.]]], dtype=float32))"
      ]
     },
     "execution_count": 297,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ref_wht, shifted_wht"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "id": "d6aa2580",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 1]\n",
      " [1 0]\n",
      " [1 1]\n",
      " [0 1]\n",
      " [0 0]] [[0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]] [0. 4. 0. 0.] [0. 4. 0. 0.] [[ 0. -4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]] [[ 0. -4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]] [[0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]] [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "recovered [[0 1 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]] [0. 4. 0. 0.]\n",
      "[[0 1]\n",
      " [1 1]\n",
      " [1 1]\n",
      " [0 1]\n",
      " [1 1]] [[0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 1 0 0 0 0 0 0 0 0 0 0]] [0. 4. 0. 0.] [0. 0. 0. 0.] [[ 0. -4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]\n",
      " [ 0.  4.  0.  0.]] [[0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]] [[0 1 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]] [0 4 0 0 0 0 0 0 0 0 0 0]\n",
      "recovered [[0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]] [0. 0. 0. 0.]\n",
      "[[0 0]\n",
      " [1 1]\n",
      " [1 0]\n",
      " [0 0]\n",
      " [0 0]] [[0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]] [4. 0. 0. 0.] [0. 0. 0. 0.] [[-4.  0.  0.  0.]\n",
      " [ 4.  0.  0.  0.]\n",
      " [ 4.  0.  0.  0.]\n",
      " [ 4.  0.  0.  0.]\n",
      " [ 4.  0.  0.  0.]] [[0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]] [[0 1 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0]] [0 4 0 0 0 0 0 0 0 0 0 0]\n",
      "recovered [[0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]\n",
      " [0 0 0 0]] [0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "freqs, amps = sparse_wht(ref_wht, shifted_wht)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "id": "e04f3342",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32),\n",
       " DeviceArray([0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "freqs, amps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb43983b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
