{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "309e4a47-9540-4cf2-9a99-bcf4aca9f49d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ext3/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import torch\n",
    "import time \n",
    "import numpy as np\n",
    "from scipy import sparse\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
    "\n",
    "from datasets import load_dataset\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "\n",
    "from MatrixSketch.JL import JL\n",
    "from MatrixSketch.PrioritySampling import PrioritySampling\n",
    "from MatrixSketch.ThresholdSampling import ThresholdSampling\n",
    "\n",
    "from Hash.Hash import Hash\n",
    "from Hash.KWiseHash import KWiseHash\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import load_android_app, load_android_app_transformer, load_imdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "953b7ba7-53e1-4815-83d0-c46ca3f0cc86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from matplotlib import rc\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "def plot_data(data, ylabel, xlabel, optimal_error, file_name):\n",
    "    matplotlib.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})\n",
    "    matplotlib.rc('text', usetex=True)\n",
    "    matplotlib.rc('text', usetex=True)\n",
    "    matplotlib.rcParams['text.latex.preamble'] = r\"\\usepackage{amsmath}\"\n",
    "    matplotlib.rc('font', family='sans-serif', size=20)\n",
    "\n",
    "    xsize = 24\n",
    "    tsize = 24\n",
    "    lsize = 20\n",
    "    figsize = (7.5, 5.8)\n",
    "    clrs = [\n",
    "        \"#d62728\",  # Brick red\n",
    "        \"#1f77b4\",  # Muted blue\n",
    "        '#FFA500',  # Yellow\n",
    "        \"#ff7f0e\",  # Safety orange\n",
    "        \"#2ca02c\",  # Cooked asparagus green\n",
    "        \"#9467bd\",  # Muted purple\n",
    "        \"#8c564b\",  # Chestnut brown\n",
    "        \"#e377c2\",  # Raspberry yogurt pink\n",
    "        \"#7f7f7f\",  # Middle gray\n",
    "        \"#bcbd22\",  # Curry yellow-green\n",
    "        \"#17becf\"  # Blue-teal\n",
    "    ]\n",
    "    clr_st = ['brown', 'gold', 'lime', 'p', 'k']\n",
    "    mrk = ['o', 's', '^', '*']\n",
    "    yticks = [0.2, 0.4, 0.6, 0.8, 1.0]\n",
    "\n",
    "\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.grid()\n",
    "    for i in range(len(data)):\n",
    "        plt.plot(data[i][0], data[i][1], linewidth=3, label=data[i][2], color=clrs[i], marker=mrk[i], markersize=12)\n",
    "\n",
    "    plt.axhline(y=optimal_error, color=\"#9467bd\", linestyle='dashdot', label=\"Optimal\", linewidth=4.0)\n",
    "    \n",
    "    plt.xticks(fontsize=tsize, usetex=True, fontname=\"Times\")\n",
    "    plt.yticks(yticks, fontsize=tsize, usetex=True, fontname=\"Times\")\n",
    "    plt.xlabel(xlabel, fontsize=xsize, usetex=True)\n",
    "    plt.ylabel(ylabel, fontsize=xsize, usetex=True)\n",
    "    plt.legend(fontsize=lsize, loc='best', frameon=True, fancybox=True, framealpha=0.8, edgecolor='k')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'./Plots/{file_name}.pdf', dpi=None, facecolor='w', format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d39bb767-ccc8-433c-ab55-80d6c920aec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gather_data_regression_priority_sampling(A, H, b, sampling_ratio_A, sampling_ratio_b, matrices_size):\n",
    "    sample_size = torch.tensor([1000, 1250, 1500, 1750, 2000, 2250])\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    \n",
    "    norm_b = torch.linalg.norm(b).to('cpu')\n",
    "    \n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            hash_func, _ = KWiseHash().hash(torch.ones(n), seed=(i * rep + r))\n",
    "\n",
    "            s_a = PrioritySampling(hash_function=hash_func, sample_size=sample_size[i])\n",
    "            s_b = PrioritySampling(hash_function=hash_func, sample_size=sample_size[i])\n",
    "\n",
    "            s_a.hash(data_set=A, sampling_ratio=sampling_ratio_A)\n",
    "            s_b.hash(data_set=b, sampling_ratio=sampling_ratio_b)\n",
    "            storage[i] += ((len(s_a) + len(s_b))* 1.25) / matrices_size \n",
    "\n",
    "            reg_error = torch.linalg.norm(A @ torch.linalg.pinv(s_a * s_a) @ (s_a * s_b) - b).to('cpu')\n",
    "\n",
    "            err[i] += abs(reg_error) / (norm_b)\n",
    "    return storage / rep, err / rep\n",
    "\n",
    "def gather_data_regression_threshold_sampling(A, H, b, sampling_ratio_A, sampling_ratio_b, matrices_size):\n",
    "    sample_size = torch.tensor([4000, 4500, 5000, 5500, 6000, 6500, 7000])\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    \n",
    "    norm_b = torch.linalg.norm(b).to('cpu')\n",
    "\n",
    "    sampling_ratio_acc_A = torch.sum(sampling_ratio_A).cpu()\n",
    "    sampling_ratio_acc_b = torch.sum(sampling_ratio_b).cpu()\n",
    "    \n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            hash_func, _ = KWiseHash().hash(torch.ones(n), seed=(i * rep + r))\n",
    "\n",
    "            s_a = ThresholdSampling(hash_function=hash_func, tau=sample_size[i]/sampling_ratio_acc_A)\n",
    "            s_b = ThresholdSampling(hash_function=hash_func, tau=sample_size[i]/sampling_ratio_acc_b)\n",
    "\n",
    "            s_a.hash(data_set=A, sampling_ratio=sampling_ratio_A)\n",
    "            s_b.hash(data_set=b, sampling_ratio=sampling_ratio_b)\n",
    "            storage[i] += ((len(s_a) + len(s_b))* 1.25) / matrices_size \n",
    "\n",
    "            reg_error = torch.linalg.norm(A @ torch.linalg.pinv(s_a * s_a) @ (s_a * s_b) - b).to('cpu')\n",
    "\n",
    "            err[i] += abs(reg_error) / (norm_b)\n",
    "    return storage / rep, err / rep\n",
    "\n",
    "def gather_data_regression_jl(A, b, matrices_size):\n",
    "    sample_size = torch.tensor([200, 225, 250, 275, 300, 325, 350, 375, 400, 425,])\n",
    "\n",
    "    storage = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    err = torch.zeros_like(sample_size, dtype=torch.float64)\n",
    "    norm_b = torch.linalg.norm(b).to('cpu')\n",
    "    \n",
    "    for i in range(len(sample_size)):\n",
    "        for r in range(rep):\n",
    "            pi_sketch = torch.normal(0, torch.sqrt(1 / sample_size[i]), (sample_size[i], n), dtype=torch.float64).to('cuda')\n",
    "            \n",
    "            storage[i] += (sample_size[i] * d + sample_size[i]) / matrices_size\n",
    "            S_A_S_b = torch.linalg.pinv(A.T @ pi_sketch.T @ pi_sketch @ A) @ A.T @ pi_sketch.T @ pi_sketch @ b\n",
    "            reg_error = torch.linalg.norm(A @ S_A_S_b - b).to('cpu')\n",
    "            \n",
    "            err[i] += abs(reg_error) / (norm_b)\n",
    "    return storage / rep, err / rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "442c57f0-dcc3-444c-a181-d9b9c521d551",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_regression_error(A, b, file_name):\n",
    "    H = torch.linalg.pinv(A.T @ A)\n",
    "    row_norm_b = torch.sum(b ** 2, dim=1)\n",
    "    row_norm_A = torch.sum((A) ** 2, dim=1)\n",
    "    \n",
    "    U, _, _ = torch.linalg.svd(A, full_matrices=False)\n",
    "    leverage_scores_A = torch.sum(U ** 2, dim=1) / d\n",
    "    \n",
    "    matrices_size = (A.shape[0] * A.shape[1] + b.shape[0] * b.shape[1])\n",
    "    optimal_error = (torch.linalg.norm(A @ H @ A.T @ b - b) / torch.linalg.norm(b)).cpu()\n",
    "    \n",
    "    storage_priority_sampling, err_priority_sampling = gather_data_regression_priority_sampling(A, H, b, leverage_scores_A, row_norm_b, matrices_size)\n",
    "    storage_threshold_sampling, err_threshold_sampling = gather_data_regression_threshold_sampling(A, H, b, leverage_scores_A, row_norm_b, matrices_size)\n",
    "    storage_jl, err_jl = gather_data_regression_jl(A, b, matrices_size)\n",
    "\n",
    "    data = [\n",
    "    (storage_priority_sampling, err_priority_sampling, 'Priority Sampling'),\n",
    "    (storage_jl, err_jl, 'JL Sketch'),\n",
    "    (storage_threshold_sampling, err_threshold_sampling, 'Threshold Sampling')\n",
    "    ]\n",
    "\n",
    "    print(f\"Optimal error : {optimal_error}\")\n",
    "\n",
    "    with open(f\"./Plots/{file_name}.txt\", 'w') as file:\n",
    "        for storage, err, method in data:\n",
    "            file.write(f\"{storage}, {err}, {method}\\n\")\n",
    "    \n",
    "    plot_data(data = data, \n",
    "            ylabel=r'Error $\\frac{\\|A \\tilde{x} - b\\|_2}{\\|b\\|_2}$',\n",
    "            xlabel=r'Sketch Size',\n",
    "            optimal_error=optimal_error,\n",
    "            file_name=file_name,\n",
    "             )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c49487b2-49a4-4f91-867e-3febe5c18641",
   "metadata": {},
   "outputs": [],
   "source": [
    "rep = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3142263-2911-47fa-8149-2c44649d9a15",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d = 10000, 256\n",
    "A_matrix, b_matrix = load_android_app_transformer(n=n, top_k=d)\n",
    "A_matrix = A_matrix.to(torch.float64)\n",
    "b_matrix = b_matrix.to(torch.float64)\n",
    "n, d = A_matrix.shape\n",
    "b_matrix = b_matrix[:, :1]\n",
    "plot_regression_error(A_matrix, b_matrix, f\"android_{n},{d}_sparse\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "159b6e05-7b71-4949-8d38-aa73327a76cf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_kivi",
   "language": "python",
   "name": "env_kivi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
