{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "kOW_JdlujI0L"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import math\n",
        "from numba import njit, prange\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics.pairwise import euclidean_distances"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "@njit(fastmath=True)\n",
        "def KDE_estimator_sampling_exponential(X, q, m, scale):\n",
        "    \"\"\"\n",
        "    X: numpy array of shape (n, d)\n",
        "    q: numpy array of shape (d,)\n",
        "    m: integer, number of samples\n",
        "    \"\"\"\n",
        "    n = X.shape[0]\n",
        "    d = X.shape[1]\n",
        "\n",
        "    total_kde_sum = 0.0\n",
        "\n",
        "    for _ in range(m):\n",
        "        idx = np.random.randint(0, n)\n",
        "        dist_sq = 0.0\n",
        "        for k in range(d):\n",
        "            diff = X[idx, k] - q[k]\n",
        "            dist_sq += diff * diff\n",
        "        distance = math.sqrt(dist_sq)\n",
        "        total_kde_sum += math.exp(-distance/scale)\n",
        "\n",
        "    return total_kde_sum*n/m\n",
        "\n",
        "@njit(parallel=True, fastmath=True)\n",
        "def kernel_mvp_exact_exponential(X, Y, v, scale = 1.0):\n",
        "    # |X| by |Y| kernel matrix\n",
        "    n1 = X.shape[0]\n",
        "    n2 = Y.shape[0]\n",
        "    d = X.shape[1]\n",
        "    output = np.zeros(n1)\n",
        "    # prange enables parallel processing for the outer loop\n",
        "    for i in prange(n1):\n",
        "        curr_value = 0.0\n",
        "        for j in range(n2):\n",
        "            # Compute distance manually to avoid creating temporary arrays\n",
        "            dist_sq = 0.0\n",
        "            for k in range(d):\n",
        "                diff = X[i, k] - X[j, k]\n",
        "                dist_sq += diff * diff\n",
        "            dist = np.sqrt(dist_sq)\n",
        "            curr_value += v[j] * np.exp(-dist/scale)\n",
        "        output[i] = curr_value\n",
        "    return output\n",
        "\n",
        "\n",
        "@njit(fastmath=True)\n",
        "def get_geometric_buckets_fast(v, eps):\n",
        "    \"\"\"\n",
        "    Optimized bucketing using sorting instead of dictionaries.\n",
        "    Returns: (indices, start_indices, exponents, base)\n",
        "    Types:   (int64[], int64[], float64[], float64)\n",
        "    \"\"\"\n",
        "    n = len(v)\n",
        "\n",
        "    # Calculate base immediately so we can return it in early exits\n",
        "    base = 1.0 - (eps / 2.0)\n",
        "\n",
        "    # --- Fix 1: Handle empty input with matching types ---\n",
        "    if n == 0:\n",
        "        return (np.zeros(0, dtype=np.int64),   # Indices must be int64\n",
        "                np.zeros(0, dtype=np.int64),   # Bucket starts must be int64\n",
        "                np.zeros(0, dtype=np.float64), # Exponents must be float64\n",
        "                base)                          # Must return base\n",
        "\n",
        "    threshold = eps / (n ** 1.5)\n",
        "    log_base = math.log(base)\n",
        "\n",
        "    # Filter indices where v is significant\n",
        "    valid_mask = v >= threshold\n",
        "    valid_indices = np.where(valid_mask)[0]\n",
        "    valid_values = v[valid_mask]\n",
        "\n",
        "    # --- Fix 2: Handle no valid values with matching types ---\n",
        "    if len(valid_values) == 0:\n",
        "        return (np.zeros(0, dtype=np.int64),\n",
        "                np.zeros(0, dtype=np.int64),\n",
        "                np.zeros(0, dtype=np.float64),\n",
        "                base)\n",
        "\n",
        "    # Calculate exponents: floor(log(val) / log(base))\n",
        "    exponents = np.floor(np.log(valid_values) / log_base).astype(np.int64)\n",
        "\n",
        "    # Sort by exponent to group them together\n",
        "    sort_order = np.argsort(exponents)\n",
        "    sorted_indices = valid_indices[sort_order]\n",
        "    sorted_exponents = exponents[sort_order]\n",
        "\n",
        "    # Identify boundaries where the exponent changes\n",
        "    unique_exponents = []\n",
        "    bucket_starts = [0]\n",
        "\n",
        "    if len(sorted_exponents) > 0:\n",
        "        curr_exp = sorted_exponents[0]\n",
        "        unique_exponents.append(curr_exp)\n",
        "\n",
        "        for i in range(1, len(sorted_exponents)):\n",
        "            if sorted_exponents[i] != curr_exp:\n",
        "                bucket_starts.append(i)\n",
        "                curr_exp = sorted_exponents[i]\n",
        "                unique_exponents.append(curr_exp)\n",
        "        bucket_starts.append(len(sorted_exponents))\n",
        "\n",
        "    # Convert lists to arrays for Numba compatibility\n",
        "    return (\n",
        "        sorted_indices.astype(np.int64),\n",
        "        np.array(bucket_starts, dtype=np.int64),\n",
        "        np.array(unique_exponents, dtype=np.float64),\n",
        "        base\n",
        "    )\n",
        "\n",
        "@njit(fastmath=True, parallel=True)\n",
        "def kernel_mvp_approx_fast(X, v, eps, m, scale=1.0):\n",
        "    n = X.shape[0]\n",
        "    d = X.shape[1]\n",
        "    output = np.zeros(n, dtype=np.float64)\n",
        "    # 1. Get Buckets (Group indices)\n",
        "    sorted_indices, bucket_starts, unique_exponents, base = get_geometric_buckets_fast(v, eps)\n",
        "    num_buckets = len(unique_exponents)\n",
        "\n",
        "    # 2. Iterate over buckets\n",
        "    for b_i in range(num_buckets):\n",
        "        start = bucket_starts[b_i]\n",
        "        end = bucket_starts[b_i+1]\n",
        "\n",
        "        # Indices in X that belong to this bucket\n",
        "        curr_indices = sorted_indices[start:end]\n",
        "        count = len(curr_indices)\n",
        "\n",
        "        # Representative value for this bucket: base^exponent\n",
        "        rep_val = base ** unique_exponents[b_i]\n",
        "\n",
        "        # --- Case A: Small bucket (Exact Computation) ---\n",
        "        if count <= m:\n",
        "            # We add contributions from X[curr_indices] to ALL output[i]\n",
        "            # Parallelize over the target points (i)\n",
        "            for i in prange(n):\n",
        "                local_sum = 0.0\n",
        "                for j in range(count):\n",
        "                    # Distance between target X[i] and source X[curr_indices[j]]\n",
        "                    idx_source = curr_indices[j]\n",
        "                    dist_sq = 0.0\n",
        "                    for k in range(d):\n",
        "                        diff = X[i, k] - X[idx_source, k]\n",
        "                        dist_sq += diff * diff\n",
        "                    local_sum += math.exp(-math.sqrt(dist_sq) / scale)\n",
        "\n",
        "                output[i] += local_sum * rep_val\n",
        "\n",
        "        # --- Case B: Large bucket (Approximation via Sampling) ---\n",
        "        else:\n",
        "            # We sample 'm' indices from the current bucket\n",
        "            # Pre-select samples for this bucket once\n",
        "            sample_ptr = np.random.choice(count, m)\n",
        "            sampled_indices = curr_indices[sample_ptr]\n",
        "\n",
        "            weight = (count / m) * rep_val\n",
        "\n",
        "            # Compute contribution of these 'm' samples to ALL output[i]\n",
        "            for i in prange(n):\n",
        "                local_kde = 0.0\n",
        "                for s in range(m):\n",
        "                    s_idx = sampled_indices[s]\n",
        "                    dist_sq = 0.0\n",
        "                    for k in range(d):\n",
        "                        diff = X[i, k] - X[s_idx, k]\n",
        "                        dist_sq += diff * diff\n",
        "                    local_kde += math.exp(-math.sqrt(dist_sq) / scale)\n",
        "\n",
        "                output[i] += local_kde * weight\n",
        "\n",
        "    return output\n",
        "\n",
        "def noisy_power_method(X, eps, m, iters, scale=1.0):\n",
        "    n = X.shape[0]\n",
        "    v = np.ones(n)/((n)**(0.5))\n",
        "    for _ in range(iters):\n",
        "        Kv = kernel_mvp_approx_fast(X, v, eps, m, scale)\n",
        "        lambda_1 = Kv.dot(v)\n",
        "        v = Kv/((Kv**2).sum())**(0.5)\n",
        "    return lambda_1\n",
        "\n",
        "\n",
        "def power_method(X, iters, scale=1.0):\n",
        "    n = X.shape[0]\n",
        "    v = np.random.random(n)\n",
        "    v = v/((v**2).sum())**(0.5)\n",
        "    for _ in range(iters):\n",
        "        Kv = kernel_mvp_exact_exponential(X, X, v, scale)\n",
        "        lambda_1 = Kv.dot(v)\n",
        "        v = Kv/((Kv**2).sum())**(0.5)\n",
        "    return lambda_1\n",
        "\n"
      ],
      "metadata": {
        "id": "zYuUGJSkkozd"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.datasets import fetch_openml\n",
        "\n",
        "# Download the data (might take a moment the first time)\n",
        "mnist = fetch_openml('mnist_784', version=1)\n",
        "X = mnist[\"data\"]\n",
        "X = X.to_numpy()\n",
        "X = X/X.max()\n",
        "random_indices = np.random.choice(X.shape[0], size=1000, replace=False)\n",
        "X = X[random_indices,:]"
      ],
      "metadata": {
        "id": "qXUr5yuFnHKb"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "scale = 2.1\n",
        "iters = 200\n",
        "exact_lambda_1 = power_method(X, iters, scale)\n",
        "print(exact_lambda_1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xYL-BpRunikG",
        "outputId": "ec81b523-2b8b-45bd-b1b3-83c416b59f24"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "13.680711558992103\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "eps = 0.1\n",
        "itrs = 25\n",
        "our_ans = noisy_power_method(X, eps, int(1/eps**2), 25, 2.1)\n",
        "print(abs(our_ans - exact_lambda_1)/exact_lambda_1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uEIvsFT6nQcu",
        "outputId": "7a38094f-2eb2-41ce-df77-1af009231230"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.026659150732345798\n"
          ]
        }
      ]
    }
  ]
}