{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "238fba78-9f1d-4e1e-921e-fd3edcad8319",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "include(\"../helperFunctions.jl\");\n",
    "using BenchmarkTools\n",
    "using LinearAlgebra\n",
    "using Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "df302d7e-4a91-4ad3-b822-0eee5376d743",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "D = 100;\n",
    "d = 10;\n",
    "v = [0.1, 100];\n",
    "N = [20, 490];\n",
    "U = generateSubspace(D, d)\n",
    "Y = generateData(U, v, N);\n",
    "varianceVector = vcat(v[1]*ones(N[1]), v[2]*ones(N[2]));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9717079a-a2fa-4338-b2c3-2ce78cc0897a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "grouplessVarianceUpdate (generic function with 1 method)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using TSVD\n",
    "function fastALPCAH(Y::Matrix, rank::Int; varfloor::Real=1e-9, alpcahIter::Int = 10, varianceMethod::Symbol = :groupless)\n",
    "    D,N = size(Y)\n",
    "    v = zeros(N)\n",
    "    # Krylov-based Lanczos Algorithm\n",
    "    T = tsvd(Y, rank)\n",
    "    L = T[1]*Diagonal(sqrt.(T[2]))\n",
    "    R = T[3]*Diagonal(sqrt.(T[2]))\n",
    "    # variance method initialization\n",
    "    if varianceMethod === :groupless\n",
    "        v = grouplessVarianceUpdate(Y, L*R'; varfloor=varfloor)\n",
    "    else\n",
    "        #v = groupedVarianceUpdate(Y, L*R' ; varfloor=varfloor)\n",
    "    end\n",
    "    Π = Diagonal(v.^-1)\n",
    "    for i=1:alpcahIter\n",
    "        # left right updates\n",
    "        L = Y*Π*R*inv(R'*Π*R)\n",
    "        R = Y'*L*inv(L'*L)\n",
    "        # variance updates\n",
    "        if varianceMethod === :groupless\n",
    "            v = grouplessVarianceUpdate(Y, L*R'; varfloor=varfloor)\n",
    "        else\n",
    "            #v = groupedVarianceUpdate(Y, L*R' ; varfloor=varfloor)\n",
    "        end\n",
    "        Π = Diagonal(v.^-1)\n",
    "    end\n",
    "    # extract left vectors from L\n",
    "    U = svd(L).U\n",
    "    return U\n",
    "end\n",
    "\n",
    "function fastALPCAH_KNOWN(Y::Matrix, rank::Int, v::Vector; alpcahIter::Int = 10)\n",
    "    # Krylov-based Lanczos Algorithm\n",
    "    T = tsvd(Y, rank)\n",
    "    L = T[1]*Diagonal(sqrt.(T[2]))\n",
    "    R = T[3]*Diagonal(sqrt.(T[2]))\n",
    "    # variance inverse matrix\n",
    "    Π = Diagonal(v.^-1)\n",
    "    for i=1:alpcahIter\n",
    "        # left right updates\n",
    "        L = Y*Π*R*inv(R'*Π*R)\n",
    "        R = Y'*L*inv(L'*L)\n",
    "    end\n",
    "    # extract left vectors from L\n",
    "    U = svd(L).U\n",
    "    return U\n",
    "end\n",
    "\n",
    "\n",
    "function grouplessVarianceUpdate(Y::Matrix, X::Matrix; varfloor::Real=1e-9)\n",
    "    D= size(Y)[1]\n",
    "    Π = (1/D)*norm.(eachcol(Y - X)).^2\n",
    "    return max.(Π, varfloor)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ac28b1fb-888e-4db7-898a-8feae96ae2b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#@btime U_ALPCAH = fastALPCAH_KNOWN(Y,d,varianceVector;alpcahIter = 1000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3c11a044-c349-426c-b1f1-508e800854eb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.24475495197803412"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_ALPCAH = fastALPCAH_KNOWN(Y,d,varianceVector;alpcahIter = 1000);\n",
    "affinityError(U, U_ALPCAH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "bca9fd73-31fb-45c7-97c9-eadc727834b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#@time U_ALPCAH = fastALPCAH(Y,d; alpcahIter = 1000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "72e8bd8b-783f-4197-8f96-e7f074999cd4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3385221824611817"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_ALPCAH = fastALPCAH(Y,d; alpcahIter = 1000);\n",
    "affinityError(U, U_ALPCAH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "69c82058-c8f4-4283-be31-4e1bb59a1371",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.1954342098372297"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_PCA = svd(Y).U[:,1:d];\n",
    "affinityError(U, U_PCA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "259e5cf8-4079-4cb7-a128-15bce0da37bd",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.45883096842354226"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using HePPCAT\n",
    "YL = [];\n",
    "for i = 1:size(Y)[2]\n",
    "    push!(YL,Y[:,i])\n",
    "end\n",
    "res = heppcat(YL,d,1000)\n",
    "U_HEPPCAT = svd(res.F).U[:,1:d];\n",
    "affinityError(U,U_HEPPCAT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6a656ffc-2df0-4d4b-92d8-0a2a1357a162",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ALPCAH_KNOWN (generic function with 1 method)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using MIRT\n",
    "\n",
    "function TSVT(A::Matrix,λ::Real,k::Int)\n",
    "    U,S,V = svd(A)\n",
    "    S[(k+1):end] = softTresh.(S[(k+1):end],λ)\n",
    "    return U*Diagonal(S)*V'\n",
    "end\n",
    "\n",
    "function softTresh(x::Real, λ::Real)\n",
    "    return sign(x) * max(abs(x) - λ, 0)\n",
    "end\n",
    "\n",
    "function ALPCAH_KNOWN(Y::Matrix, v::Vector, k::Int; alpcahIter::Int = 1000, λr::Real = 1e6)\n",
    "    Π = v.^-1\n",
    "    Lf = maximum(Π)\n",
    "    Π = Diagonal(Π)\n",
    "    U = tsvd(Y, k)[1]\n",
    "    X = U*U'*Y\n",
    "    grad = X -> -1*(Y-X)*Π\n",
    "    g_prox = (z,c) -> TSVT(z, c*λr, k)\n",
    "    X, _ = pogm_restart(X, x -> 0, grad, Lf ; mom=:fpgm, g_prox, niter=alpcahIter);\n",
    "    U = tsvd(X, k)[1]\n",
    "    return U\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "adc30bae-23d5-47fd-9967-add1b2beef31",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#@btime U_POGM = ALPCAH_KNOWN(Y, varianceVector, d; alpcahIter = 100);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f4dd68d2-47f2-4ba3-9300-faee7b17c181",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.24475498342306584"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_POGM = ALPCAH_KNOWN(Y, varianceVector, d; alpcahIter = 1000);\n",
    "affinityError(U,U_POGM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "245da74f-fb74-4e20-906d-9e3fe548adea",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ALPCAH (generic function with 1 method)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function ALPCAH(Y::Matrix, rank::Int, λr::Real; alpcahIter::Int = 200, pogmIter::Int = 5, varianceMethod::Symbol = :groupless, varfloor::Real=1e-9)\n",
    "    D, N = size(Y)\n",
    "    v = zeros(N)\n",
    "    # Krylov-based Lanczos Algorithm\n",
    "    U = tsvd(Y, rank)[1]\n",
    "    X = U*U'*Y\n",
    "    #variance method initialization\n",
    "    if varianceMethod === :groupless\n",
    "        v = grouplessVarianceUpdate(Y, X; varfloor=varfloor)\n",
    "    else\n",
    "        #v = groupedVarianceUpdate(Y, L*R' ; varfloor=varfloor)\n",
    "    end\n",
    "    Π = v.^-1\n",
    "    Lf = maximum(Π)\n",
    "    Π = Diagonal(Π)\n",
    "    X = zeros(size(Y)) # weirdly necessary to get lower error\n",
    "    grad = K -> -1*(Y-K)*Π\n",
    "    g_prox = (z,c) -> TSVT(z, c*λr, rank)\n",
    "    for i=1:alpcahIter\n",
    "        # left right updates\n",
    "        X, _ = pogm_restart(X, x -> 0, grad, Lf ; mom=:fpgm, g_prox, niter=pogmIter);\n",
    "        # variance updates\n",
    "        if varianceMethod === :groupless\n",
    "            v = grouplessVarianceUpdate(Y, X; varfloor=varfloor)\n",
    "        else\n",
    "            #v = groupedVarianceUpdate(Y, L*R' ; varfloor=varfloor)\n",
    "        end\n",
    "        Π = v.^-1\n",
    "        Lf = maximum(Π)\n",
    "        Π = Diagonal(Π)\n",
    "    end\n",
    "    # extract left vectors from X\n",
    "    U = tsvd(X, rank)[1]\n",
    "    return U\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e489a025-15d6-411e-a9c9-c282a58427be",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_test = ALPCAH(Y, d, 20; alpcahIter = 200);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ceff0303-9e43-49f1-b22e-70ab325bda13",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.22850122429780953"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "affinityError(U, U_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "79291407-e21c-4270-9b47-b7c7fb915a58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ALPCAH_KNOWN_ADMM (generic function with 1 method)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function ALPCAH_KNOWN_ADMM(Y::Matrix, k::Int, v::Vector; μ::Real=0.01, ρ::Real=1.0, λr::Real=1e6 ,alpcahIter::Int = 1000)\n",
    "    X = zeros(size(Y))\n",
    "    Z = zeros(size(Y))\n",
    "    Π = Diagonal(v.^-1)\n",
    "    Λ = sign.(Y)\n",
    "    Λ = deepcopy(Λ ./ (max(opnorm(Λ), (1/λr)*norm(Λ, Inf))))\n",
    "    for i = 1:alpcahIter\n",
    "        X = TSVT(Y-Z+(1/μ)*Λ, λr/μ,k)\n",
    "        Z = μ*(Y-X+(1/μ)*Λ)*inv(Π+μ*I)\n",
    "        Λ = Λ + μ*(Y-X-Z)\n",
    "        μ = ρ*μ\n",
    "    end\n",
    "    U = svd(X).U[:,1:k]   \n",
    "    return U\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "45465e60-dbf7-42e9-9b46-7bb14de5f7e3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ALPCAH_UNKNOWN_ADMM (generic function with 1 method)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function ALPCAH_UNKNOWN_ADMM(Y::Matrix, k::Int, λr::Real; μ::Real=0.01, ρ::Real=1.0, alpcahIter::Int=1000, varfloor::Real=1e-9)\n",
    "    U_init = svd(Y).U[:,1:k]\n",
    "    X = deepcopy(U_init*U_init'*Y)\n",
    "    Z = deepcopy(Y-X)\n",
    "    var = grouplessVarianceUpdate(Y, X; varfloor=varfloor)\n",
    "    Π = Diagonal(var.^-1)\n",
    "    Λ = sign.(Y)\n",
    "    Λ = deepcopy(Λ ./ (max(opnorm(Λ), (1/λr)*norm(Λ, Inf))))\n",
    "    for i = 1:alpcahIter\n",
    "        X = TSVT(Y-Z+(1/μ)*Λ, λr/μ,k)\n",
    "        Z = μ*(Y-X+(1/μ)*Λ)*inv(Π+μ*I)\n",
    "        Λ = Λ + μ*(Y-X-Z)\n",
    "        var = grouplessVarianceUpdate(Y, X; varfloor=varfloor)\n",
    "        Π = Diagonal(var.^-1)\n",
    "        μ = ρ*μ\n",
    "    end\n",
    "    U = svd(X).U[:,1:k]   \n",
    "    return U\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "f24686fa-d0ce-47bc-84ca-f022e9244738",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.24484468552092792"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_test = ALPCAH_KNOWN_ADMM(Y, d, varianceVector; μ=1, ρ=1.0, λr=1e6 ,alpcahIter= 1000)\n",
    "affinityError(U, U_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "b6bccb7d-370b-4b59-9294-1e979826c088",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4468192599518774"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_test = ALPCAH_UNKNOWN_ADMM(Y, d, 20; μ=0.1, ρ=1.0, alpcahIter=1000)\n",
    "affinityError(U, U_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f717818b-6865-44a9-9566-846883587ec2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (31 threads) 1.9.0",
   "language": "julia",
   "name": "julia-_31-threads_-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
