# =================================================================================================#
# Description: Plots the experimental results for the abalation study comparing the relaxed and
# non-relaxed versions
# Author: Ryan Thompson
# =================================================================================================#

import Cairo, ColorSchemes, CSV, DataFrames, Fontconfig, Gadfly, Pipe, Statistics

# Load data
result = CSV.read("Results/regression.csv", DataFrames.DataFrame)

# Retain only ablation results
result = Pipe.@pipe result |>
    DataFrames.filter(:estimator => x -> 
        x == "Contextual lasso" || x == "Contextual lasso (non-relaxed)", 
        _) |>
    DataFrames.transform(_, :estimator => DataFrames.ByRow(x -> 
        begin
            if x == "Contextual lasso"
                "Contextual lasso (relaxed)"
            else
                "Contextual lasso (non-relaxed)"
            end
        end) => :estimator)

# Summarise results for plotting
result = Pipe.@pipe result |>
         DataFrames.stack(_, [:rel_loss, :sparsity, :f1_score]) |>
         DataFrames.groupby(_, [:n, :m, :p, :estimator, :loss, :variable]) |>
         DataFrames.combine(
        _, 
        :value => Statistics.mean => :mean, 
        :value => (x -> Statistics.mean(x) - Statistics.std(x) / sqrt(size(x, 1))) => :low,
        :value => (x -> Statistics.mean(x) + Statistics.std(x) / sqrt(size(x, 1))) => :high
         ) |>
         DataFrames.transform(_, [:p, :m] => DataFrames.ByRow((p, m) -> "p = $p, m = $m") => :pm) |>
         DataFrames.transform(_, :variable => DataFrames.ByRow(x -> 
         begin
            if x == "rel_loss"
                "Relative loss"
            elseif x == "sparsity"
                "Prop. of nonzeros"
            else
                "F1 of nonzeros"
            end
         end)
         => :variable
         ) |>
         DataFrames.transform(_, :variable => DataFrames.ByRow(x -> 
            ifelse(x == "Prop. of nonzeros", 0.1, missing)) => :intercept)

# Plot ablation results
Gadfly.plot(
    DataFrames.filter(:loss => x -> x == "mse", result),
    x = :n,
    y = :mean,
    ymin = :low,
    ymax = :high,
    color = :estimator,
    xgroup = :pm,
    ygroup = :variable,
    yintercept = :intercept,
    Gadfly.Geom.subplot_grid(Gadfly.Geom.point, Gadfly.Geom.line, Gadfly.Geom.yerrorbar, 
        Gadfly.Geom.hline(color = "black", style = :dot), free_y_axis = true, 
        Gadfly.Coord.cartesian(ymin = 0, ymax = 1.1)),
    Gadfly.Guide.xlabel("Sample size"), 
    Gadfly.Guide.ylabel(""),
    Gadfly.Guide.colorkey(title = ""),
    Gadfly.Scale.x_log10,
    Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.Tsimshian[[6, 7]]),
    Gadfly.Theme(key_position = :top, key_label_font_size = 9Gadfly.pt, plot_padding = [0Gadfly.mm])
    ) |> 
    Gadfly.PDF("Figures/ablation.pdf", 9Gadfly.inch, 6.05Gadfly.inch)