using nonconvex_power_alm

import LinearAlgebra as LA
import ProximalOperators
import ProximalAlgorithms
import Random
import AdaProx

Random.seed!(1)

#################################################
### PARAMETERS
#################################################

SAHIN_INITIALIZATION = false    # Random if false
DATASET = clustering_fmnist     # MNIST, F-MNSIT
TOL_OBJ = 1e-4                  # Objective tolerance
TOL_CON = 1e-4                  # Constraint tolerance
VERBOSE = true                  # Verbosity of inner solver

#################################################
### Generate problem data
#################################################

### Load data set
n = 1000
s = 10 # Number of clusters
r = 20 # Rank in BM formulation
clustering, n, opt_val = generate_clustering_data(n, s, r, DATASET)

### Initialization
if SAHIN_INITIALIZATION
    # Use the same initialization as Sahin et al
    vars = nonconvex_power_alm.matread("experiments/clustering/initialization/sahin_x0.mat")
    x0 = vars["ans"]
    x0 = vec(x0')
    x0 ./= LA.norm(x0)
else
    # Random initialization
    x0 = Random.randn(n * r)
    x0 ./= LA.norm(x0)
end

λ = 0.001
normD = LA.opnorm(clustering.D)
γ_rule = (β, y) -> begin
    temp = .5 / (2 * n * β + normD + β * sqrt(2 * n) * LA.norm(y ./ β .- 1.)) 
    return temp
end
D_eigmin = LA.eigmin(clustering.D)
ρ_rule = (β, y) -> begin
    return β #* D_eigmin #* β
end
TRIPLE_LOOP = false
INNER_MAXIT = TRIPLE_LOOP ? 5000 : 1500
IPPM_MAX_IT = 5

# γ_rule = (β, y) -> begin
#     temp = .5 / (2 * n * β + normD + β * sqrt(2 * n) * LA.norm(y ./ β .- 1.)) / 5.
#     return temp
# end

#################################################
### Run power ALM with q = 0.65
#################################################

if DATASET === clustering_mnist
    println("Solving clustering problem with power ALM and q = 0.65...")
    include("run_power_alm_0.65.jl")
end

#################################################
### Run power ALM with q = 0.7
#################################################

println("Solving clustering problem with power ALM and q = 0.7...")
include("run_power_alm_0.7.jl")

# γ_rule = (β, y) -> begin
#     temp = .5 / (2 * n * β + normD + β * sqrt(2 * n) * LA.norm(y ./ β .- 1.)) / 2.
#     return temp
# end

#################################################
### Run power ALM with q = 0.75
#################################################

println("Solving clustering problem with power ALM and q = 0.75...")
include("run_power_alm_0.75.jl")

γ_rule = (β, y) -> begin
    temp = .5 / (2 * n * β + normD + β * sqrt(2 * n) * LA.norm(y ./ β .- 1.)) 
    return temp
end

#################################################
### Run power ALM with q = 0.8
#################################################

println("Solving clustering problem with power ALM and q = 0.8...")
include("run_power_alm_0.8.jl")

#################################################
### Run power ALM with q = 0.85
#################################################

println("Solving clustering problem with power ALM and q = 0.85...")
include("run_power_alm_0.85.jl")

#################################################
### Run power ALM with q = 0.9
#################################################

println("Solving clustering problem with power ALM and q = 0.9...")
include("run_power_alm_0.9.jl")

#################################################
### Run power ALM with q = 0.95
#################################################

println("Solving clustering problem with power ALM and q = 0.95...")
include("run_power_alm_0.95.jl")

#################################################
### Run power ALM with q = 1.00
#################################################

println("Solving clustering problem with power ALM and q = 1.00...")
include("run_power_alm_1.0.jl")

# #################################################
# ### Run classical ALM
# #################################################

# println("Solving clustering problem with classical ALM...")
# include("run_classical_alm.jl")

#################################################
### Results
#################################################

println("Total number of iterations: ")
if DATASET === clustering_mnist
    println("q = 0.65:\t\t\t$(res_power_065.nit)")
end
println("q = 0.70:\t\t\t$(res_power_07.nit)")
println("q = 0.75:\t\t\t$(res_power_075.nit)")
println("q = 0.80:\t\t\t$(res_power_08.nit)")
println("q = 0.85:\t\t\t$(res_power_085.nit)")
println("q = 0.90:\t\t\t$(res_power_09.nit)")
println("q = 0.95:\t\t\t$(res_power_095.nit)")
println("q = 1.00:\t\t\t$(res_power_100.nit)")
# println("Sahin ref:\t\t\t$(res_classical.nit)")