#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : plot.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fontm
import copy
import sys
from collections import defaultdict
from itertools import chain
import six


def register_plot_args(parser):
    plotting = parser.add_argument_group("logging")
    plotting.add_argument(
        "--column",
        "-c",
        help="describe each column in data, for example 'x,y,y'. \
        Default to 'y' for one column and 'x,y' for two columns. \
        Plot attributes can be appended after 'y', like 'ythick;cr'. \
        By default, assume all columns are y. ",
    )
    plotting.add_argument("--output", "-o", default="plot.png", help="output image")
    plotting.add_argument("--title", "-tt", default="", help="title of the figure")
    plotting.add_argument("--xlabel", "-xl", type=six.text_type, help="x label")
    plotting.add_argument("--ylabel", "-yl", type=six.text_type, help="y label")
    plotting.add_argument("--xlim", "-xlim", type=float, nargs=2, help="x lim")
    plotting.add_argument("--ylim", "-ylim", type=float, nargs=2, help="y lim")
    plotting.add_argument(
        "--ylim-ratio", "-yr", type=float, default=0.98, help="y lim ratio"
    )  # maintain 98% data points in default for better visualization
    plotting.add_argument(
        "--legend-size", "-lgds", type=float, default=8, help="legend size"
    )
    plotting.add_argument(
        "--circle-size", "-cs", type=float, default=30, help="circle size"
    )
    plotting.add_argument("--scale", "-sc", help="scale of each y, separated by comma")
    plotting.add_argument(
        "--annotate-maximum",
        "-anomax",
        action="store_true",
        help="annonate maximum value in graph",
    )
    plotting.add_argument(
        "--annotate-minimum",
        "-anomin",
        action="store_true",
        help="annonate minimum value in graph",
    )
    plotting.add_argument("--xkcd", action="store_true", help="xkcd style")
    plotting.add_argument("--log-scale", "-log", action="store_true", help="log scale")
    plotting.add_argument(
        "--decay",
        "-dec",
        type=float,
        default=0,
        help="exponential decay rate to smooth Y",
    )
    plotting.add_argument("--delimeter", help="column delimeter", default="\t")
    plotting.add_argument(
        "--plot_minmax", "-mm", action="store_true", help="plot minmax"
    )


def filter_valid_range(points, rect):
    """rect = (min_x, max_x, min_y, max_y)"""
    ret = []
    for x, y in points:
        if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]:
            ret.append((x, y))
    if len(ret) == 0:
        ret.append(points[0])
    return ret


def exponential_smooth(data, alpha):
    """smooth data by alpha. returned a smoothed version"""
    ret = np.copy(data)
    now = data[0]
    for k in range(len(data)):
        ret[k] = now * alpha + data[k] * (1 - alpha)
        now = ret[k]
    return ret


def annotate_min_max(data_x, data_y, ax, args):
    max_x, min_x = max(data_x), min(data_x)
    max_y, min_y = max(data_y), min(data_y)
    x_range = max_x - min_x
    y_range = max_y - min_y
    x_max, y_max = data_y[0], data_y[0]
    x_min, y_min = data_x[0], data_y[0]

    for i in range(1, len(data_x)):
        if data_y[i] > y_max:
            y_max = data_y[i]
            x_max = data_x[i]
        if data_y[i] < y_min:
            y_min = data_y[i]
            x_min = data_x[i]

    rect = ax.axis()
    if args.annotate_maximum:
        text_x, text_y = filter_valid_range(
            [
                (x_max + 0.05 * x_range, y_max + 0.025 * y_range),
                (x_max - 0.05 * x_range, y_max + 0.025 * y_range),
                (x_max + 0.05 * x_range, y_max - 0.025 * y_range),
                (x_max - 0.05 * x_range, y_max - 0.025 * y_range),
            ],
            rect,
        )[0]
        ax.annotate(
            "maximum ({:d},{:.3f})".format(int(x_max), y_max),
            xy=(x_max, y_max),
            xytext=(text_x, text_y),
            arrowprops=dict(arrowstyle="->"),
        )
    if args.annotate_minimum:
        text_x, text_y = filter_valid_range(
            [
                (x_min + 0.05 * x_range, y_min - 0.025 * y_range),
                (x_min - 0.05 * x_range, y_min - 0.025 * y_range),
                (x_min + 0.05 * x_range, y_min + 0.025 * y_range),
                (x_min - 0.05 * x_range, y_min + 0.025 * y_range),
            ],
            rect,
        )[0]
        ax.annotate(
            "minimum ({:d},{:.3f})".format(int(x_min), y_min),
            xy=(x_min, y_min),
            xytext=(text_x, text_y),
            arrowprops=dict(arrowstyle="->"),
        )


