# enumerate-gsw-distribution.jl
# Chris Harshaw, Fredrik Savje, Dan Spielman, Peng Zhang 
# January 2020
#
# A brute force enumeration of the distribution of the Gram--Schmidt Walk. 
# This is code is primarily for test purposes and for use with small samples, i.e n <= 12.
#
# This code is meant to be internal to the package. Use at your own risk.

using LinearAlgebra
using Combinatorics

"""
    gs_walk_entire_dist(B; x = zeros(size(B)[2]), balanced=false)

Enumerate probability of each assignment vectors under Gram--Schmid walk with random pivot via brute force.

# Arguments
- `B`: an m by n matrix which is input to Gram--Schmidt Walk
- `x`: the initial vector of fractional assignments (default)
- `balanced`: set `true` to sample from balanced Gram--Schmidt Walk. (default: `false`)

# Output 
- `assign_list`: an array of +/- 1 assignment vectors generated by GSW 
- `prob_list`: an array of the assignment probabilities
"""
function gs_walk_entire_dist(B; x = zeros(size(B)[2]), balanced=false)

    # get dimensions
    d,n = size(B)
    tol = 1e-12

    # initialize assignment and probabilities lists
    assign_list = []
    prob_list = []

    # iterate over all permutations of vectors
    ps = permutations(1:n)
    num_perm = factorial(n)
    for perm in ps

        # permute the orderings of vectors, get assignments & probabilities
        assign, prob = gs_walk_entire_dist_fixed_order(B[:,perm], x=x, balanced=balanced)

        # un-permute the assignment vectors
        for i=1:length(assign)
            assign[i] = assign[i][invperm(perm)]
        end

        # update the probabilities
        for (x,p) in zip(assign, prob)

            # check whether assignment x is already stored
            new_assignment = true
            for (i,xs) in enumerate(assign_list)

                # if it's already been generated, update the probability
                if sum(abs.(x - xs)) < tol 
                    prob_list[i] += p / num_perm
                    new_assignment = false
                    break
                end
            end

            # if this is the first time the assignment has been generated, update lists
            if new_assignment
                push!(assign_list, x)
                push!(prob_list, p / num_perm)
            end
        end # end updating probabilities
    end # end permuations
    return assign_list, prob_list
end

function gs_walk_entire_dist_fixed_order(B; x = zeros(size(B)[2]), balanced=false)
    """
    # gs_walk_entire_dist_fixed_order
    # Computes exactly the probabilities of assignment vectors when using Gram--Schmidt Walk
    # with deterministic "largest index" pivot selection.
    #
    # Input
    #   B               the vectors which are fed into Gram--Schmidt Walk
    #   x               the initial vector of fractional assignments, all zeros by default
    #   balanced        set `true` to run the balanced GSW; otherwise, leave false 
    #   
    # Output 
    #   assign_list     an array of +/- 1 assignment vectors seen so far
    #   prob_list       an array of the assignment probabilities seen so far
    """

    # get dimensions
    d,n = size(B)

    # initialize variables to feed to recursion
    curr_prob = 1.0
    assign_list = []
    prob_list = []

    # run recursion 
    assign_list, prob_list = gs_walk_entire_dist_recur(x, B, curr_prob, assign_list, prob_list, balanced)

    # return assignments with probabilites 
    return assign_list, prob_list
end

