module PlottingTools

using LinearAlgebra
using Statistics
using MATLAB


export plotTrajectory, plotTrajectoryOfMean, plotTrajectoryInSimplex, plotTrajectoryInSquare, plotHeatMap, saveFigure, newFigure



module PlotTypes
  @enum PlotType Plot ScatterPlot
end


function saveFigure(matlab_session, filename="latest"; append=false, kwargs...)

eval_string(matlab_session, """  
  hold off;

  exportgraphics(canvas, 'Figures/$filename.pdf', 'ContentType', ...
    '$(get(kwargs, :content_type, "vector"))', 'Append', $append);

  savefig(canvas, 'Code/Output/latest.fig');
  close(canvas);
""")


matlab_session

end


function plotTrajectoryOfMean(matlab_session, trajectories; transform=identity, kwargs...)

  color = get(kwargs, :color, "'Black'")

  y_limits = get(kwargs, :y_limits, (0, "inf"))


  if get(kwargs, :new_figure, true)

    eval_string(matlab_session, "canvas = figure();")

    if get(kwargs, :golden_ratio, true)
      eval_string(matlab_session, "pbaspect([(1 + sqrt(5)) / 2, 1, 1]);")
    else
      eval_string(matlab_session, "pbaspect([1, 1, 1]);")
    end

    eval_string(matlab_session, "set(gca, 'Box', 'off');");


    if get(kwargs, :log_scale, false)
      eval_string(matlab_session, "set(gca, 'YScale', 'log');")
    end

    eval_string(matlab_session, """
      set(gca, 'YLim', [$(y_limits[begin]), $(y_limits[end])]);
    """)


    if get(kwargs, :grid, false)
      eval_string(matlab_session, "grid on;")
    end


    if get(kwargs, :labels, true)

      eval_string(matlab_session, """
        xlabel('Time (\$t\$)', 'Interpreter', 'latex');
      """)

    else

      eval_string(matlab_session, """
        set(gca, 'XTick', {});
        set(gca, 'YTick', {});
        set(gca, 'XTickLabel', {});
        set(gca, 'YTickLabel', {});
      """)

    end

  end

  eval_string(matlab_session, "hold on;")


  times = Float64[timepoint.time for timepoint ∈ first(trajectories)]

  values = Float64[
    transform(trajectory[t].state)
    for trajectory ∈ trajectories, t ∈ axes(times, 1)
  ]

  values_mean = mean(values; dims=1)
  values_std = stdm(values, values_mean; dims=1)


  patch_area_x = collect(vcat(times, reverse(times)))
  patch_area_y = vec(hcat(
    values_mean .+ values_std,
    reverse(values_mean .- values_std)
  ))

  eval_string(matlab_session, """

    confidence_interval = patch($patch_area_x, $patch_area_y, $color);

    set(confidence_interval, 'FaceColor', $color);
    set(confidence_interval, 'FaceAlpha', 0.1);
    set(confidence_interval, 'EdgeColor', $color);
    set(confidence_interval, 'EdgeAlpha', 0.5);
    set(confidence_interval, 'LineWidth', 0.3);
    set(get(get(confidence_interval, 'Annotation'), 'LegendInformation'),'IconDisplayStyle','off');

    trajectory = plot($times, $values_mean);

    set(trajectory, 'Color', $color);
    set(trajectory, 'LineWidth', $(get(kwargs, :line_width, 1.5)));
    set(trajectory, 'LineStyle', $(get(kwargs, :line_style, "'-'")));
    set(trajectory, 'DisplayName', $(get(kwargs, :name, "''")));

    legend;

    hold off;
  """)

  if !get(kwargs, :legends, false)
    eval_string(matlab_session, "legend off;")
  end


  matlab_session

end


function plotTrajectory(matlab_session, trajectory; kwargs...)
  plotTrajectoryOfMean(matlab_session, [trajectory]; kwargs...)
end
  

