{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "fb7cf7a2-dc7d-4620-963a-9e970f711580",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "using Plots\n",
    "using Random\n",
    "using Distributions\n",
    "using LinearAlgebra\n",
    "include(\"pogm_restart.jl\")\n",
    "using ProgressMeter\n",
    "using HePPCAT\n",
    "using Distributed\n",
    "using JLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "f24fbb7e-b7a0-47dd-8a0b-39e65bbd4885",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "D = 10 # ambient space dimension\n",
    "d = 3 # subspace dimension\n",
    "goodPoints = 6 # points in group 1\n",
    "badPoints = 600 # points in group 2\n",
    "ν1 = 0.1 # noise variance group 1\n",
    "ν2 = 0.1; # noise variance grup 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "cabbd42e-979d-43db-aa0e-7877198d6657",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "generateTrainingData (generic function with 1 method)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function generateSubspace(goodPoints,badPoints,dimSubspace,ambientSpace)\n",
    "    U = svd(rand(ambientSpace,goodPoints+badPoints)).U[:,1:dimSubspace]\n",
    "    return U\n",
    "end\n",
    "\n",
    "function generateTrainingData(U, ν1,ν2,goodPoints,badPoints)\n",
    "    window = 10\n",
    "    ambientSpace, dimSubspace = size(U)\n",
    "    X = U*rand(Uniform(-window,window),dimSubspace,goodPoints+badPoints) #U*U'*rand(Uniform(-100,100),D,N)\n",
    "    Y = zeros(ambientSpace,goodPoints+badPoints)\n",
    "    Y[:,1:goodPoints] = X[:,1:goodPoints] +  rand(Normal(0,sqrt(ν1)),ambientSpace,goodPoints)\n",
    "    Y[:,(goodPoints+1):end] = X[:,(goodPoints+1):end] +  rand(Normal(0,sqrt(ν2)),ambientSpace,badPoints)\n",
    "    return Y\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "912c75c3-d681-4c16-accb-9007a77b74ff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U1 = generateSubspace(goodPoints,badPoints,d,D);\n",
    "Y = generateTrainingData(U1,ν1,ν2,goodPoints,badPoints);\n",
    "Π = vec(vcat(ν1*ones(goodPoints,1), ν2*ones(badPoints,1)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "d085dbcc-ef9e-440d-beef-274394fbb56b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HPCA_KNOWN (generic function with 1 method)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function fastHPCA(Y,Π,ϵ,r)\n",
    "    Π1 = Diagonal(Π.^-1)\n",
    "    specPi = opnorm(Π1)\n",
    "    U,S,Vt = svd(Y)\n",
    "    U,S,Vt = U[:,1:r], S[1:r], Vt[:,1:r]\n",
    "    L = U*Diagonal(sqrt.(S))\n",
    "    R = Vt*Diagonal(sqrt.(S))\n",
    "    Lk = deepcopy(L)\n",
    "    Rk = deepcopy(R)\n",
    "    normL = norm(L)\n",
    "    count = 0\n",
    "    while norm(Lk - L)/normL > ϵ || count < 100\n",
    "        L = Lk\n",
    "        R = Rk\n",
    "        αL = 0.5 #(specPi*opnorm(R)^2)^-1\n",
    "        αR = 0.5 #(specPi*opnorm(L)^2)^-1\n",
    "        Lk = L + αL*(Y-L*R')*Π1*R*inv(R'*R)\n",
    "        Rk = R + αR*Π1*(Y'-R*L')*L*inv(L'*L)\n",
    "        count = count + 1\n",
    "    end\n",
    "    println(count)\n",
    "    U_HPCA = svd(Lk*Rk').U\n",
    "    return U_HPCA[:,1:r]\n",
    "end\n",
    "\n",
    "function W_NO_GROUPS(Y, L)\n",
    "    d = size(Y)[1]\n",
    "    Π = diag((1/d)*(Y-L)'*(Y-L))\n",
    "    return max.(Π, 1e-9)\n",
    "end\n",
    "\n",
    "function fastHPCA2(Y,niter,r)\n",
    "    U,S,Vt = svd(Y)\n",
    "    U,S,Vt = U[:,1:r], S[1:r], Vt[:,1:r]\n",
    "    L = U*Diagonal(sqrt.(S))\n",
    "    R = Vt*Diagonal(sqrt.(S))\n",
    "    Lk = deepcopy(L)\n",
    "    Rk = deepcopy(R)\n",
    "    Π = Diagonal(W_NO_GROUPS(Y, Lk*Rk'))\n",
    "    normL = norm(L)\n",
    "    for i=1:niter\n",
    "        L = Lk\n",
    "        R = Rk\n",
    "        #α = 0.5 \n",
    "        #Lk = L + α*(Y-L*R')*pinv(Π)*R*inv(R'*R)\n",
    "        #Rk = R + α*pinv(Π)*(Y'-R*L')*L*inv(L'*L)\n",
    "        Lk = Y*pinv(Π)*R*inv(R'*pinv(Π)*R)\n",
    "        Rk = Y'*L*inv(L'*L)\n",
    "        Π = Diagonal(W_NO_GROUPS(Y, Lk*Rk'))\n",
    "        #Zk = Y-Lk*Rk'\n",
    "        #Π = Π - 1*( (-1/2)*((Zk'*Zk).*(pinv(Π).^2)) + (D/2)*pinv(Π) )\n",
    "    end\n",
    "    U_HPCA = svd(Lk*Rk').U\n",
    "    return U_HPCA[:,1:r], Π\n",
    "end\n",
    "\n",
    "function ALPCAH_ALTMIN(Y,niter,r)\n",
    "    U,S,V = svd(Y)\n",
    "    U,S,V = U[:,1:r], S[1:r], V[:,1:r]\n",
    "    L = U*Diagonal(sqrt.(S))\n",
    "    R = V*Diagonal(sqrt.(S))\n",
    "    Π = Diagonal(W_NO_GROUPS(Y, L*R'))\n",
    "    Π1 = pinv(Π)\n",
    "    for i=1:niter\n",
    "        L = Y*Π1*R*inv(R'*Π1*R)\n",
    "        R = Y'*L*inv(L'*L)\n",
    "        Π = Diagonal(W_NO_GROUPS(Y, L*R'))\n",
    "        Π1 = pinv(Π)\n",
    "    end\n",
    "    U, S, V = svd(L*R')\n",
    "    return U[:,1:r]\n",
    "end\n",
    "\n",
    "function HPCA_KNOWN(Y, λr, w, α, ϵ)\n",
    "    Π = w.^-1\n",
    "    Lf = maximum(Π)\n",
    "    Π = Diagonal(Π)\n",
    "    x0 = zeros(size(Y))\n",
    "    grad = K -> -1*(Y-K)*Π\n",
    "    soft = (x,t) -> sign.(x) .* max.(abs.(x) .- t, 0)\n",
    "    function pssvt(x,t,N)\n",
    "        U,S,V = svd(x)\n",
    "        S[(N+1):end] = soft.(S[(N+1):end],t)\n",
    "        return U*diagm(S)*V'\n",
    "    end\n",
    "    prox1 = (z,c) -> pssvt(z, c*λr, α)\n",
    "    W, _ = pogm_restart(x0, x -> 0, grad, Lf ; g_prox=prox1, eps=ϵ, mom=:fpgm, restart=:gr) # objective(x,Y-x,λr,w)\n",
    "    U_final = svd(W).U[:,1:α]\n",
    "    return U_final\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "a6bbbeae-99e6-474f-8e2f-7d8140447e7c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.011292800842927771"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U_UNK = ALPCAH_ALTMIN(Y,1000,d);\n",
    "norm(U_UNK*U_UNK' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "35b54c8d-cc75-4b06-bf8a-1726c94b4e66",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    }
   ],
   "source": [
    "U_HPCA = fastHPCA(Y,10*Π,1e-4,d);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "c17e2de0-5863-412e-bce6-8869686be678",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.00932136044942067"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_HPCA*U_HPCA' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "7d954c19-c5c4-43ef-8373-c7764b9669a0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_PCA = svd(Y).U[:,1:d];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "e46aa07d-985d-4ca7-95a5-10bf121632a7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.009321360449420583"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_PCA*U_PCA' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "160ed510-9493-44a0-bd95-7b6aab8bb9ea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U_HPCA_NN = HPCA_KNOWN(Y, 100000, Π, d, 1e-5);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "7ef31ff1-c10d-4eef-991d-7282b9aab464",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.009321360449420566"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(U_HPCA_NN*U_HPCA_NN' - U1*U1',2)/norm(U1*U1',2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "c66620cf-9a61-46b2-b7c8-8b714912201a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.011302598151816647"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "heppCAT_nogroups = []\n",
    "for k = 1:size(Y)[2]\n",
    "    push!(heppCAT_nogroups, Y[:,k])\n",
    "end\n",
    "heppCAT_NOG = heppcat(heppCAT_nogroups, d, 1000; varfloor=1e-9)\n",
    "error_heppcat = norm(heppCAT_NOG.U*heppCAT_NOG.U' - U1*U1', 2)/norm(U1*U1', 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c34299f9-e2c6-4b31-9569-dba72d03ae08",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (32 threads) 1.8.0",
   "language": "julia",
   "name": "julia-_32-threads_-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
