{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "06281d9b-4222-4b0e-b5b9-9781ec8be07a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5871530004726269"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Plots\n",
    "using Random\n",
    "using Distributions\n",
    "using LinearAlgebra\n",
    "include(\"pogm_restart_original.jl\") # issues adding MIRT so will use julia file instead\n",
    "rng = Random.seed!(0)\n",
    "N = 100\n",
    "D = 100\n",
    "d = 10\n",
    "σ1 = 2\n",
    "σ2 = 35\n",
    "goodpts = 10\n",
    "X = zeros(D,N)\n",
    "Y = zeros(D,N)\n",
    "Π = vec(zeros(N,1))\n",
    "U1 = svd(rand(D,N)).U[:,1:d]\n",
    "for k = 1:N\n",
    "    X[:,k] = U1*U1'*rand(Uniform(-100,100),D)\n",
    "end\n",
    "for j=1:N\n",
    "        if (j<= goodpts)\n",
    "            Y[:,j] = X[:,j] + rand(Normal(0,σ1),D)\n",
    "            Π[j] = σ1^2\n",
    "        else\n",
    "            Y[:,j] = X[:,j] + rand(Normal(0,σ2),D)\n",
    "            Π[j] = σ2^2\n",
    "        end\n",
    "end\n",
    "function HPCA_POGM(Y, λr, w, N, rank, ϵ)\n",
    "    Π = w.^-1\n",
    "    Lf = minimum(w)\n",
    "    Π = Diagonal(Π)\n",
    "    #x0 = zeros(size(Y))\n",
    "    U_svd = svd(Y).U[:,1:N]\n",
    "    x0 = deepcopy(U_svd*U_svd'*Y)\n",
    "    grad = K -> -1*(Y-K)*Π\n",
    "    soft = (x,t) -> sign.(x) .* max.(abs.(x) .- t, 0)\n",
    "    function pssvt(x,t,N)\n",
    "        U,S,V = svd(x)\n",
    "        S[(N+1):end] = soft.(S[(N+1):end],t)\n",
    "        return U*diagm(S)*V'\n",
    "    end\n",
    "    prox1 = (z,c) -> pssvt(z, c*λr, N)\n",
    "    K, _ = pogm_restart(x0, x -> 0, grad, Lf ; g_prox=prox1, niter=ϵ) \n",
    "    U = svd(K).U[:,1:rank]\n",
    "    return K\n",
    "end\n",
    "function HPCA_ADMM(Y, λr, w, N, μ, ρ, d, U_init)\n",
    "    #X = zeros(size(Y))\n",
    "    #Z = zeros(size(Y))\n",
    "    X = deepcopy(U_init*U_init'*Y)\n",
    "    Z = deepcopy(Y-X)\n",
    "    Π = Diagonal(w.^-1)\n",
    "    #Λ = zeros(size(Y))\n",
    "    Λ2 = sign.(Y)\n",
    "    Λ = deepcopy(Λ2 ./ (max(opnorm(Λ2), (1/λr)*norm(Λ2, Inf))))\n",
    "    normY = norm(Y,2)\n",
    "    count = 0\n",
    "    soft = (x,t) -> sign.(x) .* max.(abs.(x) .- t, 0)\n",
    "    function pssvt(x,t,N)\n",
    "        U,S,V = svd(x)\n",
    "        S[(N+1):end] = soft.(S[(N+1):end],t)\n",
    "        return U*diagm(S)*V'\n",
    "    end\n",
    "    #while norm(Y-X-Z,2)/normY > 1e-6\n",
    "    #while norm(X-X0)/norm(X0) > 1e-6\n",
    "    for i = 1:2000\n",
    "        #X0 = X\n",
    "        X = pssvt(Y-Z+(1/μ)*Λ, λr/μ,N)\n",
    "        Z = μ*(Y-X+(1/μ)*Λ)*inv(Π+μ*I)\n",
    "        Λ = Λ + μ*(Y-X-Z)\n",
    "        μ = ρ*μ\n",
    "        count = count + 1\n",
    "    end\n",
    "    U = svd(X).U[:,1:d]   \n",
    "    return U\n",
    "end\n",
    "function weightedPCA(Y , w, k)\n",
    "    L = unique(w)\n",
    "    Σ = zeros(size(Y)[1], size(Y)[1])\n",
    "    for i=1:length(L)\n",
    "        ind = findall(x -> x == L[i], w)\n",
    "        Σ = Σ + L[i]*(Y[:,ind]*Y[:,ind]')\n",
    "    end\n",
    "    U = reverse(eigvecs(Σ), dims=2)\n",
    "    Ȳ = reshape(shuffle(Y[:]), size(Y))\n",
    "    U_Y = svd(Y).S\n",
    "    U_Ȳ = svd(Ȳ).S\n",
    "    #k = floor(Int, norm(U_Y .> U_Ȳ, 0))\n",
    "    return U[:,1:k] #, k\n",
    "end\n",
    "U_WPCA = weightedPCA(Y,Π.^-1,d)\n",
    "error_wpca_10 = norm(U_WPCA*U_WPCA' - U1*U1', 2)/norm(U1*U1', 2)\n",
    "U_WPCA = weightedPCA(Y,Π.^-1,8)\n",
    "error_wpca_8 = norm(U_WPCA*U_WPCA' - U1*U1', 2)/norm(U1*U1', 2)\n",
    "U_WPCA = weightedPCA(Y,Π.^-1,12)\n",
    "error_wpca_12 = norm(U_WPCA*U_WPCA' - U1*U1', 2)/norm(U1*U1', 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b9554ec9-249a-47d2-81b5-0e461c15775a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n"
     ]
    }
   ],
   "source": [
    "λr = vcat(0:0.1:1,1:0.5:10)\n",
    "error_hpca_zero = zeros(size(λr))\n",
    "λr[1] = 0.001\n",
    "for i=1:length(λr)\n",
    "    println(i)\n",
    "    flush(stdout)\n",
    "    U_HPCA = HPCA_ADMM(Y, λr[i], Π, 10, 0.01, 1.005, d, zeros(100,10)) \n",
    "    error_hpca_zero[i] = norm(U_HPCA*U_HPCA' - U1*U1',2)/norm(U1*U1',2)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "6a83bb11-df0f-4483-a615-170fcad07a44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n"
     ]
    }
   ],
   "source": [
    "λr = vcat(0:0.1:1,1:0.5:10)\n",
    "error_hpca_rank8 = zeros(size(λr))\n",
    "λr[1] = 0.001\n",
    "for i=1:length(λr)\n",
    "    println(i)\n",
    "    flush(stdout)\n",
    "    U_HPCA = HPCA_ADMM(Y, λr[i], Π, 8, 0.01, 1.005, 8, zeros(100,10)) \n",
    "    error_hpca_rank8[i] = norm(U_HPCA*U_HPCA' - U1*U1',2)/norm(U1*U1',2)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "cc2d0704-465f-4d15-b156-e2b7fdf183e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n"
     ]
    }
   ],
   "source": [
    "λr = vcat(0:0.1:1,1:0.5:10)\n",
    "error_hpca_rank12 = zeros(size(λr))\n",
    "λr[1] = 0.001\n",
    "for i=1:length(λr)\n",
    "    println(i)\n",
    "    flush(stdout)\n",
    "    U_HPCA = HPCA_ADMM(Y, λr[i], Π, 12, 0.01, 1.005, 12, zeros(100,10)) \n",
    "    error_hpca_rank12[i] = norm(U_HPCA*U_HPCA' - U1*U1',2)/norm(U1*U1',2)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "286b0c89-57cd-472a-bbd5-a71fb3d5928f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8902024700565202"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_svd = svd(Y).U[:,1:10]\n",
    "error_svd = norm(U_svd*U_svd' - U1*U1', 2)/norm(U1*U1', 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "55c15805-2ebf-43ab-8679-71a128c73341",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LaTeXStrings\n",
    "#plot(λr, error_rpca, label=\"HPCA (α=0)\", title=\"HPCA Rank Knowledge (zero init.)\", xlabel=\"λr\", ylabel=L\"\\Vert U_{*} U_{*}^{'}- UU' \\Vert_F/ \\Vert UU' \\Vert_F\",linewidth=2)\n",
    "#plot(λr, error_hpca_rank8, label=\"HPCA (α=8)\",linewidth=2, c=:seagreen, linestyle=:dash, legend_title_font_pointsize=18)\n",
    "plot(λr, error_hpca_zero, label=\"HPCA (α=10)\",linewidth=2,title=\"HPCA Rank Knowledge (ν known, zero init.)\", xlabel=\"λr\", ylabel=L\"\\Vert U_{*} U_{*}^{'}- UU' \\Vert_F/ \\Vert UU' \\Vert_F\",c=:deepskyblue, legendfontsize=14.0)\n",
    "#plot(λr, error_hpca_zero, label=\"HPCA (α=10)\",linewidth=2,title=\"HPCA Rank Knowledge (νo init.)\", xlabel=\"λr\", ylabel=L\"\\Vert U_{*} U_{*}^{'}- UU' \\Vert_F/ \\Vert UU' \\Vert_F\",c=:deepskyblue)\n",
    "plot!(λr, error_hpca_rank8, label=\"HPCA (α=8)\",linewidth=2, c=:seagreen)\n",
    "plot!(λr, error_hpca_rank12, label=\"HPCA (α=8)\",linewidth=2, c=:orange)\n",
    "hline!([error_wpca_8], label=\"\", linewidth=2, c=:seagreen,linestyle=:dash)\n",
    "hline!([error_wpca_10], label=\"\", c=:deepskyblue, linewidth=2, linestyle=:dash)\n",
    "hline!([error_wpca_12], label=\"\", linewidth=2, c=:orange,linestyle=:dash)\n",
    "annotate!(8, error_wpca_10+0.025, \"WPCA r=10\", :deepskyblue)\n",
    "annotate!(8, error_wpca_8+0.025, \"WPCA r=8\", :seagreen)\n",
    "annotate!(8, error_wpca_12+0.025, \"WPCA r=12\", :orange)\n",
    "annotate!(3, error_svd-0.025, \"PPCA r=10\", :purple2)\n",
    "hline!([error_svd], label=\"\", linewidth=2, c=:purple2)\n",
    "#hline!([error_heppcat], label=\"HePPCAT\",linewidth=2)\n",
    "savefig(\"HPCA_KNOWN_RANK_KNOWLEDGE_ZERO_INIT.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "9fb5d85c-4225-4441-8f3e-8d0a436b92e3",
   "metadata": {},
   "outputs": [
    {
     "ename": "LoadError",
     "evalue": "UndefVarError: heppCAT not defined",
     "output_type": "error",
     "traceback": [
      "UndefVarError: heppCAT not defined",
      "",
      "Stacktrace:",
      " [1] top-level scope",
      "   @ :0",
      " [2] eval",
      "   @ ./boot.jl:373 [inlined]",
      " [3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)",
      "   @ Base ./loading.jl:1196"
     ]
    }
   ],
   "source": [
    "heppCAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf5aae7-f46c-4469-80a8-cb2d5d903d33",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.7.2",
   "language": "julia",
   "name": "julia-1.7"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
