using MKL
using TimerOutputs;
using Random
using Dates 

include("functions.jl")
include("algorithms.jl")

BLAS.set_num_threads(8);

const q = 4               # num qubits
const Q = 4               # two outcomes
const numEpochs = 200     # num. epochs 
const numEpochsBatch = 200 
const ρ_true = w_state(q) # true density matrix

const N = Q ^ q * 100     # sample size
const M = Q ^ q           # num observables
const d = 2 ^ q           # dimension

# A = randn(ComplexF64, d, d)
# A = A * A'
# const ρ_true = A / tr(A)

const filename = Dates.format(now(), "yyyy-mm-dd-HH-MM-SS")
const io       = open(filename, "a")

const to = TimerOutput();
reset_timer!(to)

@timeit to "Setup" begin
    const POVM = pauli_povm(q, Q)
    const idx_obs = rand(1: M, N)
    const outcomes = measure(ρ_true, POVM, idx_obs)
    const data = generate_data(POVM, idx_obs, outcomes)
end

# POVM = 0

∇f(λ) = ∇f(data, λ)
compute_λ(ρ) = compute_λ(data, ρ)
BatchMethods(alg) = BatchMethods(alg, numEpochsBatch, 1, io, ρ_true, N, f, ∇f, compute_λ)
StochasticQuantumSoftBayes() = StochasticQuantumSoftBayes(numEpochs, N ÷ 10, io, ρ_true, data, f, compute_λ)

try 
    global SQSB = StochasticQuantumSoftBayes()
    # global QEM  = BatchMethods("QEM")
    global RRR  = BatchMethods("RRR")
    global FW   = BatchMethods("FrankWolfe")
finally
    close(io)
end

#= 
figure(1)
plot(SQSB["numEpochs"], SQSB["fidelity"])
hold 
plot(FW["numEpochs"], FW["fidelity"])
plot(RRR["numEpochs"], RRR["fidelity"])
legend(["SQSB", "FW", "RrhoR"])
xlabel("Number of Epochs")
ylabel("Fidelity")
xlim([0, numEpochs])
ylim([0, 1])
grid("on")

figure(2)
approxOpt = minimum(vcat(SQSB["fval"], RRR["fval"], FW["fval"]))
semilogy(SQSB["numEpochs"], SQSB["fval"] .- approxOpt)
hold
semilogy(FW["numEpochs"], FW["fval"] .- approxOpt)
semilogy(RRR["numEpochs"], RRR["fval"] .- approxOpt)
legend(["SQSB", "FW", "RrhoR"])
xlim([0, numEpochs])
ylim([1e-5, 1e-1])
xlabel("Number of epochs")
ylabel("Approximate optimization error")
grid("on")

figure(3)
semilogx(SQSB["elapsedTime"], SQSB["fidelity"])
hold 
semilogx(RRR["elapsedTime"], RRR["fidelity"])
semilogx(FW["elapsedTime"], FW["fidelity"])
legend(["SQSB", "RrhoR", "FW"])
xlabel("Elpased time (seconds)")
ylabel("Fidelity")
ylim([0, 1])
grid("on")

figure(4)
loglog(SQSB["elapsedTime"], SQSB["fval"] .- approxOpt)
hold
loglog(RRR["elapsedTime"], RRR["fval"] .- approxOpt)
loglog(FW["elapsedTime"], FW["fval"] .- approxOpt)
legend(["SQSB", "RrhoR", "FW"])
xlabel("Elpased time (seconds)")
ylabel("Approximate optimization error")
grid("on") 
=#
