{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e6925ad2-5c3f-4313-9dde-49f235b14b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import time \n",
    "import numpy as np\n",
    "\n",
    "from MatrixSketch.JL import JL\n",
    "from MatrixSketch.PrioritySampling import PrioritySampling\n",
    "from MatrixSketch.ThresholdSampling import ThresholdSampling\n",
    "\n",
    "from utils import plot_data, create_matrix_pair_with_outlier\n",
    "\n",
    "from Hash.Hash import Hash\n",
    "from Hash.KWiseHash import KWiseHash\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "eb657e34-52a9-4523-948b-d4876ad647e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gather_data_regression_priority_sampling(A, B, sampling_ratio_A, sampling_ratio_B, matrices_size):\n",
    "    sample_size = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) / (1.25 * (1-zero_row_fraction))\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    \n",
    "    norm_B = torch.linalg.norm(B).to('cpu')\n",
    "    norm_A = torch.linalg.norm(A).to('cpu')\n",
    "    \n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            hash_func, _ = KWiseHash().hash(torch.ones(n), seed=(i * rep + r))\n",
    "\n",
    "            s_a = PrioritySampling(hash_function=hash_func, sample_size=int(sample_size[i]))\n",
    "            s_b = PrioritySampling(hash_function=hash_func, sample_size=int(sample_size[i]))\n",
    "\n",
    "            s_a.hash(data_set=A, sampling_ratio=sampling_ratio_A)\n",
    "            s_b.hash(data_set=B, sampling_ratio=sampling_ratio_B)\n",
    "            storage[i] += (len(s_a) + len(s_b)) / matrices_size * 1.25\n",
    "            # storage[i] += (sample_size[i] * d + sample_size[i] * d) / matrices_size\n",
    "\n",
    "            reg_error = torch.linalg.norm((s_a * s_b) - A.T @ B).to('cpu')\n",
    "\n",
    "            err[i] += abs(reg_error) / (norm_B * norm_A)\n",
    "    return storage / rep, err / rep\n",
    "\n",
    "def gather_data_regression_threshold_sampling(A, B, sampling_ratio_A, sampling_ratio_B, matrices_size):\n",
    "    sample_size = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) /(1.25 * (1-zero_row_fraction))\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    \n",
    "    norm_B = torch.linalg.norm(B).to('cpu')\n",
    "    norm_A = torch.linalg.norm(A).to('cpu')\n",
    "\n",
    "    sampling_ratio_acc_A = torch.sum(sampling_ratio_A).cpu()\n",
    "    sampling_ratio_acc_b = torch.sum(sampling_ratio_B).cpu()\n",
    "    \n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            hash_func, _ = KWiseHash().hash(torch.ones(n), seed=(i * rep + r))\n",
    "\n",
    "            s_a = ThresholdSampling(hash_function=hash_func, tau=sample_size[i]/sampling_ratio_acc_A)\n",
    "            s_b = ThresholdSampling(hash_function=hash_func, tau=sample_size[i]/sampling_ratio_acc_b)\n",
    "\n",
    "            s_a.hash(data_set=A, sampling_ratio=sampling_ratio_A)\n",
    "            s_b.hash(data_set=B, sampling_ratio=sampling_ratio_B)\n",
    "            storage[i] += (len(s_a) + len(s_b)) / matrices_size * 1.25\n",
    "            # storage[i] += (sample_size[i] * d + sample_size[i] * d) / matrices_size\n",
    "\n",
    "            reg_error = torch.linalg.norm((s_a * s_b) - A.T @ B).to('cpu')\n",
    "\n",
    "            err[i] += abs(reg_error) / (norm_B * norm_A)\n",
    "    return storage / rep, err / rep\n",
    "\n",
    "def gather_data_regression_jl(A, B, matrices_size):\n",
    "    sample_size = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])\n",
    "\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    \n",
    "    norm_B = torch.linalg.norm(B).to('cpu')\n",
    "    norm_A = torch.linalg.norm(A).to('cpu')\n",
    "\n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            pi_sketch = torch.normal(0, torch.sqrt(1 / sample_size[i]), (sample_size[i], n), dtype=torch.float64).to('cuda')\n",
    "            \n",
    "            storage[i] += (sample_size[i] * d + sample_size[i] * d) / matrices_size\n",
    "            S_A_S_b = A.T @ pi_sketch.T @ pi_sketch @ B\n",
    "            reg_error = torch.linalg.norm(S_A_S_b - A.T @ B).to('cpu')\n",
    "            \n",
    "            err[i] += abs(reg_error) / (norm_B * norm_A)\n",
    "    return storage / rep, err / rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a1ae4e11-edc7-43cb-a307-329a4e945677",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_regression_error(A, B, file_name):\n",
    "    H = torch.linalg.pinv(A.T @ A)\n",
    "    row_norm_B = torch.sum(B ** 2, dim=1)\n",
    "    row_norm_A = torch.sum((A) ** 2, dim=1)\n",
    "        \n",
    "    matrices_size = (A.shape[0] * A.shape[1] + B.shape[0] * B.shape[1])\n",
    "    \n",
    "    storage_priority_sampling, err_priority_sampling = gather_data_regression_priority_sampling(A, B, row_norm_A, row_norm_B, matrices_size)\n",
    "    storage_threshold_sampling, err_threshold_sampling = gather_data_regression_threshold_sampling(A, B, row_norm_A, row_norm_B, matrices_size)\n",
    "    storage_jl, err_jl = gather_data_regression_jl(A, B, matrices_size)\n",
    "\n",
    "    data = [\n",
    "    (storage_priority_sampling, err_priority_sampling, 'Priority Sampling'),\n",
    "    (storage_jl, err_jl, 'JL Sketch'),\n",
    "    (storage_threshold_sampling, err_threshold_sampling, 'Threshold Sampling')\n",
    "    ]\n",
    "\n",
    "    with open(f\"./Plots/{file_name}.txt\", 'w') as file:\n",
    "        for storage, err, method in data:\n",
    "            file.write(f\"{storage}, {err}, {method}\\n\")\n",
    "    \n",
    "    plot_data(data = data, \n",
    "            ylabel=r'Error $\\frac{\\|\\tilde{A}^T\\tilde{B} - A^T B\\|_F}{\\|A\\|_F\\|B\\|_F}$',\n",
    "            xlabel=r'Sketch Size',\n",
    "            file_name=file_name,\n",
    "             )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "066bcf09-0601-44ef-ae44-a06ce191f175",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "rep = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc276e4-db95-46f3-82ad-6f9a367a7999",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d = 10000, 128\n",
    "zero_row_fraction = 0\n",
    "A_matrix, B_matrix = create_matrix_pair_with_outlier(n=n, d=d, outlier_fraction=0.1, zero_row_fraction=zero_row_fraction)\n",
    "\n",
    "plot_regression_error(A_matrix, B_matrix, \"synthetic_10000,128,zero=0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad2b4f4-1c2d-4c95-82b1-846ddab95f6f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_kivi",
   "language": "python",
   "name": "env_kivi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
