using Distances
using MAT
using AdaProx

@enum CLUSTERING_TYPE begin
    clustering_gaussian = 1
    clustering_mnist = 2
    clustering_fmnist = 3
end

export clustering_gaussian, clustering_mnist, clustering_fmnist

function generate_clustering_data(n, s, r, clustering_type :: CLUSTERING_TYPE)
    if clustering_type === clustering_gaussian
        z = Random.rand(n, n)
        D = pairwise(Euclidean(), z, z)

    elseif clustering_type === clustering_mnist
        n = 1000
        vars = matread("src/data/clustering_data_mnist_digits.mat")
        D = vars["Problem"]["C"]
        opt_val = vars["Problem"]["opt_val"]

        return CLUSTERING(D, s, r), n, opt_val

    elseif clustering_type === clustering_fmnist
        n = 1000
        vars = matread("src/data/clustering_data_fmnist.mat")
        D = vars["Problem"]["C"]
        opt_val = vars["Problem"]["opt_val"]

        return CLUSTERING(D, s, r), n, opt_val

    else
        error("CLUSTERING type not recognized.")

    end

    return CLUSTERING(D, s, r), n
end

struct CLUSTERING{TD, R}
    D :: TD
    s :: R
    r :: R
    x_mat
    function CLUSTERING{TD, R}(D :: TD, s :: R, r :: R) where {TD, R}
        new(D, s, r, zeros(size(D)[1], r))
    end
end

CLUSTERING(D :: TD, s :: R, r :: R) where {TD, R} = CLUSTERING{TD, R}(D, s, r)

