import argparse
import os
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib.style as mplstyle
import numpy as np
from matplotlib.ticker import FuncFormatter
from plot.read_data import collect_data_for_evo_plot
from plot.utils import (GREAT_GNUPLOT_COLORS, custom_sort, format_func,
                        pathname_directory, pathname_name)
from tqdm import tqdm

mplstyle.use("fast")

parser = argparse.ArgumentParser()

parser.add_argument(
    "--exp-dir",
    required=True,
    type=str,
)

parser.add_argument(
    "--exp-name",
    type=str,
)

parser.add_argument(
    "--monitors",
    required=True,
    nargs="+",
)


parser.add_argument(
    "--title", default=None, type=str, help="The title of the plot. Default: no title."
)
parser.add_argument(
    "--captions",
    default=[],
    nargs="*",
    help="""Captions for the lines on the plot.
                            Default: use the filenames of the
                            raw data files.""",
)
parser.add_argument(
    "--average-windows",
    default=5000,
    nargs="*",
    help="""The size of the windows used for the rolling mean.
                            Either specify a single number used for all lines,
                            or specify a separate number per line or specify
                            None if no averaging needs to be performed.""",
)
parser.add_argument(
    "--window-type",
    default="standard",
    type=str,
    choices=["standard", "half"],
    help="Rolling window type.",
)
parser.add_argument(
    "--average-mode",
    default="mean",
    type=str,
    choices=["mean", "median"],
    help="Mode used to compute the average across the series.",
)
parser.add_argument(
    "--error-bars",
    default="stdev",
    type=str,
    choices=["stdev", "min_max", "percentile"],
    help="Mode used to compute error bars across the series.",
)
parser.add_argument(
    "--percentiles",
    default=[5, 95],
    type=int,
    nargs=2,
    help="""When using the percentile error bars,
                            use this argument to specify the quantiles used for
                            the lower- and higher-end, respectively.""",
)
parser.add_argument(
    "--error-bar-modes",
    default="lines",
    type=str,
    choices=["filled", "lines"],
    help="Mode used to draw the error bars.",
)

parser.add_argument(
    "--graphic-type",
    default="png",
    type=str,
    choices=["pdf", "png", "svg", "ps"],
    help="Type of the graphic file.",
)

parser.add_argument(
    "--only-x-last-interactions",
    default=None,
    type=int,
    help="Only plot the x last interactions.",
)
parser.add_argument(
    "--start", default=None, type=int, help="Only plot the interactions after start."
)
parser.add_argument(
    "--end", default=None, type=int, help="Only plot the interactions before end."
)
parser.add_argument(
    "--series-numbers",
    default=None,
    type=int,
    nargs="*",
    help="Which series to plot. Series start from 1.",
)
parser.add_argument(
    "--key-location",
    default="best",
    type=str,
    choices=[
        "best",
        "upper right",
        "upper left",
        "lower right",
        "lower left",
        "center left",
        "center right",
        "lower center",
        "upper center",
        "center",
    ],
    help="Location of the legend on the plot.",
)
parser.add_argument(
    "--line-width", default=1.5, type=int, help="Width of the lines on the plot."
)
parser.add_argument(
    "--dashed", dest="dashed", action="store_true", help="Whether to use dashed lines."
)
parser.add_argument(
    "--solid",
    dest="dashed",
    action="store_false",
    help="Whether not to use dashed lines",
)
parser.add_argument(
    "--colors",
    default=GREAT_GNUPLOT_COLORS,
    type=str,
    nargs="*",
    help="Colors to use for plotting the lines.",
)
parser.add_argument("--font-size", default=10, type=int, help="Font size on the plot.")
parser.add_argument(
    "--use-y-axis",
    default=None,
    type=int,
    choices=[1, 2],
    nargs="*",
    help="""Which y-axis to use for each line.
                            The left y-axis is denoted by 1.
                            The right y-axis is denoted by 2.""",
)
parser.add_argument(
    "--y1-min", default=None, type=float, help="Minimum value of the left y-axis."
)
parser.add_argument(
    "--y1-max", default=None, type=float, help="Maximum value of the left y-axis."
)
parser.add_argument(
    "--y2-min", default=None, type=float, help="Minimum value of the right y-axis."
)
parser.add_argument(
    "--y2-max", default=None, type=float, help="Maximum value of the right y-axis."
)
parser.add_argument(
    "--draw-y1-grid",
    dest="draw_y1_grid",
    action="store_true",
    help="Whether to draw the grid for the left y-axis.",
)
parser.add_argument(
    "--hide-y1-grid",
    dest="draw_y1_grid",
    action="store_false",
    help="Whether to hide the grid for the left y-axis.",
)
parser.add_argument(
    "--draw-y2-grid",
    dest="draw_y2_grid",
    action="store_true",
    help="Whether to draw the grid for the right y-axis.",
)
parser.add_argument(
    "--hide-y2-grid",
    dest="draw_y2_grid",
    action="store_false",
    help="Whether to hide the grid for the right y-axis.",
)
parser.add_argument(
    "--grid-line-width",
    default=0.5,
    type=float,
    help="Line width to use for drawing the grid.",
)
parser.add_argument(
    "--x-label", default="Number of games played", type=str, help="Label on the x-axis."
)
parser.add_argument(
    "--y1-label", default=None, type=str, help="Label of the left y-axis."
)
parser.add_argument(
    "--y2-label", default=None, type=str, help="Label on the right y-axis."
)
parser.add_argument(
    "--logscale",
    action="store_true",
    help="Whether to use a logarithmic scale for the x-axes.",
)
parser.add_argument(
    "--export", dest="export", action="store_true", help="Whether to export the file."
)
parser.add_argument(
    "--not-export",
    dest="export",
    action="store_false",
    help="Whether not to export the file.",
)
parser.add_argument(
    "--per-agent",
    default=None,
    type=int,
    help="Plot per agent, either None or the amount of agents in the population.",
)
parser.add_argument(
    "--every-x-interactions",
    default=None,
    type=int,
    help="Reduce amount of data points by plotting every x-th datapoint.",
)
parser.add_argument(
    "--plot-type",
    choices=["single", "multi"],
    required=True,
    type=str,
)

