# using Random, Distributions
# using Statistics
using Arpack
using TimerOutputs
using ExponentialUtilities

function StochasticQuantumSoftBayes(
    numEpochs::Int64, 
    period::Int64, 
    io::IOStream, 
    ρ_true::Array{ComplexF64, 2}, 
    data::Array{ComplexF64, 3}, 
    f::Function, 
    compute_λ::Function
    )

    T::Int64         = size(data)[3]
    numRounds::Int64 = numEpochs * T

    output = Dict()
    lengthOutput = numRounds ÷ period + 1
    output["numEpochs"] = zeros(Float64, lengthOutput)
    output["fidelity"] = zeros(Float64, lengthOutput)
    output["fval"] = zeros(Float64, lengthOutput)
    output["elapsedTime"] = zeros(Float64, lengthOutput)

    d::Int64 = size(ρ_true)[1]
    ρ = Matrix{ComplexF64}(I, d, d) / d    # initialize at the maximally 
                                           # mixed state
    ρ_hat = Matrix{ComplexF64}(I, d, d) / d

    iter::Int64 = 1

    println("Stochastic Q-Soft-Bayes starts.")

    to = TimerOutput()

    @timeit to "iteration" begin
        idx = rand(1: T, numRounds)
        η::Float64 = sqrt( log( d ) / numRounds / d )
        η = η / ( 1.0 + η )
        σ = ( 1.0 - η ) * Matrix{ComplexF64}(I, d, d)
    end

    # method = ExpMethodHigham2005()

    # M = zeros(ComplexF64, d, d)
  
    @inbounds for t = 1: numRounds
        
        if (mod(t, period) == 0) || (t == 1) 
            output["numEpochs"][iter] = t
            output["fidelity"][iter]  = fidelity(ρ_true, ρ_hat)
            λ                         = compute_λ(ρ_hat)
            output["fval"][iter]      = f(λ)
            if t > 1
                output["elapsedTime"][iter] = TimerOutputs.time(to["iteration"])
            end

            println(io, t / T, "\t", output["fidelity"][iter], "\t", output["elapsedTime"][iter], "\t", output["fval"][iter])
            println(t / T, "\t", output["fidelity"][iter], "\t", output["elapsedTime"][iter], "\t", output["fval"][iter])
            flush(io)
            iter += 1

        end

        @timeit to "iteration" begin
            # M = data[:,:,idx[t]]
            # M = view(data, :, :, idx[t])
            # copyto!(M, view(data, :, :, idx[t]))
            # ρ = log(ρ) + log( σ + η * M / tr(M * ρ) )
            # ρ = log(ρ) + log( σ + η * M / real(M ⋅ ρ_hat) ) # anytime online to batch
            # exponential!(ρ, method)
            # ρ = exp(ρ)
            # ρ = exp(log(ρ) + log( σ + η * M / real(M ⋅ ρ_hat) )) # anytime online to batch 
            ρ = exp(log(ρ) + log(σ + (η / real(view(data, :, :, idx[t]) ⋅ ρ_hat)) * view(data, :, :, idx[t]))) # anytime online to batch 
            # ρ = exp(log(ρ) + log( σ + (η / real(M ⋅ ρ_hat)) * M )) # anytime online to batch 
            ρ /= tr(ρ)
            
            ρ_hat = (t * ρ_hat + ρ) / (t + 1.0)

        end

    end 

    output["numEpochs"] /= T
    output["elapsedTime"] *= 1e-9

    return output
    
end

function BatchMethods(
    alg::String, 
    numEpochs::Int64, 
    period::Int64, 
    io::IOStream, 
    ρ_true::Array{ComplexF64, 2}, 
    N::Int64, 
    f::Function, 
    ∇f::Function, 
    compute_λ::Function
)

d::Int64 = size(ρ_true)[1]
ρ = Matrix{ComplexF64}(I, d, d) / d    # initialize at the maximally 
                                       # mixed state

output = Dict()
output["numEpochs"]   = zeros(Float64, numEpochs)
output["fidelity"]    = zeros(Float64, numEpochs)
output["fval"]        = zeros(Float64, numEpochs)
output["elapsedTime"] = zeros(Float64, numEpochs)

println(alg, " starts.")

λ = zeros(Float64, 1, N)

to = TimerOutput()

# method = ExpMethodHigham2005()

t::Int64 = 1
@inbounds for t = 1: numEpochs
    output["numEpochs"][t] = t
    output["fidelity"][t]  = fidelity(ρ_true, ρ)
    if (t == 1) || (alg != "FrankWolfe")
        λ = compute_λ(ρ)
        output["fval"][t] = f(λ)
    else
        output["fval"][t] = fval_next
    end
    if t > 1
        output["elapsedTime"][t] = TimerOutputs.time(to["iteration"])
    end

    @timeit to "iteration" begin
        grad = ∇f(λ)
    end
    
    if alg == "RRR"
        @timeit to "iteration" begin
            ρ = grad * ρ * grad
            ρ /= tr(ρ)

        end

    elseif alg == "QEM"
        @timeit to "iteration" begin
            # ρ = log(ρ) + log(-grad)
            # exponential!(ρ, method)
            ρ = exp(log(ρ) + log(-grad))
            ρ /= tr(ρ)
            
        end
        
    elseif alg == "FrankWolfe"
        @timeit to "iteration" begin
            σ, v   = eigs(-grad, nev = 1, which = :LM)
            V      = v * v'
            η      = 2.0 / (t + 1.0)
            ρ_next = ρ + η * (V - ρ)

            global λ_next    = compute_λ(ρ_next) 
            global fval_next = f(λ_next)
            if fval_next <= output["fval"][t]
                λ = λ_next
                ρ = ρ_next
            else
                fval_next = output["fval"][t]
            end

        end
    else
        return output
    end

    if (mod(t, period) == 0) || (t == 1)
        println(io, t, "\t", output["fidelity"][t], "\t", output["elapsedTime"][t], "\t", output["fval"][t])
        println(t, "\t", output["fidelity"][t], "\t", output["elapsedTime"][t], "\t", output["fval"][t])
        flush(io)
    end
end

output["elapsedTime"] .*= 1e-9

return output

end