#!/usr/bin/env python3

import argparse

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator, MaxNLocator

import latency_env
import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser
from latency_env.delayed_mdp import (
    delay_from_string,
    ConstantDelay,
    RandomCategoricalDelay,
)

DEFAULT_WIDTH = 8.0
DEFAULT_HEIGHT = 6.0

parser = ArgumentParser()
parser += ArgList(
    delay = Arg(type=delay_from_string,
                help="String representation of the delay to plot"),
    output_file = Arg("-o", "--output-file", type=str, default=None,
                      help="Specific output file name. (Default: show only)"),
    title = Arg("-t", "--title", type=str, default=None,
                help="Title of the generated plot."),
    width = Arg("--width", type=at.posint, default=DEFAULT_WIDTH, help="Width of the figure in inches."),
    height = Arg("--height", type=at.posint, default=DEFAULT_HEIGHT, help="Height of the figure in inches."),
    fontsize = Arg("--fontsize", type=at.posint, default=None, help="Custom font size for all text."),
    min = Arg("--min", type=at.posint, default=None,
                help="Minimum delay to plot."),
    max = Arg("--max", type=at.posint, default=None,
                help="Maximum delay to plot."),
    seed = Arg("--seed", type=at.nonnegint, default=32,
                help="The seed to use."),
    hbar = Arg("--hbar", type=at.posint, default=None,
                help="Include a red horizontal bar at this y value (--time-series only)."),
    always_sample = Arg("--always-sample", action=argparse.BooleanOptionalAction, default=False,
                        help="Ignore the special cases and always sample from the distribution."),
    plot_time_series = Arg("--time-series", action=argparse.BooleanOptionalAction, default=False,
                           help="Visualize the distribution as a time-series plot instead."),
)
parser += latency_env.training.utils.logging_arguments
args = parser.parse_args()

latency_env.training.utils.seed_all(args.seed)

label_kwargs = {}
if args.fontsize is not None:
    label_kwargs |= {"fontsize": args.fontsize}

fig, ax = plt.subplots(1, 1)
fig.set_size_inches(args.width, args.height)

if args.plot_time_series:
    args.delay.reset()
    ticks = []
    smooth_ticks = []
    for i in range(1000):
        ticks.append(args.delay.sample())
        smtick = 0.0
        for j in range(0, min(i+1, 10)):
            smtick += ticks[i - j] * (0.99 / ((1/(1 - 0.99)) ** j))
        smooth_ticks.append(smtick)

    X = np.arange(1000)
    Y = np.array(ticks)

    ax.plot(X, Y)
    ax.set_xlabel("Time Step", **label_kwargs)
    ax.set_ylabel("Sampled Delay", **label_kwargs)

    if args.hbar is not None:
        ax.plot(np.arange(1200)-100, np.full(1200, args.hbar), color=(1.0, 0.0, 0.0, 0.75))

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_minor_locator(AutoMinorLocator(n=4))

    ax.set_xbound(lower=-50, upper=1050)
    ax.grid(which="major", alpha=0.8, axis="y")
    ax.grid(which="minor", alpha=0.3, axis="y")
else:
    # Plot probability of sampling any specific delay
    ticks = args.delay.distribution()
    print(f"Ticks: {ticks}")

    x_min = args.min if args.min is not None else min(ticks.keys())
    x_max = args.max if args.max is not None else max(ticks.keys())
    assert x_min <= x_max

    X_labels = [f"{i}" for i in range(x_min, x_max + 1)]
    X = np.arange(len(X_labels))
    Y = np.array([ticks.get(i, 0.0) for i in range(x_min, x_max + 1)])

    ax.bar(X, Y, tick_label=X_labels)
    ax.set_xlabel("Delay (steps)", **label_kwargs)
    ax.set_ylabel("Probability", **label_kwargs)

    ax.yaxis.set_minor_locator(AutoMinorLocator(n=4))

    ax.grid(which="major", alpha=0.8, axis="y")
    ax.grid(which="minor", alpha=0.3, axis="y")

if args.output_file is not None:
    oidx = args.output_file.rfind(".")
    oname, oending = (args.output_file, ".pdf") if oidx < 0 else (args.output_file[:oidx], args.output_file[oidx:])
    if args.width != DEFAULT_WIDTH:
        oname += f"-w{int(args.width)}"
    if args.height != DEFAULT_HEIGHT:
        oname += f"-h{int(args.height)}"
    if args.fontsize is not None:
        oname += f"-fontsize{int(args.fontsize)}"
    outname = f"{oname}{oending}"
    print(f"Saved plot to file: {outname}")
    plt.savefig(outname, bbox_inches="tight")
else:
    plt.show()
