# =================================================================================================#
# Description: Plots the experimental results for the comparisons of estimators on the synthetic
# classification data
# Author: Ryan Thompson
# =================================================================================================#

import Cairo, ColorSchemes, CSV, DataFrames, Fontconfig, Gadfly, Pipe, Statistics

# Load data
result = CSV.read("Results/classification.csv", DataFrames.DataFrame)

# Remove ablation results
result = DataFrames.filter(:estimator => x -> x != "Contextual lasso (non-relaxed)", result)

# Remove instances where contextual linear model losses are very large (makes plot nicer to read)
result = DataFrames.transform(result, [:estimator, :n, :p, :m, :rel_loss] 
    => DataFrames.ByRow((estimator, n, p, m, rel_loss) -> 
    ifelse(estimator == "Contextual linear model" && n < 200 && p == 50 && m == 5, missing, 
    rel_loss)) => :rel_loss)

# Summarise results for plotting
result = Pipe.@pipe result |>
         DataFrames.stack(_, [:rel_loss, :sparsity, :f1_score]) |>
         DataFrames.groupby(_, [:n, :m, :p, :estimator, :variable]) |>
         DataFrames.combine(
        _, 
        :value => Statistics.mean => :mean, 
        :value => (x -> Statistics.mean(x) - Statistics.std(x) / sqrt(size(x, 1))) => :low,
        :value => (x -> Statistics.mean(x) + Statistics.std(x) / sqrt(size(x, 1))) => :high
         ) |>
         DataFrames.transform(_, [:p, :m] => DataFrames.ByRow((p, m) -> "p = $p, m = $m") => :pm) |>
         DataFrames.transform(_, :variable => DataFrames.ByRow(x -> 
         begin
            if x == "rel_loss"
                "Relative loss"
            elseif x == "sparsity"
                "Prop. of nonzeros"
            else
                "F1 of nonzeros"
            end
         end)
         => :variable
         ) |>
         DataFrames.transform(_, :variable => DataFrames.ByRow(x -> 
            ifelse(x == "Prop. of nonzeros", 0.1, missing)) => :intercept)

# Plot results
Gadfly.plot(
    result,
    x = :n,
    y = :mean,
    ymin = :low,
    ymax = :high,
    color = :estimator,
    xgroup = :pm,
    ygroup = :variable,
    yintercept = :intercept,
    Gadfly.Geom.subplot_grid(Gadfly.Geom.point, Gadfly.Geom.line, Gadfly.Geom.yerrorbar, 
        Gadfly.Geom.hline(color = "black", style = :dot), free_y_axis = true, 
        Gadfly.Coord.cartesian(ymax = 1)),
    Gadfly.Guide.xlabel("Sample size"), 
    Gadfly.Guide.ylabel(""),
    Gadfly.Guide.colorkey(title = ""),
    Gadfly.Scale.x_log10,
    Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.Tsimshian[1:p]),
    Gadfly.Theme(key_position = :top, key_label_font_size = 9Gadfly.pt, plot_padding = [0Gadfly.mm])
    ) |> 
    Gadfly.PDF("Figures/classification.pdf", 9Gadfly.inch, 6.4Gadfly.inch)