function plot_gpx(
    mod::ODECoxProcess,
    c::Int,
    class::String;
    yaxis::Bool=true,
    fontfamily="IPAPGothic"  # Option: ["serif-roman", "Hiragino"]
)

    gp = mod.gm.gps[c]
    if yaxis
        p = Plots.plot(
            0:0.001:1, gp.gptx;
            c=:blue, linewidth=0, ribbon_scale=3,
            title="$(class)", fontfamily=fontfamily, legend=:none)
    else
        p = Plots.plot(
            0:0.001:1, gp.gptx;
            c=:blue, linewidth=0, ribbon_scale=3,
            title="$(class)", fontfamily=fontfamily, legend=:none, yaxis=nothing)
    end
    Plots.scatter!(gp.times_x, gp.x, c=:blue, ms=3.0, markerstrokewidth=0)
    return p
end

function plot_gpxs(
    mod,
    classes::Vector{String};
    size=(700, 1200),
    fontfamily="IPAPGothic"
)
    C = length(classes)
    p1 = plot_gpx(mod, 1, classes[1], fontfamily=fontfamily)
    ps = [plot_gpx(mod, c, classes[c], yaxis=false, fontfamily=fontfamily) for c in 2:C]
    p = Plots.plot(
        p1, ps..., layout=(1, C),
        size=size, ylims=(-3, 3), sharey=true, titlefontsize=12, grid=false, margin=1mm)
    return p
end

function plot_gradient_match(mod::ODECoxProcess, c::Int; size=(800, 400), legend=true)
    gm = mod.gm
    gp = gm.gps[c]
    inducing_points = gp.times_x
    window_points = gp.times_y
    α = gp.ϕ.kernel_params[1] |> exp
    μ = gp.μ
    ode_ℝ = ode_gradmean_ℝ(mod, gm.gps, gm.θ.params)

    ticks = LinRange(0, 1, 5)
    tickstr = [@sprintf("%.2f", ti) for ti in ticks]

    p1 = Plots.plot(gp.gpt, ribbon_scale=3, label="GP prior",
        title="Latent GP of $(mod.data.classes[c])",
        widen=false)
    Plots.plot!(gp.gptxw, ribbon_scale=3, c=:red, label=:none)
    Plots.plot!(0:0.001:1, gp.gptx; c=:blue, ribbon_scale=3, label="f ~ GP")
    Plots.scatter!(inducing_points, gp.x, c=:blue, markerstrokewidth=0, label="x = f(t)")
    Plots.scatter!(window_points, gp.y, c=:red, ms=3, markerstrokewidth=0, label="y = f(w)")
    if legend
        Plots.scatter!([-1, -1], [0, 0], c=:red, marker=:square, markerstrokewidth=0,
            α=0.3, label="±3 σ")
        Plots.vspan!([1.02, 1.5], c=:white, label=:none)
        Plots.xlims!(-0.02, 1.5)
        Plots.ylims!(-3 * α - 0.5 + μ, 3 * α + 0.5 + μ)
        Plots.plot!([xlims(p1)[1], ticks[end]], [ylims(p1)[1], ylims(p1)[1]],
            lw=2, lc=:black, label=false, widen=false)
        Plots.xticks!(ticks, tickstr)
    end
    p2 = Plots.scatter(
        inducing_points,
        ode_ℝ[:, c],
        ms=4,
        color=:magenta,
        markerstrokewidth=2,
        markerstrokecolor=:magenta,
        label="ODE grad",
        yerror=3 * mod.gm.γ)
    Plots.scatter!(
        inducing_points,
        gp.gpgradmean,
        ms=3,
        color=:blue,
        markerstrokewidth=1,
        markerstrokecolor=:blue,
        label="GP grad",
        backgroundcolor=:none,
        yerror=3 * sqrt.(diag(gp.gpgradcov)),
        # title="Gradient Matching",
        widen=false)
    Plots.plot!([0, 1], [μ, μ], ls=:dot, lw=2, color=:black, label=:none)
    if legend
        Plots.vspan!([1.02, 1.5], c=:white, label=:none)
        Plots.xlims!(-0.02, 1.5)
        Plots.plot!([xlims(p2)[1], ticks[end]], [ylims(p2)[1], ylims(p2)[1]],
            lw=2, lc=:black, label=false, widen=false)
        Plots.xticks!(ticks, tickstr)
    end

    if legend
        p = Plots.plot(
            p1, p2, layout=Plots.grid(2, 1, heights=[0.65, 0.35]), link=:x, xaxis=false,
            legend=:right, titlefontsize=10, size=size)
    else
        p = Plots.plot(
            p1, p2, layout=Plots.grid(2, 1, heights=[0.65, 0.35]), link=:x, xaxis=false,
            legend=:none, titlefontsize=10, size=size)
    end
    return p
end

