using Distributions
using StatsBase
using LinearAlgebra
using Random
using Base

function objective(A::Array{Float64, 2}, b::Array{Float64, 1}, x::Array{Float64, 1})
    n = length(b)
    return 1 / n * norm((A * x).^2 - b, 1)
end

function sto_gradient(A::Array{Float64, 1}, b::Float64, x::Array{Float64, 1})
    dot_pr = dot(A, x)
    return sign(dot_pr^2-b) * 2 * dot_pr * A
end

function loop_sgd_monte_carlo(A::Array{Float64, 2}, b::Array{Float64, 1}, x0::Array{Float64, 1}, τ0::Float64, β::Float64, β2::Float64, max_ep::Int64, flag::Int64, fopt:: Float64, mc_size::Int64, tol::Float64)
    value_arr = Float64[]
    for k in 1:mc_size
        res1 = sgd(A, b, x0, τ0, β, β2, max_ep, flag, fopt, tol)
        append!(value_arr, res1[2])
    end

    return value_arr
end
function loop_sgd_step_size(A::Array{Float64, 2}, b::Array{Float64, 1}, x0::Array{Float64, 1}, τ0::Array{Float64, 1}, β::Float64, β2::Float64, max_ep::Int64, flag::Int64, fopt:: Float64, mc_size::Int64, tol::Float64)
    len_tau = length(τ0)
    data_arr = zeros(len_tau, mc_size)
    for k in 1:len_tau
        cur_tau = τ0[k]
        res_mc = loop_sgd_monte_carlo(A, b, x0, cur_tau, β, β2, max_ep, flag, fopt, mc_size, tol)
        data_arr[k, :] = res_mc
    end

    return data_arr
end

function sgd(A::Array{Float64, 2}, b::Array{Float64, 1}, x0::Array{Float64, 1}, τ0::Float64, β::Float64, β2::Float64, max_ep::Int64, flag::Int64, fopt:: Float64, tol::Float64)
    n, d = size(A)
    x = x0
    x_avg = x
    z = x0
    m, v = zeros(d), zeros(d)
    vhat = 1e-11 * ones(d)
    if flag == 4
        v = 1e-11
    end
    obj_vec = [objective(A, b, x)]
    last_ep = max_ep
    for k in 1:max_ep
        indices = rand(1:n, n)
        if flag == 1 # SGD
            x, x_avg, τ0 = sgd_epoch(x, x_avg, τ0, indices, k)
        elseif flag == 2 # momentum SGD
            x, z, τ0 = momentum_sgd_epoch(x, x_avg, τ0, β, indices, k, z)
        elseif flag == 3 # AMSGrad
            β1 = β
            x, m, v, vhat, τ0 = amsgrad_epoch(x, x_avg, τ0, β1, β2, indices, k, m, v, vhat)
        end
        obj_val = objective(A, b, x)
        append!(obj_vec, obj_val)
        if abs(obj_val - fopt) <= tol
            last_ep = k
            break;
        end
    end

    return obj_vec, last_ep
end


function sgd_epoch(x, x_avg, τ0, indices, k)
    n = length(indices)
    for t in 1:n
        # ind = indices[t]
        ind = rand(1:n, 1)[1]
        cnt = (k-1) * n + t
        Ai= A[ind, :]
        bi=b[ind]
        τ = τ0 / sqrt(cnt+1)
        x = sgd_iter(x, τ, Ai, bi)
        x_avg = 1/cnt * x + (1 - 1/cnt) * x_avg
    end

    return x, x_avg, τ0
end

function sgd_iter(x, τ, Ai, bi)
    return x = x - τ * sto_gradient(Ai, bi, x)
end

function momentum_sgd_epoch(x, x_avg, τ0, β, indices, k, z)
    n = length(indices)
    for t in 1:n
        # ind = indices[t]
        ind = rand(1:n, 1)[1]
        cnt = (k-1)*n + t
        Ai= A[ind, :]
        bi=b[ind]
        τ = τ0 / sqrt(cnt+1)
        x, z = momentum_sgd_iter(x, τ, β, Ai, bi, z)
        x_avg = 1/cnt * x + (1 - 1/cnt) * x_avg
    end

    return x, z, τ0
end


function momentum_sgd_iter(x, τ, β, Ai, bi, z)
    x = x - τ * z
    z = β * sto_gradient(Ai, bi, x) + (1 - β) * z

    return x, z
end


function amsgrad_epoch(x, x_avg, τ0, β1, β2, indices, k, m, v, vhat)
    n = length(indices)
    for t in 1:n
        # ind = indices[t]
        ind = rand(1:n, 1)[1]
        cnt = (k-1)*n + t
        Ai= A[ind, :]
        bi=b[ind]
        # adjust the effective step size of AMSGrad to be the same with others
        if cnt == 1
            first_grad = sto_gradient(Ai, bi, x)
            first_grad_norm = β2 * first_grad .^2
            τ0 = τ0 * sqrt(maximum(first_grad_norm))
        end
        τ = τ0 / sqrt(cnt+1)
        x, m, v, vhat = amsgrad_iter(x, τ, β1, β2, Ai, bi, m, v, vhat)
        x_avg = 1/cnt * x + (1 - 1/cnt) * x_avg
    end

    return x, m, v, vhat, τ0
end


function amsgrad_iter(x, τ, β1, β2, Ai, bi, m, v, vhat)
    g = sto_gradient(Ai, bi, x)
    m = β1 * m + (1 - β1) * g
    v = β2 * v + (1 - β2) * g .^2
    vhat = max.(v, vhat)
    x = x .- (τ ./ (sqrt.(vhat))) .* m
    return x, m, v, vhat
end