function A_eval(clustering :: CLUSTERING, x)
    D, s, r = clustering.D, clustering.s, clustering.r
    n, _ = size(D)

    x_mat = clustering.x_mat
    x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

    return vec(x_mat * sum(x_mat, dims = 1)' .- 1.)
end

function JA_eval(clustering :: CLUSTERING, x)
    D, s, r = clustering.D, clustering.s, clustering.r
    n, _ = size(D)

    x_mat = clustering.x_mat
    x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

    x_sum = sum(x_mat, dims = 1)'
    Ax = vec(x_mat * x_sum .- 1.)

    JA = zeros(n, n * r)
    for i = 1:n
        for j = 1:n
            JAx = x[(i - 1) * r + 1 : (i - 1) * r + r]'

            if i == j
                JAx .+= x_sum'
            end

            JA[i, (j - 1) * r + 1 : (j - 1) * r + r] = JAx
        end
    end

    return JA, Ax
end

(f :: CLUSTERING)(x) = begin
    D, s, r = f.D, f.s, f.r
    n, _ = size(D)

    x_mat = f.x_mat
    x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

    return sum(vec(x_mat .* (D * x_mat)))
end

function ProximalOperators.:gradient!(y, f :: CLUSTERING, x)
    D, s, r = f.D, f.s, f.r
    n, _ = size(D)

    x_mat = f.x_mat
    x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

    # Update gradient value and evaluate f
    y .= 2 .* vec((D * x_mat)')

    return sum(vec(x_mat .* (D * x_mat)))
end

# function ProximalOperators.:gradient!(y, f :: CLUSTERING, x)
    # D, s, r = f.D, f.s, f.r
    # n, _ = size(D)

#     x_mat = f.x_mat
#     x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

#     # Update gradient value and evaluate f
#     fx = 0.
#     for i = 1:n
#         grad_i = zeros(r)
#         for j = 1:n
#             # Grad
#             grad_i .+= D[i, j] .* x[(j - 1) * r + 1 : (j - 1) * r + r]
#             grad_i .+= D[j, i] .* x[(j - 1) * r + 1 : (j - 1) * r + r]

#             # Func
#             fx += D[i, j] * LA.dot(x[(i - 1) * r + 1 : (i - 1) * r + r], x[(j - 1) * r + 1 : (j - 1) * r + r])
#         end

#         y[(i - 1) * r + 1 : (i - 1) * r + r] .= grad_i
#     end

#     return sum(vec(x_mat .* (D * x_mat)))

#     return fx
# end

function ProximalOperators.:prox!(y, g :: CLUSTERING, x, γ)
    D, r, s = g.D, g.r, g.s

    y .= max.(x, 0.)
    y .*= min.(sqrt(s) / LA.norm(y), 1.)

    return 0.
end

Base.@kwdef struct CLUSTERINGSolver{Tf, Tg, TA, TJA, Tx, Ty, R, Tβ}
    problem :: CLUSTERING
    f :: Tf = problem
    g :: Tg = problem
    A :: TA = x -> A_eval(problem, x)
    JA :: TJA = x -> JA_eval(problem, x)
    x0 :: Tx
    y0 :: Ty
    σ0 :: R
    initial_inner_tol :: R
    λ :: R
    β :: Tβ
end

Base.@kwdef mutable struct ProximalOperatorsCLUSTERING
    solver :: CLUSTERINGSolver
    state :: ALMState
    k
end

(f::ProximalOperatorsCLUSTERING)(x) = (al(f.solver, f.state, x, f.state.y, f.k))

function ProximalOperators.:gradient!(grad, clustering :: ProximalOperatorsCLUSTERING, x)
    solver = clustering.solver
    state = clustering.state
    k = clustering.k

    return al_gradx!(grad, solver, state, x, state.y, k)
end

function al_gradx!(grad, solver :: CLUSTERINGSolver, state :: ALMState, x, y, k)
    clustering = solver.problem
    D, s, r = clustering.D, clustering.s, clustering.r
    n, _ = size(D)

    x_mat = clustering.x_mat
    x_mat .= reshape(x, size(x_mat)[2], size(x_mat)[1])'

    x_sum = sum(x_mat, dims = 1)'
    Ax = vec(x_mat * x_sum .- 1.)

    grad_fx, fx = ProximalOperators.gradient(solver.f, x)

    JATy = repeat(y' * x_mat, n) .+ (y .* x_sum')
    JATy_vec = vec(JATy')

    JATAx = repeat(Ax' * x_mat, n) .+ (Ax .* x_sum')
    JATAx_vec = vec(JATAx')

    if LA.norm(Ax) == 0.
        grad[:] .= grad_fx + JATy_vec
    else
        grad[:] .= grad_fx + JATy_vec + solver.β(k) * JATAx_vec * LA.norm(Ax)^(state.q - 1)
    end

    return fx + LA.dot(y, Ax) + solver.β(k) / (state.q + 1) * LA.norm(Ax)^(state.q+1)
end

function AdaProx.eval_with_pullback(f :: ProximalOperatorsCLUSTERING, x)
    y = copy(x)

    fx = ProximalOperators.gradient!(y, f, x)
    return fx, () -> y
end

function solve_primal!(solver :: CLUSTERINGSolver, state :: ALMState, k)
    clustering = ProximalOperatorsCLUSTERING(solver, state, k)

    if state.logging && k == 1
        state.hist.objective[1] = solver.f(state.x)
        state.hist.feasibility[1] = LA.norm(solver.A(state.x))
        state.hist.nit[1] = 0
    end

    if state.triple_loop
        if !(state.inner_solver == AdaProx.adaptive_proxgrad)
            if state.adaptive_gamma
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                    adaptive = state.adaptive_gamma,
                )
            else        
                # ffb = ProximalAlgorithms.SFISTA(
                    # tol = state.inner_tol,
                    # maxit = round(state.inner_max_it),
                # )
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> begin 
                        temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                        return temp < state.inner_tol
                    end,
                    adaptive = state.adaptive_gamma,
                    gamma = state.γ_rule(solver.β(k), state.y)
                )
            end
        end

        x_temp = copy(state.x)
        iters = 0.
        ρ = state.ρ_rule(solver.β(k), state.y)

        for k = 1:state.ippm_max_it
            F = ProximalOperators.Regularize(
                clustering,
                2 * ρ,
                state.x
            )
            if state.inner_solver == AdaProx.adaptive_proxgrad
                x_ffb, iters_ffb = state.inner_solver(
                    state.x,
                    f = F,
                    g = solver.g,
                    rule = AdaProx.OurRule(gamma = state.γ_rule(solver.β(k), state.y)),
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                )
            end
            # x_ffb, iters_ffb = ffb(x0 = x_temp, f = F, g = solver.g, Lf = 2. / (state.γ_rule(solver.β(k), state.y)), mf = ρ)
            # x_ffb, iters_ffb = ffb(x0 = x_temp, f = F, g = solver.g)
            copyto!(x_temp, x_ffb)
            iters += iters_ffb

            if 2 * ρ * LA.norm(x_temp - state.x, Inf) <= 2 * state.inner_tol #|| iters >= state.inner_max_it
                break
            end  
        end

        copyto!(state.x, x_temp)
    else
        if state.adaptive_gamma
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                adaptive = state.adaptive_gamma,
            )
        else
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> begin 
                    temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                    # println(temp)
                    return temp < state.inner_tol
                end,
                adaptive = state.adaptive_gamma,
                gamma = state.γ_rule(solver.β(k), state.y)
            )
        end

        x_temp, iters = ffb(x0 = state.x, f = clustering, g = solver.g)
        copyto!(state.x, x_temp)
    end
    
    if state.verbose
        println("-----------------\tk=$(k)")
        println("Inner iterations: \t$(iters)")
        println("Objective: \t\t$(solver.f(state.x))")
        println("Constraint violation: \t$(LA.norm(solver.A(state.x)))")
        println()
    end

    if state.logging
        state.hist.objective[k+1] = solver.f(state.x)
        state.hist.feasibility[k+1] = LA.norm(solver.A(state.x))
        state.hist.nit[k+1] = state.nit + iters
    end
    
    return iters
end

export generate_clustering_data, CLUSTERING, CLUSTERINGSolver