function plot_competitive_relationships_bipartite(
    mod::ODECoxProcess; size=(500, 500), labelsize=10, markersize=10)

    @assert typeof(mod.gm.θ.params) == CompetitionParams
    C = mod.data.C
    pairs = vec([[pair[1], pair[2]] for pair in CartesianIndices((C, C))])
    aℝ = param2tuple(mod.gm.θ)[3]
    a = scaled_logistic.(aℝ, mod.gm.θ.priors[3])
    A = competitive_coef_matrix(a, C)

    p = Plots.plot(
        size=size,
        yflip=true,
        showaxis=false,
        legend=:none,
        xticks=nothing,
        yticks=(1:C, [mod.data.classes[c] for c in 1:C]),
        ytickfont=font(labelsize)
    )
    for (pair, width) in zip(pairs, vec(A))
        Plots.plot!([0, 1], [pair[1], pair[2]], lw=width, c=:black)
        Plots.scatter!([0, 1], [pair[1], pair[2]], c=:blue, markersize=markersize)
    end
    return p
end


# function plot_posterior_logA_matrix(
#     chains::TemperedChains{T, I};
#     from::Int,
#     to::Int,
#     n_thinning::Int,
#     competitors::Union{Nothing, Vector{String}}=nothing,
#     kwargs...
#     ) where {T<:Real, I<:Integer}

#     logA_mean = get_posterior_logA_matrix(chains, from=from, to=to, n_thinning=n_thinning)
#     competitors = isnothing(competitors) ? sort(chains.mods[1].data.classes).vals : competitors
#     @assert length(competitors)==chains.mods[1].data.C
#     p = heatmap(
#         competitors, competitors, logA_mean,
#         yflip=true,
#         c = cgrad([:white, :black], [0, 1.]);
#         kwargs...
#     )
#     return p
# end

# function plot_posterior_logA_graph(
#     chains::TemperedChains{T, I};
#     from::Int,
#     to::Int,
#     n_thinning::Int,
#     savefig::Bool=false,
#     savepath::String="graph.pdf",
#     figsize::Tuple=(12cm, 12cm),
#     competitors::Union{Nothing, Vector{String}}=nothing,
#     nodesize=0.2,
#     nodelabelsize=12,
#     nodefillcolor="lightgray",
#     arrowlengthfrac=0.1,
#     arrowangleoffset=π/9,
#     outangle=π/20,
#     fontfamily::String="Times"
#     ) where {T<:Real, I<:Integer}

#     logA = get_posterior_logA_matrix(chains, from=from, to=to, n_thinning=n_thinning)
#     gA = Graphs.SimpleDiGraph(logA);
#     normalized_logA = (logA .- minimum(logA)) ./ (maximum(logA) - minimum(logA)) ;
#     colors = [Gray(0.9*(1 - a)) for a in normalized_logA]
#     colorvec = colors[[ij[1] != ij[2] for ij in CartesianIndices(size(logA))]]
#     widths = (0.2 .+ 0.8 .* normalized_logA) * 10
#     widthvec = widths[[ij[1] != ij[2] for ij in CartesianIndices(size(logA))]]

#     competitors = isnothing(competitors) ? sort(chains.mods[1].data.classes).vals : competitors
#     nodelabels = [uppercase(string(name[1])) for name in competitors]
#     @assert length(competitors)==chains.mods[1].data.C
#     p = GraphPlot.gplot(gA,
#        layout=circular_layout,
#        NODESIZE=nodesize,
#        nodelabel=nodelabels,
#        NODELABELSIZE = nodelabelsize,
#        nodefillc=nodefillcolor,
#        edgestrokec=colorvec,
#        edgelinewidth=widthvec,
#        arrowlengthfrac=arrowlengthfrac,
#        arrowangleoffset=arrowangleoffset,
#        linetype="curve",
#        outangle=outangle
#     )

#     p = Compose.compose(p, Compose.font(fontfamily))
#     if savefig
#         # NOTE: The Cairo and Fontconfig packages are necessary for saving as PDF.
#         # Add them with the package manager if necessary,
#         # then run import Cairo, Fontconfig before invoking PDF.
#         Compose.draw(Compose.PDF(savepath, figsize[1], figsize[2]), p)
#     end
#     return p
# end


