# =================================================================================================#
# Description: Produces the coefficients for the house pricing data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/contextual_lasso.jl")

import Cairo, CSV, DataFrames, Fontconfig, Gadfly, Random, StatsBase, Flux

# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(1); Random.default_rng() = rng

# Load data
data = CSV.read("Data/new.csv", DataFrames.DataFrame)

# Remove corrupted observations
data = filter(:bathRoom => x -> in(x, string.(0:7)), data)
data.bathRoom = parse.(Float64, data.bathRoom)
data.livingRoom = parse.(Float64, data.livingRoom)

# Save data dimension
n = size(data, 1)

# Extract explanatory features, contextual features, and response
x = DataFrames.select(
    data, 
    :elevator, 
    :floor => DataFrames.ByRow(x -> parse(Float64, x[end - 1:end])), 
    :renovationCondition => DataFrames.ByRow(x -> max(x - 2.0, 0.0)) => :renovationCondition,
    :livingRoom,
    :bathRoom
)
z = DataFrames.select(
    data, 
    :Lng => :Longitude, 
    :Lat => :Latitude
    )
y = DataFrames.select(data, [:totalPrice, :square] => DataFrames.ByRow(/) => :price)

# Remove missing type from DataFrame
DataFrames.disallowmissing!(x)

# Subsample data to facilitate plotting
id = 1:n
train_id = StatsBase.sample(id, 15000, replace = false)
id = setdiff(id, train_id)
valid_id = StatsBase.sample(id, 15000, replace = false)
test_id = setdiff(id, valid_id)
test_id = StatsBase.sample(id, 15000, replace = false)
x_train = Matrix(x[train_id, :])
z_train = Matrix(z[train_id, :])
y_train = Matrix(y[train_id, :])[:]
x_valid = Matrix(x[valid_id, :])
z_valid = Matrix(z[valid_id, :])
y_valid = Matrix(y[valid_id, :])[:]
z_test = Matrix(z[test_id, :])

# Set network configuration
p, m = size(x, 2), size(z, 2)
n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 8)) - m - p - 3))
hidden_layers = repeat([n_neuron], 3)

# Perform a contextual lasso fit
fit = ContextualLasso.classo(
    x_train, z_train, y_train, 
    x_valid, z_valid, y_valid, 
    sign_constraint = [0, 1, 0, 1, 1, 1], verbose = false, relax = true, 
    hidden_layers = hidden_layers
)
coef = ContextualLasso.coef(fit, z_test, lambda = "lambda_min", gamma = "gamma_min")

# Produce coefficient profile plots
plots = Vector{Gadfly.Plot}(undef, size(coef, 2))
labels = ["Intercept", "Elevator", "Floor", "Renovation condition", "No. of living rooms",  
    "No. of bathrooms"]
for j in 1:size(coef, 2)
active = coef[:, j] .!= 0
if sum(active) != 0
    plots[j] = Gadfly.plot(
        Gadfly.layer(x = z_test[active, 1], y = z_test[active, 2], color = coef[active, j]), 
        Gadfly.layer(x = z_test[.! active, 1], y = z_test[.! active, 2], 
        color = [Gadfly.colorant"light grey"]),
        size = [0.02], 
        Gadfly.Guide.colorkey(""),
        Gadfly.Guide.title(labels[j]), 
        Gadfly.Guide.xlabel("Longitude"), 
        Gadfly.Guide.ylabel("Latitude"),
        Gadfly.Guide.xticks(label = false),
        Gadfly.Guide.yticks(label = false),
        Gadfly.Theme(highlight_width = 0Gadfly.mm)
        )
else
    plots[j] = Gadfly.plot(
        Gadfly.layer(x = z_test[.! active, 1], y = z_test[.! active, 2], 
        color = [Gadfly.colorant"light grey"]),
        size = [0.02], 
        Gadfly.Guide.colorkey(""),
        Gadfly.Guide.title(labels[j]), 
        Gadfly.Guide.xlabel("Longitude"), 
        Gadfly.Guide.ylabel("Latitude"),
        Gadfly.Guide.xticks(label = false),
        Gadfly.Guide.yticks(label = false),
        Gadfly.Theme(highlight_width = 0Gadfly.mm)
        )
end
end

# # Arrange and export coefficient profile plots
# Gadfly.gridstack([plots[1] plots[2] plots[3]; plots[4] plots[5] plots[6]]) |> 
#   Gadfly.PDF("Figures/housing.pdf", 9Gadfly.inch, 4.3Gadfly.inch)

# Arrange and export coefficient profile plots
Gadfly.hstack(plots[1], plots[2], plots[3], plots[4]) |> 
    Gadfly.PDF("Figures/house.pdf", 9Gadfly.inch, 2.5Gadfly.inch)