# =================================================================================================#
# Description: Plots the experimental results for the synthetic linear data
# Author: Ryan Thompson
# =================================================================================================#

cd("/Experiments")

import Cairo, ColorSchemes, CSV, DataFrames, Fontconfig, Gadfly, Pipe, PlotUtils, Statistics

for (p, s, graph_type, dist) in [
    (20, 40, "erdos_renyi", "gaussian"), 
    (100, 200, "erdos_renyi", "gaussian"),
    (20, 60, "erdos_renyi", "gaussian"), 
    (20, 80, "erdos_renyi", "gaussian"),
    (20, 40, "scale_free_2", "gaussian"), 
    (20, 40, "scale_free_3", "gaussian"),
    (20, 40, "erdos_renyi", "gumbel"), 
    (20, 40, "erdos_renyi", "exponential"), 
    (100, 200, "erdos_renyi", "gumbel"), 
    (100, 200, "erdos_renyi", "exponential"),
    (20, 40, "erdos_renyi", "heteroscedastic_gaussian")
    ]

    # Summarise results for plotting
    result = Pipe.@pipe CSV.read("Results/synthetic_linear.csv", DataFrames.DataFrame) |>
    DataFrames.stack(_, [:bscore, :shd, :f1score, :auroc, :sparsity]) |>
    DataFrames.groupby(_, [:n, :p, :s, :graph_type, :dist, :estimator, :variable]) |>
    DataFrames.combine(
    _, 
    :value => Statistics.mean => :mean, 
    :value => (x -> Statistics.mean(x) - Statistics.std(x) / sqrt(size(x, 1))) => :low,
    :value => (x -> Statistics.mean(x) + Statistics.std(x) / sqrt(size(x, 1))) => :high
    ) |>
    DataFrames.filter(row -> row.p == p && row.s == s && row.dist == dist && 
        row.graph_type == graph_type, _)

    # Compute y axis ticks for consistent plots
    bdf = DataFrames.filter(:variable => x -> x == "bscore", result)
    sdf = DataFrames.filter(:variable => x -> x == "shd", result)
    lo = min(minimum(bdf.low), minimum(sdf.low))
    hi = max(maximum(bdf.high), maximum(sdf.high))
    yticks, _ = PlotUtils.optimize_ticks(lo, hi, k_min = 5, k_max = 6)

    # Plot Brier score
    p1 = Gadfly.plot(
        DataFrames.filter(:variable => x -> x == "bscore", result),
        x = :n,
        y = :mean,
        ymin = :low,
        ymax = :high,
        color = :estimator,
        Gadfly.Geom.point, 
        Gadfly.Geom.line, 
        Gadfly.Geom.yerrorbar,
        Gadfly.Guide.xlabel("Sample size"), 
        Gadfly.Guide.ylabel("Brier score"),
        Gadfly.Guide.colorkey(title = ""),
        Gadfly.Guide.yticks(ticks = yticks),
        Gadfly.Scale.x_log10,
        Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.seaborn_deep[1:p]),
        Gadfly.Theme(key_position = :top, plot_padding = [0Gadfly.mm, 5Gadfly.mm, 0Gadfly.mm, 
            0Gadfly.mm], major_label_font_size = 10Gadfly.pt, minor_label_font_size = 7Gadfly.pt, 
            key_label_font_size = 8Gadfly.pt, key_label_font = "sans-serif")
        )

    # Plot structural Hamming Distance
    p2 = Gadfly.plot(
        DataFrames.filter(:variable => x -> x == "shd", result),
        x = :n,
        y = :mean,
        ymin = :low,
        ymax = :high,
        color = :estimator,
        Gadfly.Geom.point, 
        Gadfly.Geom.line, 
        Gadfly.Geom.yerrorbar,
        Gadfly.Guide.xlabel("Sample size"), 
        Gadfly.Guide.ylabel("Expected SHD"),
        Gadfly.Guide.colorkey(title = ""),
        Gadfly.Guide.yticks(ticks = yticks),
        Gadfly.Scale.x_log10,
        Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.seaborn_deep[1:p]),
        Gadfly.Theme(key_position = :none, plot_padding = [2.5Gadfly.mm, 2.5Gadfly.mm, 
            12.4Gadfly.mm, 0Gadfly.mm], major_label_font_size = 10Gadfly.pt, 
            minor_label_font_size = 7Gadfly.pt, key_label_font = "sans-serif")
        )

    # Plot F1 score
    p3 = Gadfly.plot(
        DataFrames.filter(:variable => x -> x == "f1score", result),
        x = :n,
        y = :mean,
        ymin = :low,
        ymax = :high,
        color = :estimator,
        Gadfly.Geom.point, 
        Gadfly.Geom.line, 
        Gadfly.Geom.yerrorbar,
        Gadfly.Guide.xlabel("Sample size"), 
        Gadfly.Guide.ylabel("Expected F1 score"),
        Gadfly.Guide.colorkey(title = ""),
        Gadfly.Guide.yticks(ticks = [0.00, 0.25, 0.50, 0.75, 1.00]),
        Gadfly.Scale.x_log10,
        Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.seaborn_deep[1:p]),
        Gadfly.Theme(key_position = :none, plot_padding = [5Gadfly.mm, 0Gadfly.mm, 12.4Gadfly.mm, 
            0Gadfly.mm], major_label_font_size = 10Gadfly.pt, minor_label_font_size = 7Gadfly.pt, 
            key_label_font = "sans-serif")
        )

    # Plot AUROC
    p4 = Gadfly.plot(
        DataFrames.filter(:variable => x -> x == "auroc", result),
        x = :n,
        y = :mean,
        ymin = :low,
        ymax = :high,
        color = :estimator,
        Gadfly.Geom.point, 
        Gadfly.Geom.line, 
        Gadfly.Geom.yerrorbar,
        Gadfly.Guide.xlabel("Sample size"), 
        Gadfly.Guide.ylabel("AUROC"),
        Gadfly.Guide.colorkey(title = ""),
        Gadfly.Guide.yticks(ticks = [0.50, 0.60, 0.70, 0.80, 0.90, 1.00]),
        Gadfly.Scale.x_log10,
        Gadfly.Scale.DiscreteColorScale(p -> ColorSchemes.seaborn_deep[1:p]),
        Gadfly.Theme(key_position = :none, plot_padding = [2.5Gadfly.mm, 2.5Gadfly.mm, 
            12.4Gadfly.mm, 0Gadfly.mm], major_label_font_size = 10Gadfly.pt, 
            minor_label_font_size = 7Gadfly.pt, key_label_font = "sans-serif")
        )

    # Combine and save plots
    Gadfly.hstack(p1, p2, p3, p4) |> 
        Gadfly.PDF("Figures/synthetic_linear_$(p)_$(s)_$(graph_type)_$(dist).pdf", 10Gadfly.inch, 
            3.3Gadfly.inch)

end