function plot_dynamics(
    mod::ODECoxProcess;
    true_X::Union{Nothing,Matrix{T}}=nothing,
    space::Symbol=:real, # [:real, :nonnegative]
    with_gp_mean::Bool=false,
    with_event::Bool=true,
    with_gm_error::Bool=true,
    colors::Union{Vector{String},Vector{Symbol}}=[
        "#CB3C33", "#4063D8", "#389826", "#9558B2", "#ff8c00"],
    title::Union{Nothing,String}=nothing,
    kwargs...
) where {T<:Real}

    if ~isnothing(true_X)
        @assert size(true_X)[2] == mod.data.C
    end
    @assert length(colors) >= mod.data.C
    @assert space in [:real, :nonnegative]
    @unpack gps, θ = mod.gm
    p1 = Plots.plot(legend=:none)
    if ~isnothing(title)
        title!(title)
    end
    X = get_X(gps)
    Y = get_Y(gps)
    gaussians = agm_pogs(mod, gps, θ.params)
    for c in 1:mod.data.C
        est_F = space == :real ? X[:, c] : exp.(X[:, c])
        est_Y = space == :real ? Y[:, c] : exp.(Y[:, c])
        gm_error_ratio = calc_gm_error_std_ratio(gaussians[c])
        gm_error_size = gm_error_ratio .|> x -> x < 1 ? 0 : x
        time_length = 1 + mod.gm.extrapolation_time_length

        if ~isnothing(true_X)
            true_F = space == :real ? true_X[:, c] : exp.(true_X[:, c])
            Plots.plot!(
                0:time_length/(size(true_X)[1]-1):time_length,
                true_F, c=colors[c], s=:dot, l=5)
        end

        Plots.plot!(gps[c].times_x, est_F, c=colors[c], l=3)
        if c in mod.gm.observed_c_ids
            Plots.scatter!(gps[c].times_y, est_Y, c=colors[c], ms=3, markerstrokewidth=0)
        end
        if with_gm_error
            Plots.scatter!(
                gps[c].times_x, est_F, c=:yellow, markerstrokecolor=:black,
                ms=gm_error_size * 3, markerstrokewidth=1, m=:star)
        end
        if with_gp_mean
            Plots.hline!([gps[c].μ], c=colors[c], s=:dashdotdot, l=2)
        end
    end
    if with_event
        p2 = plot_event(mod, with_count=false, border=:none)
        p = Plots.plot(p1, p2; layout=Plots.grid(2, 1, heights=[0.9, 0.1]), kwargs...)
    else
        p = Plots.plot(p1; kwargs...)
    end
    return p
end

"""
    animate_dynamics(
        mod, mcp, n_frames, n_iter_per_frame;
        true_X, space, with_gp_mean, with_event, colors, args...
    )
"""
function animate_dynamics(
    mod::ODECoxProcess,
    mcp::Dict,
    n_frames::Int,
    n_iter_per_frame::Int;
    true_X::Union{Nothing,Matrix{T}}=nothing,
    space::Symbol=:real, # [:real, :nonnegative]
    with_gp_mean::Bool=false,
    with_event::Bool=true,
    colors::Union{Vector{String},Vector{Symbol}}=["#CB3C33", "#4063D8", "#389826"],
    kwargs...) where {T<:Real}

    p = Progress(n_frames, showspeed=true)
    anim = @animate for i in 1:n_frames
        train!(mod, n_iter_per_frame, mcp, show_progress=false)
        plot_dynamics(
            mod;
            true_X=true_X, space=space, with_event=with_event, with_gp_mean=with_gp_mean,
            title="iter $(i*n_iter_per_frame)", kwargs...)
        next!(p)
    end
    return anim
end


function plot_event(
    mod::ODECoxProcess;
    ms=2, with_count::Bool=true,
    colors::Union{Vector{String},Vector{Symbol}}=["#CB3C33", "#4063D8", "#389826"],
    kwargs...)

    @assert length(colors) >= mod.data.C
    times_n = mod.data.df.time
    classes = mod.data.df.class_orig
    c_names = [c for (i, c) in sort(collect(mod.data.classes))]
    C = length(c_names)
    height = 0.1 * C
    p1 = Plots.plot(
        legend=:none, yticks=(collect(0.1:0.1:height), c_names), xlabel="time (normalized)")
    n_pat = Int[]
    for (i, c) in enumerate(c_names)
        T = times_n[classes.==c]
        n = length(T)
        push!(n_pat, n)
        Plots.scatter!(T, fill(i * 0.1, n), marker=:+, ms=ms, label=c, c=colors[i])
    end
    Plots.xlims!(0, 1 + mod.gm.extrapolation_time_length)
    Plots.ylims!(0, height + 0.1)

    if with_count
        p2 = Plots.bar(legend=:none, xlabel="# of events")
        for i in 1:length(c_names)
            Plots.bar!(
                [i * 0.1], [n_pat[i]], orientation=:horizotnal,
                bar_width=0.08, width=0, yticks=false, c=colors[i])
        end
        Plots.ylims!(0, height + 0.1)

        p = Plots.plot(
            p1, p2, yflip=true, layout=Plots.grid(1, 2, widths=[0.75, 0.25]);
            kwargs...)
    else
        p = Plots.plot(p1; kwargs...)
    end
    return p
end

function plot_event_swarm(mod::ODECoxProcess)
    return Gadfly.plot(
        mod.data.df, y="class_orig", x="time",
        Guide.xlabel("time (normalized)"),
        Guide.ylabel("class"),
        Geom.beeswarm(orientation=:horizontal, padding=0.03mm),
        Theme(point_size=0.3mm, default_color="blue", background_color="white"),
        Coord.Cartesian(yflip=true)
    )
end