{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "fb7cf7a2-dc7d-4620-963a-9e970f711580",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "using Plots\n",
    "using Random\n",
    "using Distributions\n",
    "using LinearAlgebra\n",
    "include(\"pogm_restart.jl\")\n",
    "using ProgressMeter\n",
    "using HePPCAT\n",
    "using Distributed\n",
    "using JLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "f24fbb7e-b7a0-47dd-8a0b-39e65bbd4885",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "D = 100 # ambient space dimension\n",
    "d = 5 # subspace dimension\n",
    "goodPoints = 10 # points in group 1\n",
    "badPoints = 100 # points in group 2\n",
    "ν1 = 1 # noise variance group 1\n",
    "ν2 = 100; # noise variance grup 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "cabbd42e-979d-43db-aa0e-7877198d6657",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "generateTrainingData (generic function with 1 method)"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function generateSubspace(goodPoints,badPoints,dimSubspace,ambientSpace)\n",
    "    U = svd(rand(ambientSpace,goodPoints+badPoints)).U[:,1:dimSubspace]\n",
    "    return U\n",
    "end\n",
    "\n",
    "function generateTrainingData(U, ν1,ν2,goodPoints,badPoints)\n",
    "    window = 10\n",
    "    ambientSpace, dimSubspace = size(U)\n",
    "    X = U*rand(Uniform(-window,window),dimSubspace,goodPoints+badPoints) #U*U'*rand(Uniform(-100,100),D,N)\n",
    "    Y = zeros(ambientSpace,goodPoints+badPoints)\n",
    "    Y[:,1:goodPoints] = X[:,1:goodPoints] +  rand(Normal(0,sqrt(ν1)),ambientSpace,goodPoints)\n",
    "    Y[:,(goodPoints+1):end] = X[:,(goodPoints+1):end] +  rand(Normal(0,sqrt(ν2)),ambientSpace,badPoints)\n",
    "    return Y\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "912c75c3-d681-4c16-accb-9007a77b74ff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U1 = generateSubspace(goodPoints,badPoints,d,D);\n",
    "Y = generateTrainingData(U1,ν1,ν2,goodPoints,badPoints);\n",
    "Π = vec(vcat(ν1*ones(goodPoints,1), ν2*ones(badPoints,1)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "d085dbcc-ef9e-440d-beef-274394fbb56b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HPCA_KNOWN (generic function with 1 method)"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function fastHPCA(Y,Π,ϵ,r)\n",
    "    Π1 = Diagonal(Π.^-1)\n",
    "    specPi = opnorm(Π1)\n",
    "    U,S,Vt = svd(Y)\n",
    "    U,S,Vt = U[:,1:r], S[1:r], Vt[:,1:r]\n",
    "    L = U*Diagonal(sqrt.(S))\n",
    "    R = Vt*Diagonal(sqrt.(S))\n",
    "    Lk = deepcopy(L)\n",
    "    Rk = deepcopy(R)\n",
    "    normL = norm(L)\n",
    "    count = 0\n",
    "    while norm(Lk - L)/normL > ϵ || count < 100\n",
    "        L = Lk\n",
    "        R = Rk\n",
    "        αL = 0.5 #(specPi*opnorm(R)^2)^-1\n",
    "        αR = 0.5 #(specPi*opnorm(L)^2)^-1\n",
    "        Lk = L + αL*(Y-L*R')*Π1*R*inv(R'*R)\n",
    "        Rk = R + αR*Π1*(Y'-R*L')*L*inv(L'*L)\n",
    "        count = count + 1\n",
    "    end\n",
    "    #println(count)\n",
    "    U_HPCA = svd(Lk*Rk').U\n",
    "    return U_HPCA[:,1:r]\n",
    "end\n",
    "function HPCA_KNOWN(Y, λr, w, α, ϵ)\n",
    "    Π = w.^-1\n",
    "    Lf = maximum(Π)\n",
    "    Π = Diagonal(Π)\n",
    "    x0 = zeros(size(Y))\n",
    "    grad = K -> -1*(Y-K)*Π\n",
    "    soft = (x,t) -> sign.(x) .* max.(abs.(x) .- t, 0)\n",
    "    function pssvt(x,t,N)\n",
    "        U,S,V = svd(x)\n",
    "        S[(N+1):end] = soft.(S[(N+1):end],t)\n",
    "        return U*diagm(S)*V'\n",
    "    end\n",
    "    prox1 = (z,c) -> pssvt(z, c*λr, α)\n",
    "    W, _ = pogm_restart(x0, x -> 0, grad, Lf ; g_prox=prox1, eps=ϵ, mom=:fpgm, restart=:gr) # objective(x,Y-x,λr,w)\n",
    "    U_final = svd(W).U[:,1:α]\n",
    "    return U_final\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "35b54c8d-cc75-4b06-bf8a-1726c94b4e66",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_HPCA = fastHPCA(Y,Π,1e-4,d);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "c17e2de0-5863-412e-bce6-8869686be678",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8357948879575671"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_HPCA*U_HPCA' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "7d954c19-c5c4-43ef-8373-c7764b9669a0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_PCA = svd(Y).U[:,1:d];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "e46aa07d-985d-4ca7-95a5-10bf121632a7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.338173546122687"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_PCA*U_PCA' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "160ed510-9493-44a0-bd95-7b6aab8bb9ea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_HPCA_NN = HPCA_KNOWN(Y, 100000, Π, d, 1e-5);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "7ef31ff1-c10d-4eef-991d-7282b9aab464",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8113181217766545"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_HPCA_NN*U_HPCA_NN' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "015b7f12-44da-4352-8cd7-1d015618491b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (32 threads) 1.8.0",
   "language": "julia",
   "name": "julia-_32-threads_-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
