import numpy as np
from matplotlib import pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.ticker import FuncFormatter
import matplotlib.patches as patches
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# TODO (low priority) general refactor


def create_grid_plot(ax, grid, color_map="binary"):
    if color_map == "binary":
        grid = 1 - grid
    size_y = grid.shape[0]
    size_x = grid.shape[1]
    vmax = max(float(np.max(grid)), 1)
    vmin = 0

    mat = ax.matshow(np.flip(grid, 0), cmap=plt.get_cmap(color_map), extent=[0, size_x, 0, size_y], vmin=vmin, vmax=vmax)
    ax.set_xticks(np.arange(0, size_x))
    ax.set_xticks(np.arange(0.5, size_x + 0.5), minor=True)
    # ax.set_xticklabels(np.arange(0, size_x), minor=True)
    plt.setp(ax.get_xmajorticklabels(), visible=False)
    ax.set_yticks(np.arange(0, size_y))
    ax.set_yticks(np.arange(0.5, size_y + 0.5), minor=True)
    # ax.set_yticklabels(np.arange(0, size_y), minor=True)
    ax.invert_yaxis()
    plt.setp(ax.get_ymajorticklabels(), visible=False)

    ax.tick_params(axis='both', which='both', length=0)
    ax.grid(color="black")
    return mat


def create_grid_plot_values(ax, grid, color_map, coords, probs):
    grid = np.ma.masked_array(grid, grid == 0)
    for i in range(len(coords)):
        if grid[coords[i][0], coords[i][1]] != 0:
            x = coords[i][1]
            y = coords[i][0]
            grid[y][x] = probs[i]
            # if probs[i] < 0.9 * max(probs.detach().numpy()):
            #     color = "blue"
            # else:
            #     color = "deepskyblue"
            ax.text(x + 0.5, y + 0.5, "{}".format(int(probs[i] * 100)), horizontalalignment='center',
                    verticalalignment='center', color="black", fontsize=7)
    return create_grid_plot(ax, grid, color_map=color_map)


def plot_options(ax, probs, coords, grid, title_suffix="", colorbar_size='10%'):
    create_grid_plot(ax, grid)
    grid = np.array(grid, float)
    mat = create_grid_plot_values(ax, grid, "YlGn", coords, probs.numpy())
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
    plt.colorbar(mat, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0.0, 1.1, 0.25))
    ax.set_title(("Probability of choosing an option in states" + title_suffix))


def plot_terminations(ax, probs, coords, grid, title_suffix="", colorbar_size='10%'):
    create_grid_plot(ax, grid)
    grid = np.array(grid, float)
    mat = create_grid_plot_values(ax, grid, "OrRd", coords, probs.numpy())
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
    plt.colorbar(mat, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0.0, 1.1, 0.25))
    ax.set_title(("Termination probabilities in states" + title_suffix))


def plot_policy(ax, arrow_data, grid, title_suffix="", values=True, headwidth=9, headlength=20, colorbar_size='10%'):
    create_grid_plot(ax, grid)
    x_pos, y_pos, x_dir, y_dir, color = arrow_data
    quiv = ax.quiver(x_pos, y_pos, x_dir, y_dir, color, cmap=plt.get_cmap("viridis"),
                     norm=colors.Normalize(vmin=color.min(), vmax=color.max()), angles='xy', scale_units='xy',
                     scale=1, pivot='middle', clim=(0.3, 1), headwidth=headwidth, headaxislength=headlength, headlength=headlength)# width=0.1)
    divider = make_axes_locatable(ax)

    if values:
        for i in range(len(x_pos)):
            x = x_pos[i]
            y = y_pos[i]
            if x_dir[i] == 0:
                x -= 0.25
            else:
                y -= 0.25
            ax.text(x, y, "%2d" % (color[i] * 100), horizontalalignment='center',
                    verticalalignment='center', color="black", fontsize=7)
    cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
    plt.colorbar(quiv, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0.3, 1.1, 0.1))
    ax.set_title(("Maximum likelihood actions in states" + title_suffix))