function plotTrajectoryInSimplex(matlab_session, trajectory; transform=identity, kwargs...)

  color = get(kwargs, :color, "'Black'")


  if get(kwargs, :new_figure, true)

    eval_string(matlab_session, """
      canvas = figure();

      pbaspect([1, 1, 1]);
      axis off;

      set(gca, 'XLim', [0, 1]);
      set(gca, 'YLim', [0, 1]);

      hold on;

      simplex = plot(polyshape([0, 0.5, 1], [0, sqrt(3) / 2, 0]));

      set(simplex, 'FaceColor', $color);
      set(simplex, 'FaceAlpha', 0.075);
      set(simplex, 'EdgeColor', $color);
      set(simplex, 'EdgeAlpha', 0.25);

      hold off;
    """)

  end

  eval_string(matlab_session, "hold on;")


  data = [transform(timepoint.state) for timepoint ∈ trajectory]

  x_data = [point[1] + 0.5point[2] for point ∈ data]
  y_data = [0.5√3point[2] for point ∈ data]


  plot_type = get(kwargs, :plot_type, PlotTypes.Plot)

  if plot_type == PlotTypes.Plot

    eval_string(matlab_session, """
      trajectory = plot($x_data, $y_data);

      set(trajectory, 'Color', $color);
      set(trajectory, 'LineWidth', $(get(kwargs, :line_width, 1.5)));
    """)

  elseif plot_type == PlotTypes.ScatterPlot

    eval_string(matlab_session, """
      trajectory = scatter($x_data, $y_data);

      set(trajectory, 'MarkerEdgeColor', $color);
      set(trajectory, 'MarkerFaceColor', $color);
      set(trajectory, 'SizeData', $(get(kwargs, :marker_size, 1)));
    """)
  end


  point_text = get(kwargs, :point_text, nothing)
  
  if !isnothing(point_text)

    text_offset = get(kwargs, :text_offset, 0.075)

    point₃ = [cosd(30), sind(30)] * text_offset
    point₁ = [1 - point₃[1], point₃[2]]
    point₂ = [0.5, sqrt(3) / 2 - text_offset]

    eval_string(matlab_session, """

      point_1 = text($(point₁[1]), $(point₁[2]), $(point_text[1]));

      set(point_1, 'HorizontalAlignment', 'center');
      set(point_1, 'VerticalAlignment', 'middle');
      set(point_1, 'Rotation', 60);
      set(point_1, 'FontWeight', 'bold');

      point_2 = text($(point₂[1]), $(point₂[2]), $(point_text[2]));

      set(point_2, 'HorizontalAlignment', 'center');
      set(point_2, 'VerticalAlignment', 'middle');
      set(point_2, 'Rotation', 180);
      set(point_2, 'FontWeight', 'bold');

      point_3 = text($(point₃[1]), $(point₃[2]), $(point_text[3]));

      set(point_3, 'HorizontalAlignment', 'center');
      set(point_3, 'VerticalAlignment', 'middle');
      set(point_3, 'Rotation', -60);
      set(point_3, 'FontWeight', 'bold');
    """)

  end

  eval_string(matlab_session, "hold off;")


  matlab_session
  
end


function plotTrajectoryInSquare(matlab_session, trajectory; transform=identity, kwargs...)

  color₁ = get(kwargs, :color₁, "'Black'")
  color₂ = get(kwargs, :color₂, color₁)

  x_limits = get(kwargs, :x_limits, (0, 1))
  y_limits = get(kwargs, :y_limits, (0, 1))


  if get(kwargs, :new_figure, true)

    eval_string(matlab_session, """
      canvas = figure();

      pbaspect([1, 1, 1]);

      set(gca, 'XLim', [$(x_limits[begin]), $(x_limits[end])]);
      set(gca, 'YLim', [$(y_limits[begin]), $(y_limits[end])]);
    """)

  end

  eval_string(matlab_session, "hold on;")


  data = [transform(timepoint.state) for timepoint ∈ trajectory]

  x_data = [point[1] for point ∈ data]
  y_data = [point[2] for point ∈ data]


  plot_type = get(kwargs, :plot_type, PlotTypes.Plot)

  if plot_type == PlotTypes.Plot

    eval_string(matlab_session, """
      trajectory = plot($x_data, $y_data);

      set(trajectory, 'LineWidth', $(get(kwargs, :line_width, 1.5)));
      set(trajectory, 'Color', $color₁);
    """)

  elseif plot_type == PlotTypes.ScatterPlot

    eval_string(matlab_session, """
      trajectory = scatter($x_data, $y_data);

      set(trajectory, 'Marker', $(get(kwargs, :marker_symbol, "'diamond'")));
      set(trajectory, 'LineWidth', $(get(kwargs, :line_width, 1.5)));
      set(trajectory, 'MarkerEdgeColor', $color₂);
      set(trajectory, 'SizeData', $(get(kwargs, :marker_size, 1)));
      set(trajectory, 'MarkerFaceColor', $color₁);
    """)

  end

  eval_string(matlab_session, "hold off;")


  matlab_session

end


function plotHeatMap(matlab_session, colors::Matrix{Float64}; kwargs...)

  x_limits = get(kwargs, :x_limits, (0, 1))
  y_limits = get(kwargs, :y_limits, (0, 1))

  x_data = repeat(
    LinRange(x_limits..., size(colors, 2))',
    size(colors, 1), 1
  )
  y_data = repeat(
    LinRange(y_limits..., size(colors, 1)),
    1, size(colors, 2)
  )


  if get(kwargs, :new_figure, true)

    eval_string(matlab_session, """
      canvas = figure();

      pbaspect([1, 1, 1]);

      set(gca, 'XLim', [$(x_limits[begin]), $(x_limits[end])]);
      set(gca, 'YLim', [$(y_limits[begin]), $(y_limits[end])]);
    """)

  end

  eval_string(matlab_session, "hold on;")


  levels = collect(LinRange(extrema(colors)..., get(kwargs, :levels, 15)))

  eval_string(matlab_session, """
    colormap($(get(kwargs, :colormap, "jet")));

    [contour_map, contours] = contour($x_data, $y_data, $colors);
    set(contours, 'LevelList', $levels);
    set(contours, 'EdgeColor', 'Black');


    heatmap = imagesc([$(x_limits[begin]), $(x_limits[end])], [$(y_limits[begin]), $(y_limits[end])], $colors);
    
    set(heatmap, 'AlphaData', 0.8);


    hold off;
  """)


  matlab_session
end


function newFigure(fn, filename; kwargs...)

  matlab_session = MSession()

  fn(matlab_session)
  saveFigure(matlab_session, filename; kwargs...)
  
  close(matlab_session)

end



end