using Pkg
Pkg.instantiate()
Pkg.activate("SAIF/FrankWolfe.jl/fw-rde")
using Statistics
using LinearAlgebra


# parse command line arguments if given
if length(ARGS) > 0
    subdir = ARGS[1]
# otherwise prompt user to specify
else
    print("Please enter sub directory to run RDE in: ")
    subdir = readline()
end
# input validation
while !isdir(joinpath(@__DIR__, subdir))
    print("Invalid directory $subdir. Please enter sub directory to run RDE in:")
    global subdir = readline()
end

using PyCall
pushfirst!(PyVector(pyimport("sys")["path"]), joinpath(@__DIR__, subdir))

import FrankWolfe
import FrankWolfe: LpNormLMO
include("custom_oralces.jl")
include(joinpath(@__DIR__, subdir, "config.jl"))  # load indices, rates, max_iter
cd(subdir)

# Get the Python side of RDE
rde = pyimport("rde")

rate = rates
idx0 = minimum(indices)
idxl = maximum(indices)
logfname = "../logs/$subdir-d-$d-rate-$rate-idx-$idx0:$idxl-max_iter-$max_iter-mode-$mode-$suffix-$optim.txt"

success = 0
l1 = Vector{Float64}()
l2 = Vector{Float64}()
l_inf = Vector{Float64}()
s_vals = Vector{Float64}()
p_l0 = Vector{Float64}()
l0_flat = Vector{Float64}() # l0 with the channels flattened

fn = 0

for idx in indices

    cifar = false 
    if subdir=="cifar10"
        cifar = true 
    end

    # Load data sample and distortion functional
    if cifar
        x, fname, true_cl = rde.get_data_sample(idx)
        fname = "$true_cl/$fname"
    else
        x, fname = rde.get_data_sample(idx)
#     rde.store_test(x, "untargeted_ktest")
    end
    
    if cifar
        f, df_s, df_p, node, target_node, f_all = rde.get_distortion(x, mode=mode, optim=optim)
    else
        f, df_s, df_p, node, target_node = rde.get_distortion(x, mode=mode, optim=optim)
    end

    if save
        rde.store_single_result(x, "xb4", fname, 0, d, test_name)
    end

    # Wrap objective and gradient functions
    function func(s, p)
        if !(s isa Vector{eltype(x)})
            s = convert(Vector{eltype(x)}, s)
        end
        if !(p isa Vector{eltype(x)})
            p = convert(Vector{eltype(x)}, p)
        end
        return f(s, p)
    end

    function grad_s!(storage, s, p)
        if !(s isa Vector{eltype(x)})
            s = convert(Vector{eltype(x)}, s)
        end
        if !(p isa Vector{eltype(x)})
            p = convert(Vector{eltype(x)}, p)
        end
        g = df_s(s, p)
        return @. storage = g
    end

    function grad_p!(storage, s, p)
        if !(s isa Vector{eltype(x)})
            s = convert(Vector{eltype(x)}, s)
        end
        if !(p isa Vector{eltype(x)})
            p = convert(Vector{eltype(x)}, p)
        end
        g = df_p(s, p)
        return @. storage = g
    end

    all_s = zeros(eltype(x), (length(rates), length(x)))

    for rate in rates
        print("\nrate: ")
        print(rate)
        print("\n")
        # Run FrankWolfe
        println("Running sample $idx with rate $rate")

        s0 = similar(x[:])
        s0 .= 0.0
        lmo_s = NonNegKSparseLMO(rate, 1.0)

        p0 = similar(x[:])
        p0 .= 0.0
        lmo_p = LpNormLMO{Float64,Inf}(d)

        @time s, v_s, p, v_p, primal, dual_gap_s, dual_gap_p  = FrankWolfe.frank_wolfe_2var(
            (s, p) -> func(s, p),
            (storage, s, p) -> grad_s!(storage, s, p),
            (storage, s, p) -> grad_p!(storage, s, p),
            lmo_s,
            lmo_p,
            s0,
            p0,
            ;fw_arguments...
        )
        # reset adaptive step size if necessary
        if fw_arguments.line_search isa FrankWolfe.MonotonousNonConvexStepSize
            fw_arguments.line_search.factor = 0
        end

        if cifar
            iter_success, iter_norm, pert_lab, l_inf_norm = rde.get_model_prediction(x, s, p, node, target_node, mode, optim, f_all(s,p), logfname, fname)
        else
            iter_success, iter_norm, l2_norm, l_inf_norm = rde.get_model_prediction(x, s, p, node, target_node, mode, optim, logfname, fname)
        end

        if optim == "univariate"
            l0_of_p = norm(p,0)
            push!(p_l0, l0_of_p)
        end

        s_norm, s_flat = rde.get_s_sum(s)

        global success += iter_success
        if iter_success == 1
            push!(l1, iter_norm)
            push!(l_inf, l_inf_norm)
            if !cifar
                push!(l2, l2_norm)
            end
            push!(s_vals, s_norm)
            push!(l0_flat, s_flat)
        end

        fname_tar = "$fname-$target_node"

        if save
            rde.store_single_result(p, "p", fname_tar, rate, d, test_name)
            if cifar
                rde.store_pert_img(x, s, p, fname_tar, rate, d, test_name, optim, true_cl, node, target_node, pert_lab)
            else
                rde.store_pert_img(x, s, p, fname_tar, rate, d, test_name, optim)
            end
        end
    end

end

# Compute stats
success = success / length(indices)

stdevl2 = std(l2)/n_pixels
mean_l2 = mean(l2)/n_pixels

stdevl1 = std(l1)/n_pixels
mean_l1 = mean(l1)/n_pixels

stdevlinf = std(l_inf)/255
l_inf_avg = mean(l_inf)/255

stdevs = std(s_vals)/n_pixels
meanofs = mean(s_vals)/n_pixels

stdev_flat_s = std(l0_flat)/n_pixels
meanof_flat_s = mean(l0_flat)/n_pixels

println("\n| accuracy: $success \n| avg L1 norm: $mean_l1 +- $stdevl1 \n| avg L2 norm: $mean_l2 +- $stdevl2 \n| avg l_inf norm: $l_inf_avg +- $stdevlinf \n| avg s: $meanofs +- $stdevs \n| flat s: $meanof_flat_s +- $stdev_flat_s")

if optim == "univariate"
    mean_of_l0_p = mean(p_l0)/n_pixels
    std_of_l0_p = std(p_l0)/n_pixels
    print("\n| sparsity of p: $mean_of_l0_p +- $std_of_l0_p\n")
end

open(logfname, "a") do io
    println(io, "\n *********************************** \n")
    println(io, "\n| accuracy: $success \n| avg L1 norm: $mean_l1 +- $stdevl1 \n| avg L2 norm: $mean_l2 +- $stdevl2 \n| avg l_inf norm: $l_inf_avg +- $stdevlinf \n| avg s: $meanofs +- $stdevs \n| avg flat s: $meanof_flat_s +- $stdev_flat_s")
    if optim == "univariate"
        println(io, "\n| sparsity of p: $mean_of_l0_p +- $std_of_l0_p") 
    end
    println(io, "\n *********************************** \n")
end
