{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3520635e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tqdm\n",
    "from PIL import Image, ImageFilter\n",
    "import scipy.optimize\n",
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a569dfb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']\n",
    "\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "665a8a72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def svd(A):\n",
    "    U, S, Vt = np.linalg.svd(A, full_matrices=False)\n",
    "    rank = (S > 1e-5).sum()\n",
    "    U = U[:,:rank]\n",
    "    S = S[:rank]\n",
    "    Vt = Vt[:rank, :]\n",
    "    return U, S, Vt, rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac553a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "class state:\n",
    "    def __init__(self, A, k):\n",
    "        self.A = A\n",
    "        self.d = (A.shape)[1]\n",
    "        self.subset = []\n",
    "        self.rank = 0\n",
    "        self.k = k\n",
    "        self.quadratic_form = np.zeros((self.d,self.d))\n",
    "        self.n = A.shape[0]\n",
    "        self.basis = np.zeros((1,self.d))\n",
    "\n",
    "    def append(self, i):\n",
    "        self.subset.append(i)\n",
    "        U, S, Vt, self.rank = svd(self.A[self.subset])\n",
    "        self.basis = Vt\n",
    "        num_rows_gaussian = 3*int(np.log(self.n))\n",
    "        G = np.random.normal(0,1/num_rows_gaussian,(num_rows_gaussian, self.d))\n",
    "        if self.rank <= self.k:\n",
    "            self.quadratic_form = ((G @ Vt.T) * ([(s**-1) for s in S])) @ Vt\n",
    "        else:\n",
    "            lam = (np.sum([s**2 for s in S]) - np.sum([S[i]**2 for i in range(self.k)]))/self.k\n",
    "            self.quadratic_form = ((G @ Vt.T) * [1/np.sqrt(s**2 + lam) for s in S]) @ Vt + (1/np.sqrt(lam))*(G - (G @ Vt.T) @ Vt)\n",
    "\n",
    "    def ridge_leverage_score(self, a):\n",
    "        if self.rank <= self.k:\n",
    "            a_normalized = a/np.linalg.norm(a)\n",
    "            if np.linalg.norm(self.basis @ a_normalized) < 0.999: # The coreset does not span a\n",
    "                return 1\n",
    "        return np.linalg.norm(self.quadratic_form @ a)**2\n",
    "    \n",
    "    def return_coreset(self):\n",
    "        return self.A[self.subset,:].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faaf2d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def our_coreset(A, k):\n",
    "    coreset = state(A, k)\n",
    "    n = (A.shape)[0]\n",
    "    for i in tqdm.tqdm(range(n)):\n",
    "        rls = coreset.ridge_leverage_score(A[i,:])\n",
    "#         print(np.linalg.norm(A[i,:]), rls)\n",
    "        if rls >= 1/(1+1/k):\n",
    "            coreset.append(i)\n",
    "            print(\"Appending {}-th row with leverage score {}\".format(i, rls))\n",
    "    return coreset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "606a43f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_instance(n, d, k):\n",
    "    L = np.random.randint(-100, 100, size=(n,k))\n",
    "    R = np.random.randint(-100, 100, size=(k, d))\n",
    "    G = np.random.randint(-5000, 5000, size=(n,d))\n",
    "    return (L@R + G, R)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0872f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def linf_cost(A, V):\n",
    "    # Columns of V are an orthonormal basis\n",
    "    n = A.shape[0]\n",
    "    distances = [np.sqrt(np.linalg.norm(A[i,:])**2 - np.linalg.norm(V.T @ A[i, :])**2) for i in range(n)]\n",
    "    return max(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae9367b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def width_estimation(A, k):\n",
    "    n, d = A.shape\n",
    "    B = A[1:, :] - A[0,:]\n",
    "    coreset = our_coreset(B, k)\n",
    "    coreset_matrix = coreset.return_coreset()\n",
    "    return lambda x: np.max(np.abs(coreset_matrix @ x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac016e1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_width(A, x):\n",
    "    print(np.max(A @ x))\n",
    "    print(np.min(A @ x))\n",
    "    return np.max(A @ x) - np.min(A@x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44ca301",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lra_image(image, k):\n",
    "    image_array = np.array(image)\n",
    "    U, Sigma, Vt = np.linalg.svd(image_array)\n",
    "    lra = (U[:,:k]*Sigma[:k])@Vt[:k,:]\n",
    "    def quant(x):\n",
    "        if x < 0:\n",
    "            return 0\n",
    "        elif x > 255:\n",
    "            return 255\n",
    "        else:\n",
    "            return int(x)\n",
    "    vq = np.vectorize(quant)\n",
    "    quantized = vq(lra).astype(np.uint8)\n",
    "    return Image.fromarray(quantized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e18861",
   "metadata": {},
   "source": [
    "# Estimating Distortion using Linear Programs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf637fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def distortion_fixing(j, A, B):\n",
    "#         model = gp.Model()\n",
    "#         model.setParam('OutputFlag', 0)\n",
    "\n",
    "#         num_vars = s+1\n",
    "#         variables = model.addMVar(num_vars)\n",
    "    n, d = A.shape\n",
    "    s, _ = B.shape\n",
    "\n",
    "    ub_matrix = np.zeros((2*s, s+1))\n",
    "    ub_vector = np.zeros(2*s)\n",
    "    BBt = B @ B.T\n",
    "\n",
    "    ub_matrix[:s,:s] = BBt\n",
    "    ub_matrix[s:,:s] = -BBt\n",
    "\n",
    "    ub_matrix[:,s] = -1*np.ones(2*s)\n",
    "\n",
    "    eq_matrix = np.zeros((1, s+1))\n",
    "    eq_matrix[0,:s] = B @ A[j,:]\n",
    "\n",
    "    eq_vector = np.ones(1)\n",
    "\n",
    "    c = np.zeros(s+1)\n",
    "    c[-1] = 1\n",
    "\n",
    "    result_scipy = scipy.optimize.linprog(c, ub_matrix, ub_vector, eq_matrix, eq_vector)\n",
    "    return result_scipy\n",
    "\n",
    "def max_width_distortion(A, B):\n",
    "    # A is the original matrix and B is the coreset computed matrix and hence a submatrix of A\n",
    "    # Since in our algorithms the rank of B is smaller than A, the distortion can in principle be infinity\n",
    "    # So we restrict to analyzing width_distortion for x in span(B)\n",
    "    n, d = A.shape\n",
    "    s, _ = B.shape # s for sketch size which is number of rows in B\n",
    "    \n",
    "    results = np.zeros(n)\n",
    "    for j in tqdm.tqdm(range(n)):\n",
    "        result = distortion_fixing(j, A, B)\n",
    "        if result.success == True:\n",
    "            results[j] = result.fun\n",
    "            print(result.fun)\n",
    "        else:\n",
    "            print(\"Optimal Not Found\")\n",
    "            results[j] = 100\n",
    "    return results    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dba5abd4",
   "metadata": {},
   "source": [
    "# Experiments with Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb6a85f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 40000\n",
    "d = 10000\n",
    "k = 20\n",
    "M, true_factor = create_instance(n, d, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf4f4684",
   "metadata": {},
   "outputs": [],
   "source": [
    "coreset = our_coreset(M, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7ce7373",
   "metadata": {},
   "outputs": [],
   "source": [
    "coreset_matrix = coreset.return_coreset()\n",
    "coreset_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481025c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for i in range(k):\n",
    "    subspace, _ = np.linalg.qr(true_factor[:i,:].T)\n",
    "    true_cost = linf_cost(M, subspace)\n",
    "    coreset_cost = linf_cost(coreset_matrix, subspace)\n",
    "    results[i] = (true_cost, coreset_cost)\n",
    "    print((i, true_cost, coreset_cost, true_cost/coreset_cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b96607ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot([i+1 for i in range(20)], [results[i][0] for i in range(20)], '--o', label='Cost of $V_i$ w.r.t $A$')\n",
    "plt.plot([i+1 for i in range(20)], [results[i][1] for i in range(20)], '--^', label='Cost of $V_i$ w.r.t $A_S$')\n",
    "plt.legend(loc='upper right')\n",
    "plt.xlabel('$i$')\n",
    "plt.ylabel('Max. distance to subspace $V_i$')\n",
    "plt.xticks(np.arange(0, 21, 2))\n",
    "plt.savefig('synthetic_lra.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cd23e2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 40000\n",
    "d = 10000\n",
    "k = 20\n",
    "L = np.random.randint(-10,10,size=(n, k))\n",
    "R = np.random.randint(-10,10,size=(k,d))\n",
    "G = np.random.randint(-50,50,size=(n,d))\n",
    "\n",
    "A = L @ R + G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639ba194",
   "metadata": {},
   "outputs": [],
   "source": [
    "B = our_coreset(A, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45414b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "B_matrix = B.return_coreset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b238b18f",
   "metadata": {},
   "outputs": [],
   "source": [
    "B_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ef90b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = max_width_distortion(A, B_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e11cf4d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c23c36b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "max([1/r for r in res])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5e59a80",
   "metadata": {},
   "source": [
    "# Chess Wiki Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebc10d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess = Image.open(\"chess_wiki.jpg\")\n",
    "chess = chess.convert(\"L\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbb43a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "720cd871",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8ff5c9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess_lra = lra_image(chess, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb089a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess_lra.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd0d6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess_coreset = our_coreset(np.array(chess, dtype=np.int32), k)\n",
    "chess_coreset_matrix = chess_coreset.return_coreset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f8531e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess_coreset_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c2f77ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(chess).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934e355c",
   "metadata": {},
   "outputs": [],
   "source": [
    "chess_array = np.array(chess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "538f7780",
   "metadata": {},
   "outputs": [],
   "source": [
    "U, Sigma, Vt = np.linalg.svd(chess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d916482",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for i in range(k):\n",
    "    subspace = (Vt[:i,:].T)\n",
    "    true_cost = linf_cost(chess_array, subspace)\n",
    "    coreset_cost = linf_cost(chess_coreset_matrix, subspace)\n",
    "    results[i] = (true_cost, coreset_cost)\n",
    "    print((i, true_cost, coreset_cost, true_cost/coreset_cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81b7b54f",
   "metadata": {},
   "outputs": [],
   "source": [
    "max([results[i][0]/results[i][1] for i in range(k)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "029131fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_width_chess = max_width_distortion(chess_array, chess_coreset_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc6fb53",
   "metadata": {},
   "outputs": [],
   "source": [
    "max([1/r for r in res_width_chess])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3465ecc",
   "metadata": {},
   "source": [
    "# Galaxy Wiki Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ead769f",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy = Image.open(\"galaxy_wiki.jpg\")\n",
    "galaxy = galaxy.convert(\"L\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4edae5f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "058d61d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cd0a8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy_lra = lra_image(galaxy, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fba11708",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(galaxy).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62d57eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy_lra.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd79ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy_coreset = our_coreset(np.array(galaxy, dtype=np.int32), k)\n",
    "galaxy_coreset_matrix = galaxy_coreset.return_coreset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1008d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy_coreset_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01985c9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "galaxy_array = np.array(galaxy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653f9570",
   "metadata": {},
   "outputs": [],
   "source": [
    "U, Sigma, Vt = np.linalg.svd(galaxy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81f49a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for i in range(k):\n",
    "    subspace = (Vt[:i,:].T)\n",
    "    true_cost = linf_cost(galaxy_array, subspace)\n",
    "    coreset_cost = linf_cost(galaxy_coreset_matrix, subspace)\n",
    "    results[i] = (true_cost, coreset_cost)\n",
    "    print((i, true_cost, coreset_cost, true_cost/coreset_cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "711f6ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "max([results[i][0]/results[i][1] for i in range(k)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07f83c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_width_galaxy = max_width_distortion(galaxy_array, galaxy_coreset_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab28cac4",
   "metadata": {},
   "outputs": [],
   "source": [
    "max([1/r for r in res_width_galaxy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c16f70b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_width_galaxy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "009756d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
