using nonconvex_power_alm

using Statistics, DelimitedFiles

import LinearAlgebra as LA
import ProximalOperators
import ProximalAlgorithms
import Random
import AdaProx

Random.seed!(1)
LOG_TO_FILE = true

#################################################
### PARAMETERS
#################################################

DATASET = quadratics_gaussian
TOL_OBJ = 1e-3                  # Objective tolerance
TOL_CON = 1e-3                  # Constraint tolerance
VERBOSE = true                 # Verbosity of inner solver

#################################################
### Generate problem data
#################################################

PROBLEM_SIZE = 1

if PROBLEM_SIZE == 1
    # Use these for the figure -> In future just put figure with m = 20, n = 100
    # m = 10
    # n = 200
    m = 20
    n = 100
end
if PROBLEM_SIZE == 2
    m = 80
    n = 400
end
quadratics = generate_quadratics_data(m, n, DATASET)
opt_val = -1e-10

# Random initialization
x0 = Random.randn(n)
x0 ./= LA.norm(x0)

normQ = LA.opnorm(quadratics.Q)
normA = LA.opnorm(quadratics.A)
normAA = LA.opnorm(quadratics.A' * quadratics.A)
γ_rule = (β, y) -> begin
    # temp = .5 / (normQ + β * normAA)# + normA * LA.norm(y))
    temp = .5 / LA.opnorm(quadratics.Q + β * quadratics.A' * quadratics.A)
    return temp
end


ρ_rule = (β, y) -> begin
    return -LA.eigmin(quadratics.Q)
end
TRIPLE_LOOP = false
if PROBLEM_SIZE == 1
    INNER_MAXIT = TRIPLE_LOOP ? 100000 : 10000
    IPPM_MAX_IT = 10000
else
    INNER_MAXIT = TRIPLE_LOOP ? 1000000 : 1000000
    IPPM_MAX_IT = 100000
end

λ = TRIPLE_LOOP ? 0.01 : 0.001
β = TRIPLE_LOOP ? (k) -> 1.1^(k - 1) : (k) -> 5 * 1.1^(k - 1)

#################################################
### Run power ALM
#################################################

γ_rule = (β, y) -> begin
    temp = .5 / (normQ + β * normAA)# + normA * LA.norm(y)) / 50
    return temp
end

# #################################################
# ### Run power ALM with q = 0.3
# #################################################

# println("Solving quadratics problem with power ALM and q = 0.3...")
# include("run_power_alm_0.3.jl")

# #################################################
# ### Run power ALM with q = 0.4
# #################################################

# println("Solving quadratics problem with power ALM and q = 0.4...")
# include("run_power_alm_0.4.jl")

# γ_rule = (β, y) -> begin
#     temp = .5 / (normQ + β * normAA) / 50# + normA * LA.norm(y)) / 50
#     return temp
# end

# #################################################
# ### Run power ALM with q = 0.5
# #################################################

# println("Solving quadratics problem with power ALM and q = 0.5...")
# include("run_power_alm_0.5.jl")

γ_rule = (β, y) -> begin
    temp = .5 / (normQ + β * normAA) / 10 # + normA * LA.norm(y)) / 10
    return temp
end

#################################################
### Run power ALM with q = 0.6
#################################################

println("Solving quadratics problem with power ALM and q = 0.6...")
include("run_power_alm_0.6.jl")

#################################################
### Run power ALM with q = 0.7
#################################################

println("Solving quadratics problem with power ALM and q = 0.7...")
include("run_power_alm_0.7.jl")

γ_rule = (β, y) -> begin
    temp = .5 / (normQ + β * normAA) # + normA * LA.norm(y))
    if PROBLEM_SIZE == 2 || n == 100
        temp = .5 / (normQ + β * normAA) / 2. # + normA * LA.norm(y))
    end
    return temp
end

# # #################################################
# # ### Run power ALM with q = 0.75
# # #################################################

# println("Solving quadratics problem with power ALM and q = 0.75...")
# include("run_power_alm_0.75.jl")

#################################################
### Run power ALM with q = 0.8
#################################################

println("Solving quadratics problem with power ALM and q = 0.8...")
include("run_power_alm_0.8.jl")

# #################################################
# ### Run power ALM with q = 0.85
# #################################################

# println("Solving quadratics problem with power ALM and q = 0.85...")
# include("run_power_alm_0.85.jl")

#################################################
### Run power ALM with q = 0.9
#################################################

γ_rule = (β, y) -> begin
    temp = .5 / (normQ + β * normAA)# + normA * LA.norm(y)) / 50
    return temp
end

println("Solving quadratics problem with power ALM and q = 0.9...")
include("run_power_alm_0.9.jl")

# # #################################################
# # ### Run power ALM with q = 0.95
# # #################################################

# println("Solving quadratics problem with power ALM and q = 0.95...")
# include("run_power_alm_0.95.jl")

################################################
## Run power ALM with q = 1.00
################################################

println("Solving quadratics problem with power ALM and q = 1.00...")
include("run_power_alm_1.0.jl")

# #################################################
# ### Run classical ALM
# #################################################

# println("Solving quadratics problem with classical ALM...")
# include("run_classical_alm.jl")

# #################################################
# ### Results
# #################################################

println("Total number of iterations: ")
# println("q = 0.30:\t\t\t$(res_power_03.nit)")
# println("q = 0.40:\t\t\t$(res_power_04.nit)")
# println("q = 0.50:\t\t\t$(res_power_05.nit)")
println("q = 0.60:\t\t\t$(res_power_06.nit)")
println("q = 0.70:\t\t\t$(res_power_07.nit)")
# println("q = 0.75:\t\t\t$(res_power_075.nit)")
println("q = 0.80:\t\t\t$(res_power_08.nit)")
# println("q = 0.85:\t\t\t$(res_power_085.nit)")
println("q = 0.90:\t\t\t$(res_power_09.nit)")
# println("q = 0.95:\t\t\t$(res_power_095.nit)")
println("q = 1.00:\t\t\t$(res_power_100.nit)")
# println("Sahin ref:\t\t\t$(res_classical.nit)")

if LOG_TO_FILE
data = LA.zeros(1, 3 * 5)

data[2] = abs(res_power_06.hist.objective[res_power_06.k+1] - opt_val)
data[1] = res_power_06.hist.feasibility[res_power_06.k+1]
data[3] = res_power_06.hist.nit[res_power_06.k+1]

data[5] = abs(res_power_07.hist.objective[res_power_07.k+1] - opt_val)
data[4] = res_power_07.hist.feasibility[res_power_07.k+1]
data[6] = res_power_07.hist.nit[res_power_07.k+1]

data[8] = abs(res_power_08.hist.objective[res_power_08.k+1] - opt_val)
data[7] = res_power_08.hist.feasibility[res_power_08.k+1]
data[9] = res_power_08.hist.nit[res_power_08.k+1]

data[11] = abs(res_power_09.hist.objective[res_power_09.k+1] - opt_val)
data[10] = res_power_09.hist.feasibility[res_power_09.k+1]
data[12] = res_power_09.hist.nit[res_power_09.k+1]

data[14] = abs(res_power_100.hist.objective[res_power_100.k+1] - opt_val)
data[13] = res_power_100.hist.feasibility[res_power_100.k+1]
data[15] = res_power_100.hist.nit[res_power_100.k+1]

io = open("experiments/quadratics/results/$(DATASET)_$(m)_$(n)_$(TRIPLE_LOOP ? "triple" : "double").csv", "a")
writedlm(io, data)
close(io)

fs = LA.zeros(1, 5)

x = res_power_06.x
fs[1] = 0.5 .* (x' * quadratics.Q * x + LA.dot(quadratics.q, x))

x = res_power_07.x
fs[2] = 0.5 .* (x' * quadratics.Q * x + LA.dot(quadratics.q, x))

x = res_power_08.x
fs[3] = 0.5 .* (x' * quadratics.Q * x + LA.dot(quadratics.q, x))

x = res_power_09.x
fs[4] = 0.5 .* (x' * quadratics.Q * x + LA.dot(quadratics.q, x))

x = res_power_100.x
fs[5] = 0.5 .* (x' * quadratics.Q * x + LA.dot(quadratics.q, x))

io = open("experiments/quadratics/results/$(DATASET)_$(m)_$(n)_$(TRIPLE_LOOP ? "triple" : "double")_fs.csv", "a")
writedlm(io, fs)
close(io)

# io = open("experiments/quadratics/results/$(DATASET)_$(m)_$(n)_$(TRIPLE_LOOP ? "triple" : "double").csv", "r")
# datas = readdlm(io)
# close(io)
# avgs = sum(datas, dims = 1) ./ size(datas)[1]
# io = open("experiments/quadratics/results/$(DATASET)_$(m)_$(n)_$(TRIPLE_LOOP ? "triple" : "double").csv", "a")
# writedlm(io, avgs)
# close(io)

end