{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def svd_internal_dimensionality_reduction(tensor, num_components):\n",
    "    \"\"\"\n",
    "    Performs SVD dimensionality reduction, but returns the full tensor instead of just the reduced components.\n",
    "    \"\"\"\n",
    "    u, s, v = torch.svd(tensor)\n",
    "    return torch.matmul(u[:, :num_components] * s[:num_components], v[:, :num_components].T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(10, 3)\n",
    "B = torch.randn(3, 10)\n",
    "\n",
    "C_low_rank = torch.matmul(A, B)\n",
    "C_full_rank = torch.randn(10, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C_full_rank_svd = svd_internal_dimensionality_reduction(C_full_rank, 3)\n",
    "C_full_rank_svd.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.svd(C_full_rank_svd).S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.svd(C_low_rank).S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def random_projection_dim_reduction(tensor, target_dim):\n",
    "    \"\"\"\n",
    "    Performs random projection dimensionality reduction according to the Johnson-Lindenstrauss lemma.\n",
    "    Only reduces the inner dimensionality, does not affect the shape of the tensor\n",
    "    \"\"\"\n",
    "    original_dtype = tensor.dtype\n",
    "    original_shape = tensor.shape\n",
    "    tensor = tensor.to(dtype=torch.float32)\n",
    "\n",
    "    # generate a random matrix with entries drawn from a normal distribution\n",
    "    random_matrix = torch.randn(tensor.shape[-1], target_dim, dtype=torch.float32, device=tensor.device)\n",
    "    random_matrix /= torch.norm(random_matrix, dim=0, keepdim=True)\n",
    "\n",
    "    # project the tensor onto the random matrix, shape should not change\n",
    "    new_matrix = torch.matmul(tensor, random_matrix).to(dtype=original_dtype)\n",
    "    assert new_matrix.shape == original_shape\n",
    "    return new_matrix\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(100, 10)\n",
    "\n",
    "B = random_projection_dim_reduction(A, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "B.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "@torch.no_grad()\n",
    "def random_pruning(tensor, prune_ratio):\n",
    "    \"\"\"\n",
    "    Performs random pruning dimensionality reduction.\n",
    "    Only reduces the inner dimensionality, does not affect the shape of the tensor\n",
    "    \"\"\"\n",
    "    random_pruning_mask = torch.rand_like(tensor) > prune_ratio\n",
    "    tensor = tensor * random_pruning_mask\n",
    "    return tensor\n",
    "\n",
    "# Create a 2D tensor with random values\n",
    "tensor = torch.rand((10, 10))\n",
    "\n",
    "# Define a list of pruning ratios\n",
    "prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]\n",
    "\n",
    "# Initialize a figure\n",
    "fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))\n",
    "\n",
    "# Plot the original tensor\n",
    "axs[0].imshow(tensor.numpy(), cmap='viridis')\n",
    "axs[0].set_title('Original Tensor')\n",
    "\n",
    "# Apply pruning for each ratio and plot the resulting tensors\n",
    "for i, prune_ratio in enumerate(prune_ratios):\n",
    "    pruned_tensor = random_pruning(tensor.clone(), prune_ratio)\n",
    "    axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')\n",
    "    axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "@torch.no_grad()\n",
    "def magnitude_pruning(tensor, prune_ratio):\n",
    "    \"\"\"\n",
    "    Performs magnitude pruning dimensionality reduction.\n",
    "    Only reduces the inner dimensionality, does not affect the shape of the tensor\n",
    "    \"\"\"\n",
    "    tensor_magnitude = torch.abs(tensor)\n",
    "    threshold = torch.quantile(tensor_magnitude.flatten(), prune_ratio)\n",
    "\n",
    "    mask = tensor_magnitude > threshold\n",
    "    tensor = tensor * mask.to(dtype=tensor.dtype)\n",
    "    return tensor\n",
    "\n",
    "# Create a 2D tensor with random values\n",
    "tensor = torch.rand((10, 10))\n",
    "\n",
    "# Define a list of pruning ratios\n",
    "prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]\n",
    "\n",
    "# Initialize a figure\n",
    "fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))\n",
    "\n",
    "# Plot the original tensor\n",
    "axs[0].imshow(tensor.numpy(), cmap='viridis')\n",
    "axs[0].set_title('Original Tensor')\n",
    "\n",
    "# Apply pruning for each ratio and plot the resulting tensors\n",
    "for i, prune_ratio in enumerate(prune_ratios):\n",
    "    pruned_tensor = magnitude_pruning(tensor.clone(), prune_ratio)\n",
    "    axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')\n",
    "    axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