def plot_args_from_column_desc(desc):
    # if not desc:
    #     return {}
    ret = {}
    desc = desc.split(";")
    ret["lw"] = 0.8
    if "thick" in desc:
        ret["lw"] = 3
    if "dash" in desc:
        ret["ls"] = "--"
    for v in desc:
        if v.startswith("c"):
            ret["color"] = v[1:]
    return ret


def do_plot(data_xs, data_ys, args, legends):
    """
    data_xs: list of 1d array, either of size 1 or size len(data_ys)
    data_ys: list of dict, each key contains 1d array
    """
    fig = plt.figure(figsize=(16.18 / 1.2, 10 / 1.2))
    ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
    nr_y = len(data_ys)
    y_column = args.y_column

    if args.scale:
        scale = map(float, args.scale.split(","))
        assert len(scale) == nr_y
    else:
        scale = [1.0] * nr_y

    minx, maxx = None, None
    all_datapoints_y = []
    for yidx in range(nr_y):
        plotargs = plot_args_from_column_desc(y_column[yidx][1:])
        now_scale = scale[yidx]
        data_y = data_ys[yidx]
        # print(data_y['mean'])
        assert type(data_y) is dict or type(data_y) is defaultdict
        assert "mean" in data_y

        for key in data_y.keys():
            data_y[key] = np.array(data_y[key])
            data_y[key] = data_y[key] * now_scale
        leg = legends[yidx] if legends else None
        if now_scale != 1:
            leg = "{}*{}".format(
                now_scale if int(now_scale) != now_scale else int(now_scale), leg
            )
        data_x = data_xs[0] if len(data_xs) == 1 else data_xs[yidx]
        if minx is None:
            minx, maxx = min(data_x), max(data_x)
        else:
            minx, maxx = min(minx, min(data_x)), max(maxx, max(data_x))
        assert len(data_x) >= len(
            data_y["mean"]
        ), "x column is shorter than y column! {} < {}".format(
            len(data_x), len(data_y["mean"])
        )
        truncate_data_x = data_x[: len(data_y["mean"])]
        p = plt.plot(truncate_data_x, data_y["mean"], label=leg, **plotargs)

        c = p[0].get_color()
        if args.plot_minmax:
            if "min" in data_y:
                min_plotargs = copy.copy(plotargs)
                min_plotargs["ls"] = ":"
                min_plotargs["color"] = c
                p_min = plt.plot(truncate_data_x, data_y["min"], **min_plotargs)

            if "max" in data_y:
                max_plotargs = copy.copy(plotargs)
                # max_plotargs['lw'] = 5
                max_plotargs["ls"] = "--"
                max_plotargs["color"] = c
                p_min = plt.plot(truncate_data_x, data_y["max"], **max_plotargs)

        if "std" in data_y:
            std = data_y["std"]
            mean = data_y["mean"]
            plt.fill_between(
                truncate_data_x, mean - std, mean + std, alpha=0.1, facecolor=c
            )
        # Draw a circle on current pos
        cur = len(data_y["mean"])
        if cur < len(data_x):
            plt.scatter(data_x[cur - 1], data_y["mean"][-1], s=args.circle_size, c=c)

        if args.annotate_maximum or args.annotate_minimum:
            annotate_min_max(truncate_data_x, data_y["mean"], ax, args)

        all_datapoints_y.extend(data_y["mean"])

    all_datapoints_y = sorted(all_datapoints_y)
    num_datapoints = len(all_datapoints_y)
    if args.ylim_ratio > 0.0:
        ratio = args.ylim_ratio
        extra = 0 if args.log_scale else 0.01
        if getattr(args, "smaller_better", False):
            ylow = all_datapoints_y[0]
            yhigh = all_datapoints_y[int((num_datapoints - 1) * ratio)]
            ylow -= (yhigh - ylow) * extra
        else:
            ylow = all_datapoints_y[int((num_datapoints - 1) * (1 - ratio))]
            yhigh = all_datapoints_y[-1]
            yhigh += (yhigh - ylow) * extra
        plt.ylim(ylow, yhigh)

    if args.xlabel:
        plt.xlabel(args.xlabel, fontsize="20")
    if args.ylabel:
        plt.ylabel(args.ylabel, fontsize="20")
    if args.xlim:
        plt.xlim(args.xlim[0], args.xlim[1])
    if args.ylim:
        plt.ylim(args.ylim[0], args.ylim[1])
    plt.legend(loc="best", fontsize=args.legend_size)
    # plt.legend(loc="best", fontsize="20")

    # adjust maxx
    # minx, maxx = min(data_x), max(data_x)
    new_maxx = maxx + (maxx - minx) * 0.05
    plt.xlim(minx, new_maxx)

    for label in chain.from_iterable([ax.get_xticklabels(), ax.get_yticklabels()]):
        label.set_fontproperties(fontm.FontProperties(size=15))

    ax.grid(color="gray", linestyle="dashed")
    if args.log_scale:
        ax.set_yscale("log")

    plt.title(args.title, fontdict={"fontsize": "30"})
    # print(args.output)
    if args.output != "":
        plt.savefig(args.output, bbox_inches="tight")
    # if args.show:
    #     plt.show()