function gs_walk_entire_dist_recur(x, B, curr_prob, assign_list, prob_list, balanced)
    """
    # gs_walk_entire_dist_recur
    # Computes exactly the probabilities of assignment vectors recursively.
    # The code is not optimized in any way, but it is self contained.
    #
    # Input
    #   x               the vector of fractional assignments
    #   B               the vectors which are fed into Gram--Schmidt Walk
    #   curr_prob       the probability of current branch
    #   assign_list     an array of +/- 1 assignment vectors seen so far
    #   prob_list       an array of the assignment probabilities seen so far
    #   balanced        set `true` to run the balanced GSW; otherwise, leave false 
    #   
    # Output 
    #   assign_list     an array of +/- 1 assignment vectors seen so far
    #   prob_list       an array of the assignment probabilities seen so far
    """

    # get dimensions, set tolerance
    d,n = size(B)
    tol = 100*eps()

    # define alive variables, pivot
    live = convert(BitArray, [ abs(1 - abs(xi)) > tol for xi in x])
    p = findfirst(live)

    # get live not pivot variables 
    live_not_pivot = copy(live)
    live_not_pivot[p] = false
    Bt = B[:,live_not_pivot]
    vp = B[:,p]
    
    # create u vector 
    u = zeros(n)
    u[p] = 1

    # compute values for alive not pivot variables 
    if !balanced
        # z = (Bt' * Bt) \ (- Bt' * vp)
        z = pinv(Bt' * Bt) * (- Bt' * vp)
    else
        # build system of linear equations
        k = sum(live_not_pivot)
        
        # construct coefficient matrix
        A = zeros(k+1, k+1)
        A[1:k, 1:k] = Bt' * Bt
        A[1:k,k+1] = ones(k)/2
        A[k+1,1:k] = ones(k)

        # construct rhs coefficient vector 
        b = zeros(k+1)
        b[1:k] = - Bt' * vp
        b[k+1] = - 1

        # solve the system 
        # y = A \ b
        y = pinv(A)*b
        z = y[1:k]
    end
    u[live_not_pivot] = z
    
    # compute the step sizes and probabilities
    del_plus = Inf 
    del_minus = Inf 
    for i=1:n
        if live[i] & (abs(u[i]) > tol) # add u(i) != 0 for numerical error
            dp = (sign(u[i]) - x[i]) / u[i]
            dm = (sign(u[i]) + x[i]) / u[i]

            # update step sizes to x[i] is always within +/- 1
            del_plus = (dp < del_plus) ? dp : del_plus
            del_minus = (dm < del_minus) ? dm : del_minus
        end
    end
    prob_plus = del_minus / (del_plus + del_minus)
    prob_minus = 1.0 - prob_plus

    # choose plus 
    x_next_plus = x + del_plus * u 
    next_prob_plus = curr_prob * prob_plus
    if abs( n - sum(abs.(x_next_plus))) < tol # if x is completely frozen

        x_fixed = sign.(x_next_plus)

        # check whether assignment x is already stored
        new_assignment = true
        for (i,xs) in enumerate(assign_list)

            # if it's already been generated, update the probability
            if sum(abs.(x_fixed - xs)) < tol 
                prob_list[i] += next_prob_plus
                new_assignment = false
                break
            end
        end

        # if this is the first time the assignment has been generated, update lists
        if new_assignment
            push!(assign_list, x_fixed)
            push!(prob_list, next_prob_plus)
        end

    else 
        # if not all variables are frozen, recurse!
        assign_list, prob_list = gs_walk_entire_dist_recur(x_next_plus, B, next_prob_plus, assign_list, prob_list, balanced)
    end

    # choose minus 
    x_next_minus = x - del_minus * u 
    next_prob_minus = curr_prob * prob_minus
    if abs( n - sum(abs.(x_next_minus))) < tol # if x is completely frozen

        x_fixed = sign.(x_next_minus)

        # check whether assignment x is already stored
        new_assignment = true
        for (i,xs) in enumerate(assign_list)

            # if it's already been generated, update the probability
            if sum(abs.(x_fixed - xs)) < tol 
                prob_list[i] += next_prob_minus
                new_assignment = false
                break
            end
        end

        # if this is the first time the assignment has been generated, update lists
        if new_assignment
            push!(assign_list, x_fixed)
            push!(prob_list, next_prob_minus)
        end

    else 
        # if not all variables are frozen, recurse!
        assign_list, prob_list = gs_walk_entire_dist_recur(x_next_minus, B, next_prob_minus, assign_list, prob_list, balanced)
    end

    # return the list of assignments and probabilities
    return assign_list, prob_list
end

"""
    exact_mean_cov
    
Compute the mean and covariance of assignments given explicit distribution.
    
# Arguments
- `assign_list`: an array of +/- 1 assignment vectors generated by GSW 
- `prob_list`: an array of the assignment probabilities

# Outputs
- `x_mean`: the mean of +/- 1 assignment vectors, n array
- `x_cov`: the covariance of +/- 1 assignment vectors, n by n array
"""
function exact_mean_cov(assign_list, prob_list)

    # get dimensions
    T = length(assign_list)
    n = length(assign_list[1])

    # get mean 
    x_mean = zeros(n)
    for i=1:T 
        x_mean += assign_list[i] * prob_list[i]
    end

    # get covariance 
    x_cov = zeros(n,n)
    for i=1:T 
        x_c = assign_list[i] - x_mean
        x_cov += prob_list[i] * (x_c * x_c')
    end
    return x_mean, x_cov
end