# Axis-aligned multivariate normal (i.e., independent entries, i.e., diagonal covariance matrix) using conjugate prior.
module MVNaaC

module MVNaaCmodel # submodule for component family definitions
export Theta, Data, log_marginal, log_marginal_multiple, new_theta, Theta_clear!, Theta_clear2!, Theta_adjoin!, Theta_adjoin2!, Theta_remove!, Theta_remove2!,
       Hyperparameters, construct_hyperparameters, update_hyperparameters!

using Statistics
using SpecialFunctions
using Distributions: Gamma
lgamma_(x) = logabsgamma(x)[1]

const Data = Array{Float64,1}

mutable struct Theta
    n1::Int64                 # number of data points assigned to this cluster
    n2::Int64
    d::Int64                 # dimension
    sum_x::Array{Float64,1}  # sum of the data points x assigned to this cluster
    sum_x2::Array{Float64,1}
    sum_xx::Array{Float64,1} # sum of x.*x for the data points assigned to this cluster
    sum_xx2::Array{Float64,1}
    # ind_x::Array{Int64,1}
    # ind_x2::Array{Int64,1}
    Theta(d) = (
        p=new(); 
        p.n1=0; p.n2=0;
        p.d=d;
        p.sum_x=zeros(d); p.sum_x2=zeros(d);
        p.sum_xx=zeros(d); p.sum_xx2=zeros(d);
        # p.ind_x=[]; p.ind_x2=[];
        p
    )
end
new_theta(H) = Theta(H.d)
Theta_clear!(p) = (
    p.sum_x[:] .= 0.; p.sum_x2[:] .= 0.;
    p.sum_xx[:] .= 0.; p.sum_xx2[:] .= 0.;
    # p.ind_x = []; p.ind_x2 = [];
    p.n1 = 0; p.n2 = 0
)
Theta_clear2!(p) = (
    p.sum_x2[:] .= 0.;
    p.sum_xx2[:] .= 0.;
    # p.ind_x2 = [];
    p.n2 = 0;
)
Theta_adjoin!(p,x1) = (
# Theta_adjoin!(p,x1,ind1) = (
    for j=1:p.d
        p.sum_x[j] += x1[j];
        p.sum_xx[j] += x1[j]*x1[j];
    end; 
    # push!(p.ind_x, ind1);
    p.n1 += 1;
)
Theta_adjoin2!(p,x2) = (
# Theta_adjoin2!(p,x2,ind2) = (
    for j=1:p.d
        p.sum_x2[j] += x2[j];
        p.sum_xx2[j] += x2[j]*x2[j];
    end; 
    # push!(p.ind_x2, ind2);
    p.n2 += 1;
)
Theta_remove!(p,x1) = (
# Theta_remove!(p,x1,ind1) = (
    for j=1:p.d
        p.sum_x[j] -= x1[j];
        p.sum_xx[j] -= x1[j]*x1[j];
    end;
    # deleteat!(p.ind_x, ind1);
    p.n1 -= 1;
    @assert p.n1 >= 0 "Theta_remove!: n1 became negative"
)
Theta_remove2!(p,x2) = (
# Theta_remove2!(p,x2,ind2) = (
    for j=1:p.d
        p.sum_x2[j] -= x2[j];
        p.sum_xx2[j] -= x2[j]*x2[j];
    end;
    # deleteat!(p.ind_x2, ind2);
    p.n2 -= 1;
    @assert p.n2 >= 0 "Theta_remove2!: n2 became negative"
)

function log_marginal(p,H)
    n = p.n1 + p.n2
    d = H.d
    within = 0.0
    mean_sq = 0.0
    @inbounds for j=1:d
        sum_x = p.sum_x[j] + p.sum_x2[j]
        sum_xx = p.sum_xx[j] + p.sum_xx2[j]
        within += sum_xx - sum_x^2/n
        mean_sq += (sum_x / n)^2
        # variance_term = sum_xx - sum_x^2 / n

        # log_lik += -0.5 * (log(2 * pi / H.lambda) + H.lambda * variance_term / n)
    end
    log_lik = 0.5 * (n * d * log(H.lambda) - n * d * log(2 * pi)) - 0.5 * d * log(1 + n) - 0.5 * H.lambda * (within + (n / (1 + n)) * mean_sq)
    return log_lik
end

function log_marginal(x1,p,H)
    Theta_adjoin!(p,x1)
    result = log_marginal(p,H)
    Theta_remove!(p,x1)
    return result
end

function log_marginal_multiple(x1,x2_list,p,H)
    Theta_adjoin!(p,x1)
    for x2 in x2_list
        Theta_adjoin2!(p,x2)
    end

    result = log_marginal(p,H)
    Theta_remove!(p,x1)
    for x2 in x2_list
        Theta_remove2!(p,x2)
    end
    return result
end

mutable struct Hyperparameters
    d::Int64    # dimension
    m::Float64  # prior mean of mu's 
    c::Float64  # prior precision multiplier for mu's
    a::Float64  # prior shape of lambda's
    b::Float64  # prior rate of lambda's
    constant::Float64
    log_Ga::Array{Float64,1}
    lambda::Float64
end

function construct_hyperparameters(options)
    x1, x2 = options.x1, options.x2
    n1, n2 = length(x1), length(x2); 
    n = n1+n2
    d = length(x1[1])

    m = 0.0
    c = 1.0
    a = 1.0
    b = 1.0
    log_Ga = lgamma_.(a .+ 0.5*(1:n+1))
    constant = 0.5*log(c) + a*log(b) - lgamma_(a)
    lambda = rand(Gamma(a, 1/b))
    return Hyperparameters(d,m,c,a,b,constant,log_Ga,lambda)
end

# Update lambda
function update_hyperparameters!(H,theta,list,t,x1,x2,z)
    n1, n2 = length(x1), length(x2); 
    n = n1+n2
    d = length(x1[1])

    total_ss = 0.0
    point_count = 0

    for c in list[1:t]
        n_c = theta[c].n1 + theta[c].n2

        if n_c > 0
            for j in 1:d
                sum_x = theta[c].sum_x[j] + theta[c].sum_x2[j]
                sum_xx = theta[c].sum_xx[j] + theta[c].sum_xx2[j]

                ss = sum_xx - (sum_x^2) / n_c
                total_ss += ss
            end
            point_count += n_c
        end
    end

    a_new = H.a + 0.5 * point_count * d
    b_new = H.b + 0.5 * total_ss
    H.lambda = rand(Gamma(a_new, 1/b_new))
end

end # module MVNaaCmodel
using .MVNaaCmodel

# Include generic code
include("generic.jl")

# Include t_proposal code
include("T_proposal.jl")

# Include upsampling code
include("upsampling.jl")

# Include core sampler code
include("coreConjugate.jl")

end # module MVNaaC



