using LinearAlgebra
using JLD
using Optim, LineSearches
using StatsBase
using ProgressMeter

## Given a matrix, return the leverage scores

function leverages(A)
    # X = Matrix(qr(A).Q) = Economic Q factor from the QR
    # sum.(eachrow(abs2.(X))) = Squared L2 norm of each row of X
    return sum.(eachrow(abs2.(Matrix(qr(A).Q))));
end

## Given a matrix, return the Lewis Weights after 4 iterations
function lewis_weights(A, p=1)

    if p==2
        return leverages(A)
    end

    n,d = size(A)
    w = ones(n)
    for t=1:4
        W = Diagonal(w.^(1/2 - 1/p))
        levs = leverages(W * A)
        w = (w.^(2/p - 1) .* levs).^(p/2)
    end
    return w
end

## Subsampling

function subsample(A, b, probs, p, s)
    # A in an (n, d) matrix
    # b is a vector of length n
    # p is a vector of length n
    # p is the Lp parameter, used in the rescaling
    # s is the number of rows to keep
    # Subsample s entries of A wrt the entries of p

    (n,d) = size(A)
    probs = probs / sum(probs) # sums to 1 now for sure
    samples = sample(1:n, Weights(probs), s) # List of row id's to keep
    rescaling_factors = (probs*s).^(-1/p) # Rescaling factors

    SA = (Diagonal(rescaling_factors) * A)[samples',:][1,:,:] # Rescale A, keep certain rows, and drop an extra index that appears
    Sb = (Diagonal(rescaling_factors) * b)[samples]
    return SA, Sb
end

## Lp Regression (exact solution)

function LpRegression(A,b,p)
    (n,d) = size(A)

    # Optim.jl -- black box optimization, since this experiment is just about row-wise sample complexity
    f(x) = norm(A*x-b,p)
    soln = optimize(f, zeros(d), LBFGS(linesearch=LineSearches.BackTracking()); autodiff = :forward)
    return soln.minimizer

end

## Build rectangular vandermonde matrix

function Vandermonde(times, columns)
    (n, d) = (length(times), columns)
    A = ones(n,d)
    for i=1:n
        time = times[i]
        for j=2:d
            A[i,j] = time^(j-1)
        end
    end
    return A
end

## Main Algo (Algorithm 2 in the paper)

function FastVandermondeRegression(times, values, degree, p, s, A=Vandermonde(times, degree+1))
    ## times is a real vector of times when the polynomial is observed
    ## values is a real vector of the values the polynomial holds
    ## degree is the degree of the polynomial we fit
    ## p is the parameter of Lp regression
    ## s is the number of samples to take in Lewis Weight sampling
    ## The vandermonde matrix can be precomputed to save time, (e.g. if many function calls will use the same matrix).
    ## If it's precomputed, pass it in with the parameter A

    (n, d) = (length(times), degree+1) ## There's one more column than there is degree

    r = Int(floor(log2(p)))
    q = p / (2^r)
    M = Vandermonde(times, 2^r * (d-1) + 1)

    LewisWeights = lewis_weights(M, q)
    T, values_sampled = subsample(A, values, LewisWeights, q, s)

    return LpRegression(T, values_sampled, p)
end

## Uniform Algo (Alternative to Algorithm 2)

function UnifVandermondeRegression(times, values, degree, p, s, A=Vandermonde(times, degree+1))
    ## times is a real vector of times when the polynomial is observed
    ## values is a real vector of the values the polynomial holds
    ## degree is the degree of the polynomial we fit
    ## p is the parameter of Lp regression
    ## s is the number of samples to take in Lewis Weight sampling
    ## The vandermonde matrix can be precomputed to save time, (e.g. if many function calls will use the same matrix).
    ## If it's precomputed, pass it in with the parameter A

    (n, d) = (length(times), degree+1) ## There's one more column than there is degree

    T, values_sampled = subsample(A, values, ones(n), p, s) ## Uniform Subsample

    return LpRegression(T, values_sampled, p)
end

## Compare exact and approx methods on a single instance of the full Vandermonde Regression problem. Computes eps_empirical, formally speaking.

function compare_all_on_input(eg_times, eg_values, eg_degree, eg_p, eg_s)

    BigVander = Vandermonde(eg_times, eg_degree+1)

    true_soln = LpRegression(BigVander, eg_values, eg_p)
    fast_soln = FastVandermondeRegression(eg_times, eg_values, eg_degree, eg_p, eg_s, BigVander)
    unif_soln = UnifVandermondeRegression(eg_times, eg_values, eg_degree, eg_p, eg_s, BigVander)

    true_err = norm(BigVander*true_soln - eg_values, eg_p)
    fast_err = norm(BigVander*fast_soln - eg_values, eg_p)
    unif_err = norm(BigVander*unif_soln - eg_values, eg_p)
    fast_rel_err = abs(true_err - fast_err) / true_err
    unif_rel_err = abs(true_err - unif_err) / true_err

    return fast_rel_err, unif_rel_err # This is eps_empirical in the paper

end

## Generate random data, in two slightly different flavors!

function generate_randn_data(eg_n)

    eg_times = randn(eg_n);
    eg_values = randn(eg_n);

    return eg_times, eg_values

end

function generate_near_zero_response_data(eg_n, nnzeros = 10)

    eg_times = randn(eg_n);
    eg_values = zeros(eg_n);
    eg_values[1:nnzeros] .= 1;

    return eg_times, eg_values

end

function generate_monomial_data(eg_n)

    eg_times = randn(eg_n);
    eg_values = eg_times.^10 + (1e5 * randn(eg_n));

    return eg_times, eg_values

end

## Compute eps_{empirical} for a range of values of degree d, parameter p, number of rows s
# Note that in this code, s is the number of subsampled rows, while m is the equivalent symbol in the paper. Sorry about that!

function statistical_complexity(times, values, d_range, p_range, m_range)
    # times: The times to consider
    # values: The b vector
    # d_range is a range(...) of degrees
    # p_range is a range(...) of Lp regression p parameters
    # m_range is a range(...) of subsample amounts
    d_trials = length(d_range)
    p_trials = length(p_range)
    s_trials = length(m_range)

    rel_errs = zeros(d_trials, p_trials, s_trials, 2)

    for i = 1:d_trials
        for j = 1:p_trials
            for k = 1:s_trials
                fast_err, unif_err = compare_all_on_input(times, values, d_range[i], p_range[j], m_range[k])
                rel_errs[i,j,k,1] = fast_err
                rel_errs[i,j,k,2] = unif_err
            end
        end
    end

    return rel_errs
end

## Relative Error vs row-count (Figure 2a)
# Run the Lp approx vs. exact regression problem many times, and store the output in a JLD file
# The JLD file can then be opened to compute medians and quartiles for the LaTeX Figure Graphics

function lin_log_convergence_rate_p_range(data_gen_func, filename, eg_n = 1000, n_trials = 30, resolution = 8)

    d_base = 20
    p_range = range(2, stop=25, length=resolution)
    s_base = 1000

    fast_rel_errs = zeros(n_trials, length(p_range))
    unif_rel_errs = zeros(n_trials, length(p_range))

    @showprogress for t = 1:n_trials
        eg_times, eg_values = data_gen_func(eg_n)

        rel_errs = dropdims(statistical_complexity(eg_times, eg_values, [d_base] , p_range, [s_base]), dims=(1,3))
        fast_rel_errs[t, :] = rel_errs[:,1]
        unif_rel_errs[t, :] = rel_errs[:,2]

    end

    JLD.save(filename,
             "fast_rel_errs", fast_rel_errs,
             "unif_rel_errs", unif_rel_errs,
             "d_base", d_base, "p_range", p_range, "s_base", s_base,
             "n", eg_n, "n_trials", n_trials, "resolution", resolution)

end

##

# lin_log_convergence_rate_p_range(generate_monomial_data, "lin_log_convergence_sparse_25k_30trial_p_range_monomial.jld", 25000, 30);
# lin_log_convergence_rate_p_range(generate_near_zero_response_data, "lin_log_convergence_sparse_25k_500trial.jld", 25000, 500);
# lin_log_convergence_rate_p_range(generate_near_zero_response_data, "lin_log_convergence_sparse_1k_30trial.jld", 500, 10);

## Relative Error vs row-count (Figure XX)
# Run the Lp approx vs. exact regression vs. uniform sampling many times, and store the output in a JLD file
# The JLD file can then be opened to compute medians and quartiles for the LaTeX Figure Graphics

function lin_log_convergence_rate_m_range(data_gen_func, filename, eg_n = 1000, n_trials = 30, resolution = 8)

    d_base = 20
    p_base = 6
    m_range = Int.(floor.(range(100, 3000, length=resolution)))

    fast_rel_errs = zeros(n_trials, length(m_range))
    unif_rel_errs = zeros(n_trials, length(m_range))

    @showprogress for t = 1:n_trials
        eg_times, eg_values = data_gen_func(eg_n)

        rel_errs = dropdims(statistical_complexity(eg_times, eg_values, [d_base] , [p_base], m_range), dims=(1,2))
        fast_rel_errs[t, :] = rel_errs[:,1]
        unif_rel_errs[t, :] = rel_errs[:,2]
    end

    JLD.save(filename,
             "fast_rel_errs", fast_rel_errs,
             "unif_rel_errs", unif_rel_errs,
             "d_base", d_base, "p_base", p_base, "s_range", m_range,
             "n", eg_n, "n_trials", n_trials, "resolution", resolution)

end

##

lin_log_convergence_rate_m_range(generate_monomial_data, "lin_log_convergence_sparse_25k_30trial_m_range_monomial_save_me.jld", 25000, 30);