def main(data_xs, data_ys, args, legends=None):
    # parse input args

    nr_column = len(data_ys)
    # print(data_ys)

    if args.column is None:
        column = ["y"] * nr_column
    else:
        column = args.column.strip().split(",")
    for k in column:
        assert k[0] in ["x", "y"]
    assert nr_column == len(
        column
    ), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
    args.y_column = [v for v in column if v[0] == "y"]
    args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == "y"]
    args.x_column = [v for v in column if v[0] == "x"]
    args.x_column_idx = [idx for idx, v in enumerate(column) if v[0] == "x"]
    nr_x_column = len(args.x_column)
    nr_y_column = len(args.y_column)
    if nr_x_column > 1:
        assert (
            nr_x_column == nr_y_column
        ), "If multiple x columns are used, nr_x_column must equals to nr_y_column"

    for idx, data_y in enumerate(data_ys):
        if args.decay != 0:
            for k in data_y.keys():
                if len(data_y[k]) > 0:
                    data_ys[idx][k] = exponential_smooth(data_y[k], args.decay)
    do_plot(data_xs, data_ys, args, legends)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    register_plot_args(parser)
    args = parser.parse_args()
    n = 50
    data_xs = [list(range(n))]

    def get_rand_y(m):
        y = dict(mean=[], min=[], max=[], std=[])
        for i in range(n):
            a = np.random.randint(m, size=10)
            y["mean"].append(np.mean(a))
            y["min"].append(np.min(a))
            y["max"].append(np.max(a))
            y["std"].append(np.std(a))
        for k in y.keys():
            y[k] = np.array(y[k])
        return y

    y1, y2, y3 = get_rand_y(5), get_rand_y(10), get_rand_y(15)

    data_ys = [y1, y2, y3]
    main(data_xs, data_ys, args)