parser.set_defaults(dashed=True, draw_y1_grid=True, export=True)


line_list = {
    "communicative-success": "dotted",
    "lexicon-coherence": "dashed",
    "unique-form-usage": "dashdot",
}


def raw_files_to_evo_plot(
    exp_dir,
    exp_name,
    monitors,
    plot_type,
    title=None,
    captions=[],
    average_windows=5000,
    window_type=None,
    average_mode="mean",
    error_bars="stdev",
    percentiles=[5, 95],
    error_bar_modes="lines",
    graphic_type="pdf",
    only_x_last_interactions=None,
    start=None,
    end=None,
    series_numbers=None,
    key_location="best",
    line_width=1.5,
    dashed=True,
    colors=GREAT_GNUPLOT_COLORS,
    font_size=10,
    use_y_axis=None,
    y1_min=None,
    y1_max=None,
    y2_min=None,
    y2_max=None,
    draw_y1_grid=True,
    draw_y2_grid=False,
    grid_line_width=0.5,
    x_label="Number of games played",
    y1_label=None,
    y2_label=None,
    logscale=False,
    export=True,
    per_agent=None,
    every_x_interactions=None,
):
    raw_file_paths = []
    fpaths = custom_sort([f.path for f in os.scandir(exp_dir) if f.is_dir()])
    # fpaths = reversed(fpaths)
    for fpath in fpaths:
        exp_dir2 = [f.path for f in os.scandir(fpath) if f.is_dir()]
        if exp_dir2:
            exp_dir2 = exp_dir2[0]
        else:
            exp_dir2 = fpath
        # add a raw path for each monitor
        for monitor in monitors:
            raw_file_paths.append(os.path.join(exp_dir2, monitor))

    if average_windows is None:
        average_windows = [None for path in raw_file_paths]
    if type(average_windows) is int:
        # Support for passing a single number to average_windows,
        # it is set to each raw_file_path
        average_windows = [average_windows for path in raw_file_paths]
    if type(average_windows) is list:
        if len(average_windows) != len(raw_file_paths):
            raise Exception(
                """Length of raw-file-names should
                               be equal to length of average-windows,
                               or it can be a single number."""
            )

    if use_y_axis is None:
        use_y_axis = [1 for path in raw_file_paths]

    if len(use_y_axis) != len(raw_file_paths):
        use_y_axis = use_y_axis * (len(raw_file_paths) // len(use_y_axis))

    if plot_type == "multi":
        captions = [f'{p.split("/")[6]}'.upper() for p in raw_file_paths]

    colors = custom_sort(
        GREAT_GNUPLOT_COLORS[: len(raw_file_paths) // len(monitors)] * len(monitors)
    )

    data_set = collect_data_for_evo_plot(
        raw_file_paths,
        only_x_last_interactions=only_x_last_interactions,
        start=start,
        end=end,
        series_numbers=series_numbers,
        window_type=window_type,
        windows=average_windows,
        average_mode=average_mode,
        error_mode=error_bars,
        percentiles=percentiles,
        per_agent=per_agent,
        every_x_interactions=every_x_interactions,
    )

    w, h = 5, 6
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(w, h), sharex=True)

    handlers = []
    for i, (data, y_axis) in tqdm(
        enumerate(zip(data_set, use_y_axis)), total=len(data_set)
    ):
        raw_file_path = raw_file_paths[i]
        # get filename of raw file
        monitor_name = pathname_name(raw_file_path)

        if monitor_name == "communicative-success":
            ax = ax1
        elif monitor_name == "lexicon-coherence":
            ax = ax2
        elif monitor_name == "unique-form-usage":
            ax = ax3

        # ax = ax1 if y_axis == 1 else ax2
        plot_kwargs = {
            "linestyle": line_list[monitor_name],
            "color": colors[i],
            "label": captions[i],
            "linewidth": line_width,
        }

        (l1,) = ax.plot(data["average"], **plot_kwargs)
        handlers.append(l1)
        low_end = data["average"] - data["error_low"]
        high_end = data["average"] + data["error_high"]
        ax.fill_between(
            data.index.values.tolist(),
            y1=low_end,
            y2=high_end,
            alpha=0.25,
            linewidth=0,
            color=colors[i],
        )

    ax1.set_xmargin(0)
    ax1.xaxis.set_major_formatter(FuncFormatter(format_func))
    if title:
        fig.suptitle(title)

    ax1.set_ylim(bottom=y1_min)
    ax1.set_ylim(top=y1_max)
    ax2.set_ylim(bottom=y2_min)
    ax2.set_ylim(top=y2_max)

    if draw_y1_grid:
        major_ticks = np.arange(0, 1.01, 0.25)
        ax1.set_yticks(major_ticks)
        ax2.set_yticks(major_ticks)
        major_ticks = np.arange(0, 201, 50)
        ax3.set_yticks(major_ticks)

        ax1.grid(True, axis="y", linewidth=grid_line_width)
        ax2.grid(True, axis="y", linewidth=grid_line_width)
        ax3.grid(True, axis="y", linewidth=grid_line_width)

    ax3.set_xlabel(x_label)

    ax1.set_ylabel("Comm. success")
    ax2.set_ylabel("Ling. coherence")
    ax3.set_ylabel("Avg. inventory size")

    if logscale:
        ax1.set_xscale("log")

    ax3.legend(
        handles=[handlers[0], handlers[3]],
        labels=[captions[0], captions[3]],
        loc="upper center",
        bbox_to_anchor=(0.5, -0.3),
        fancybox=False,
        shadow=False,
        ncol=3,
    )

    if export:
        if plot_type == "single":
            output_dir = os.path.join(
                pathname_directory(pathname_directory(pathname_directory(exp_dir))),
                "plots",
                "single",
            )
        elif plot_type == "multi":
            output_dir = os.path.join(
                pathname_directory(pathname_directory(exp_dir)), "plots", "multi"
            )
        os.makedirs(output_dir, exist_ok=True)
        plot_file_name = str(datetime.now().strftime("%Y-%m-%d_%Hh%Mm%Ss"))
        output_path = os.path.join(
            output_dir, f"{os.path.basename(exp_dir)}-{plot_file_name}.{graphic_type}"
        )
        fig.savefig(output_path, format=f"{graphic_type}", bbox_inches="tight")


if __name__ == "__main__":
    args = parser.parse_args()
    if len(args.average_windows) == 1:
        if args.average_windows[0] == "None":
            args.average_windows = None
        else:
            args.average_windows = int(args.average_windows[0])
    else:
        args.average_windows = [
            None if x == "None" else int(x) for x in args.average_windows
        ]
    raw_files_to_evo_plot(**vars(args))
