import os.path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
from matplotlib.table import Table


class Painter:
    def __init__(self, solver):
        self.solver = solver
        self.fig, self.axes = plt.subplots(nrows=1, ncols=1, figsize=(14, 12), squeeze=False)
        self.ax = self.axes[0,0]
        self.ax.set_aspect('1.0')
        self.ax.set_axis_off()
        self.tb = Table(self.ax, bbox=[0,0,1,1])
        self.flag_table = False


    def determine_edge_color(self, rgb):
        red = rgb[0] * 255
        green = rgb[1] * 255
        blue = rgb[2] * 255
        color = 'black'
        if ((red * 0.299 + green * 0.587 + blue * 0.114) < 150):
            color = 'white'
        return color


    def draw_state_value_by_table(self, value_2d=None, min=None, max=None):
        ax = self.ax
        tb = self.tb
        self.tb.set_zorder(0)
        self.ax.set_rasterization_zorder(1)

        if value_2d is None:
            value_2d = self.solver.get_V_2D()

        nrows, ncols = value_2d.shape
        width, height = 1.0 / ncols, 1.0 / nrows

        if min == None:
            self.vmin = np.floor(np.min(value_2d))
        else:
            self.vmin = min
        if max == None:
            self.vmax = np.ceil(np.max(value_2d))
        else:
            self.vmax = max
        norm = colors.Normalize(vmin=self.vmin, vmax=self.vmax)
        # print(norm(1), norm(2), norm(3))
        cmap = cm.get_cmap('jet')
        mappable = cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable._A = []
        cbar = plt.colorbar(mappable, ax=plt.gca())
        ticks = np.linspace(norm.vmin, norm.vmax, 5)
        cbar.set_ticks(ticks)
        cbar.ax.set_yticklabels(['{0:.2f}'.format(s) for s in ticks])

        # for (i,j), wall in np.ndenumerate(maze):
        for (i,j), wall in np.ndenumerate(self.solver.env.grid):
            if wall == 1.0:
                facecolor = 'dimgrey'
                text = ''
                txt_color = 'black'
            else:
                facecolor = colors.to_rgb(cm.jet((value_2d[i,j] - self.vmin)/(self.vmax-self.vmin)))
                txt_color = self.determine_edge_color(facecolor)
                # text = "{0:.2f}".format(value_2d[i,j])
                text = ''

            tb.add_cell(i, j, width, height, text=text, loc='center', facecolor=facecolor, edgecolor='gray')
            # tb._cells[(i, j)]._text.set_size(25)
            # tb._cells[(i, j)]._text.set_weight('bold')
            # tb._cells[(i, j)]._text.set_weight('extra bold')
            # tb._cells[(i, j)]._text.set_color(txt_color)

        ax.add_table(tb)
        self.flag_table = True


    def draw_state_value(self):
        ax = self.ax
        fig = self.fig

        value_2d = self.solver.get_V_2D()

        nrows, ncols = value_2d.shape
        width, height = 1.0 / ncols, 1.0 / nrows

        x = np.arange(nrows+1)
        y = np.arange(ncols+1)
        X,Y = np.meshgrid(x, y)
        c = ax.pcolor(X,Y,value_2d,cmap='jet',edgecolors='k', linewidths=2)
        fig.colorbar(c, ax=ax)


    def draw_policy(self):
        ax = self.ax
        # tb = self.tb

        value_2d = self.solver.get_V_2D()
        policy = self.solver.policy

        nrows, ncols = value_2d.shape
        width, height = 1.0 / ncols, 1.0 / nrows

        for (i,j), wall in np.ndenumerate(self.solver.env.grid):
            if wall != 1.0:
                s = self.solver.env.xytos(j,i)
                pi = policy[s,:]

                if self.flag_table:
                    x = (j + 0.5) * width
                    y = (nrows-1-i + 0.5) * height
                else:
                    x = (j + 0.5)
                    y = (nrows-1-i + 0.5)
                for a in range(0,len(pi)):
                    if pi[a] == 0.0:
                        continue

                    head = np.empty(2)
                    tail = np.empty(2)

                    head[0] = x
                    head[1] = y
                    tail[0] = x
                    tail[1] = y

                    if a == 0:
                        head[1] -= pi[a] * 0.9 * height
                    elif a == 1:
                        head[0] += pi[a] * 0.9 * width
                    elif a == 2:
                        head[1] += pi[a] * 0.9 * height
                    elif a == 3:
                        head[0] -= pi[a] * 0.9 * width
                    if pi[a] == 1.0:
                        if np.argmax(pi) == 0:
                            head[1] += height * 0.5
                            tail[1] += height * 0.4
                        elif np.argmax(pi) == 1:
                            head[0] -= width * 0.5
                            tail[0] -= width * 0.4
                        elif np.argmax(pi) == 2:
                            head[1] -= height * 0.5
                            tail[1] -= height * 0.4
                        elif np.argmax(pi) == 3:
                            head[0] += width * 0.5
                            tail[0] += width * 0.4

                    headsize = 130 * (height + width)/2.0
                    c = np.sqrt(pi[a])
                    plt.annotate('', xy=(head[0],head[1]), xytext=(tail[0],tail[1]),
                                arrowprops=dict(fc='black', ec='grey',
                                                # width=4,headwidth=12*np.sqrt(c),headlength=12*np.sqrt(c)))
                                                width=headsize/3,headwidth=headsize*np.sqrt(c),headlength=headsize*np.sqrt(c)))
                if np.max(pi) != 1.0:
                    ax.scatter(x, y, s=15.9, marker='s', c='black', edgecolor='gainsboro', zorder=100)
                    # ax.scatter(x, y, s=15.9, marker='s', c='grey', edgecolor='grey', zorder=100)

    def reset_plot(self):
        # self.fig.clear()
        # self.fig.clf()
        # self.ax.cla()
        # self.ax.set_aspect('1.0')
        # self.ax.set_axis_off()
        # # self.tb = Table(self.ax, bbox=[0,0,1,1])
        self.fig, self.axes = plt.subplots(nrows=1, ncols=1, figsize=(14, 12), squeeze=False)
        self.ax = self.axes[0,0]
        self.ax.set_aspect('1.0')
        self.ax.set_axis_off()
        self.tb = Table(self.ax, bbox=[0,0,1,1])
        self.flag_table = False


    def save_grid(self, filename):
        plt.savefig(filename, transparent=True)
        print(filename)
