using Turing
using Distributions
using MCMCChains
using Plots
using StatsPlots
using Distances
using Random
using CSV
using DataFrames
using MLDataUtils: shuffleobs, splitobs, rescale!

Random.seed!(0)
test = false

data = DataFrame(CSV.File("Bayesian/tgfb.csv"))

first(data, 6)

# transform!(data, :y => (x -> round.(x, digits = 5)))
# select!(data, Not([:X227899_at, :X211302_s_at, :X214594_x_at, :X1554662_at]))

if test
    trainset, testset = splitobs(shuffleobs(data), 0.8)
else
    trainset = data
end

target = :y

train = Matrix(select(trainset, Not(target)))
train_target = trainset[:, target]

if test
    test = Matrix(select(testset, Not(target)))
    test_target = testset[:, target]
end

# Standardise the features.
μ, σ = rescale!(train; obsdim = 1)
μtarget, σtarget = rescale!(train_target; obsdim = 1)

if test
    rescale!(test, μ, σ; obsdim = 1)
    rescale!(test_target, μtarget, σtarget; obsdim = 1)
end

# Bayesian linear regression.
@model function linear_regression(x, y)

    σ₂ ~ truncated(Normal(0, 100), 0, Inf)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(nfeatures, sqrt(10))

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sqrt(σ₂)) # equiv is : @Turing.addlogprob! loglikelihood(MvNormal(mu, sqrt(σ₂)), y)
end

model = linear_regression(train, train_target)
chain = sample(model, NUTS(0.65), 3_000)

# sumstats = DataFrame(summarize(chain], mean, std))
# means = sumstats[:, :mean]
# stdevs = sumstats[:, :std]

sampled_coefs_df = select(DataFrame(chain), vcat(:intercept, namesingroup(chain, :coefficients)))
chain_coefs = mean.(eachcol(sampled_coefs_df))
chain_coef_σs = std.(eachcol(sampled_coefs_df))

plot(chain)
stats = DataFrame.(describe(chain))
select!(stats[2], Not(:parameters))
stats = hcat(stats[1], stats[2])

# Import the GLM package.
using GLM

# Perform multiple regression OLS.
train_with_intercept = hcat(ones(size(train, 1)), train)
ols = lm(train_with_intercept, train_target)
ols_coefs = GLM.coef(ols)

# Compute predictions on the training data set
# and unstandardize them.
p = GLM.predict(ols)
train_prediction_ols = μtarget .+ σtarget .* p

if test
    # Compute predictions on the test data set
    # and unstandardize them.
    test_with_intercept = hcat(ones(size(test, 1)), test)
    p = GLM.predict(ols, test_with_intercept)
    test_prediction_ols = μtarget .+ σtarget .* p
end

# Make a prediction given an input vector.
function prediction(chain, x)
    p = get_params(chain[:, :, :])
    targets = p.intercept' .+ x * reduce(hcat, p.coefficients)'
    return vec(mean(targets; dims = 2))
end

# Calculate the predictions for the training and testing sets
# and unstandardize them.
p = prediction(chain, train)
train_prediction_bayes = μtarget .+ σtarget .* p

if test
    p = prediction(chain, test)
    test_prediction_bayes = μtarget .+ σtarget .* p

    # Show the predictions on the test data set.
    DataFrame(
        y = testset[!, target],
        Bayes = test_prediction_bayes,
        OLS = test_prediction_ols
    )
end

println(
    "Training set:",
    "\n\tBayes loss: ",
    msd(train_prediction_bayes, trainset[!, target]),
    "\n\tOLS loss: ",
    msd(train_prediction_ols, trainset[!, target])
)

if test
    println(
        "Test set:",
        "\n\tBayes loss: ",
        msd(test_prediction_bayes, testset[!, target]),
        "\n\tOLS loss: ",
        msd(test_prediction_ols, testset[!, target])
    )
end