def plot_policy_alternative(ax, arrow_data, grid, title_suffix="", values=True, colorbar_size='10%', fixed_colors=True):
    create_grid_plot(ax, grid)
    cmap = plt.get_cmap('viridis')
    x_pos, y_pos, x_dir, y_dir, color = arrow_data

    MAX_ARROW_WIDTH = 0.5
    MAX_ARROW_FORWARD = 0.4
    MAX_ARROW_BACKWARD = 0.5

    for i in range(len(x_pos)):
        arrow_width = MAX_ARROW_WIDTH * color[i]
        arrow_forward = MAX_ARROW_FORWARD * color[i]
        arrow_backward = MAX_ARROW_BACKWARD * color[i]

        x = float(x_pos[i])
        y = float(y_pos[i])
        if x_dir[i] == 1 and y_dir[i] == 0: # right
            c = 'limegreen' if fixed_colors else cmap(color[i])
            ax.add_patch(patches.Polygon(np.array([[x + arrow_forward, y],
                                                   [x - arrow_backward, y - arrow_width / 2],
                                                   [x - arrow_backward, y + arrow_width / 2]]),
                                         edgecolor=c, facecolor=c))
        elif x_dir[i] == -1 and y_dir[i] == 0: # left
            c = 'magenta' if fixed_colors else cmap(color[i])
            ax.add_patch(patches.Polygon(np.array([[x - arrow_forward, y_pos[i]],
                                                   [x + arrow_backward, y_pos[i] - arrow_width / 2],
                                                   [x + arrow_backward, y_pos[i] + arrow_width / 2]]),
                                         edgecolor=c, facecolor=c))
        elif x_dir[i] == 0 and y_dir[i] == 1: # top
            c = 'b' if fixed_colors else cmap(color[i])
            ax.add_patch(patches.Polygon(np.array([[x, y + arrow_forward],
                                                   [x - arrow_width / 2, y - arrow_backward],
                                                   [x + arrow_width / 2, y - arrow_backward]]),
                                         edgecolor=c, facecolor=c))
        elif x_dir[i] == 0 and y_dir[i] == -1: # bottom
            c = 'r' if fixed_colors else cmap(color[i])
            ax.add_patch(patches.Polygon(np.array([[x, y - arrow_forward],
                                                   [x - arrow_width / 2, y + arrow_backward],
                                                   [x + arrow_width / 2, y + arrow_backward]]),
                                         edgecolor=c, facecolor=c))
    if values:
        for i in range(len(x_pos)):
            x = x_pos[i]
            y = y_pos[i]
            if x_dir[i] == 0:
                x -= 0.25
            else:
                y -= 0.25
            ax.text(x, y, "%2d" % (color[i] * 100), horizontalalignment='center',
                    verticalalignment='center', color="black", fontsize=7)

    if not fixed_colors:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
        sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
        sm._A = []
        plt.colorbar(sm, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0, 1.1, 0.25))

    ax.set_title(("Maximum likelihood actions in states" + title_suffix))


def get_tf_data(dir: str, tag: str, timesteps: int, subdir_regex=None):
    paths = get_tf_paths(dir=dir, subdir_regex=subdir_regex)
    ret_val = []

    for path in paths:
        event_acc = EventAccumulator(path)
        event_acc.Reload()
        w_times, step_nums, vals = zip(*event_acc.Scalars(tag))
        if len(paths) == 1:
            ret_val = vals[:timesteps]
        else:
            ret_val.append(vals[:timesteps])
    return np.array(ret_val)


def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value

    return smoothed


def slidind_window_avg(vals, window_size):
    res = np.empty_like(vals)
    s = 0
    for i in range(len(vals)):
        if i >= window_size:
            s -= vals[i - window_size] / window_size
            s += vals[i] / window_size
        else:
            s = (s * i + vals[i]) / (i + 1)
        res[i] = s
    return res


def get_tf_paths(dir="../../run_data/test", subdir_regex=None):
    import os
    import re
    if subdir_regex is not None:
        pattern = re.compile(subdir_regex)
    tf_files = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if "tfevent" in file:
                full_path = os.path.join(root, file)
                if subdir_regex is None or pattern.search(full_path):
                    tf_files.append(full_path)
    return tf_files


if __name__ == '__main__':
    a = np.arange(100, dtype=float)
    c = np.convolve(a, np.ones(10)/10, mode='same')
    b = slidind_window_avg(a, 10)
