from brl.utils import *


class LinePlot(object):
    def __init__(self,
                 x_label, y_label,  # names of the x-axis variable and the y-axis variable
                 xlim=None, ylim=None,  # displayed value ranges (for x and y)
                 title='',  # displayed title of the plot
                 label='new',  # name (in legend) of the curve that is point-by-point added
                 ):
        # format
        self.x_label = x_label
        self.y_label = y_label
        self.xlim = xlim
        self.ylim = ylim
        self.fig_title = title
        self.label = label

        # data
        self.x = []
        self.y = []
        self.y_low = []
        self.y_high = []
        self.baselines = []

        # handles of GUI
        self.fig, self.ax = None, None
        self._new_window()

    def add_point(self,
                  x, y,
                  y_low=None, y_high=None,  # lower and upper bounds to specify an interval estimate of y
                  y_range: tuple = None,  # alternative way to specify an interval estimate of y
                  ):
        assert y_low is None or isinstance(y_low, numbers.Number)
        if y_range is not None:
            y_low = y_range[0]
            y_high = y_range[1]

        self.x.append(x)
        self.y.append(y)
        self.y_low.append(y_low)
        self.y_high.append(y_high)

    def add_line(self,
                 x, y,
                 y_low=None, y_high=None,
                 y_range: tuple=None,
                 label=None
                 ):
        assert y_low is None or isinstance(y_low[0], numbers.Number)
        if y_range is not None:
            y_low = [range[0] for range in y_range]
            y_high = [range[1] for range in y_range]
        else:
            if y_low is None:
                assert y_high is None
                y_low = y_high = [None] * len(x)
        assert len(x) == len(y) == len(y_low) == len(y_high)
        if label is None:
            label = 'baseline' + str(len(self.baselines) + 1)

        self.baselines.append(
            {'x': x,
             'y': y,
             'y_low': y_low,
             'y_high': y_high,
             'label': label}
        )

    def load_data(self, filename):
        with open(filename, 'rb') as fp:
            data = pickle.load(fp)
            self.fig_title = data.fig_title
            self.x_label, self.y_label = data.x_label, data.y_label
            self.xlim, self.ylim = data.xlim, data.ylim
            self.label = data.label
            self.x, self.y, self.y_low, self.y_high = data.x, data.y, data.y_low, data.y_high
            self.baselines = data.baselines

    def clear_data(self):
        self.x = []
        self.y = []
        self.y_low = []
        self.y_high = []
        self.baselines = []

    def output(self,
               filename='screen',  # path to save the figure as well as the associated data (in both binary and human-readable format) to file system
               block=False  # set block = True will display the figure in blocking mode (if filename is 'screen'); has no effect if filename is not 'screen'
               ):
        '''
        by default, will plot all the curves as stored to a single object-dedicated figure on screen in non-blocking way;
        can switch to save the data and the figure to file system with path specified by 'filename',
        or switch to blocking-display by setting 'block' to True
        '''
        if filename == 'screen' and not plt.fignum_exists(self.fig.number):
            self._new_window()

        if filename == 'screen' and block is False:
            plt.draw()
            plt.pause(0.001)

        self.ax.clear()
        self.ax.set_xlabel(self.x_label)
        self.ax.set_ylabel(self.y_label)
        if self.xlim is not None: self.ax.set_xlim(self.xlim)
        if self.ylim is not None: self.ax.set_ylim(self.ylim)
        self.fig.suptitle(self.fig_title, wrap=True)
        #self.ax.set_title(self.fig_title, wrap=True)
        #self.fig.tight_layout()
        self.ax.grid()

        for idx, line in enumerate(self.baselines):
            self.ax.plot(line['x'],
                         line['y'],
                         label=line['label'],
                         color='C' + str(idx))
            if line['y_low'][0] is not None:  # TODO: better design to deal with incomplete y_range
                self.ax.fill_between(line['x'],
                                     line['y_low'],
                                     line['y_high'],
                                     alpha=0.2,
                                     color='C' + str(idx))

        assert len(self.x) == len(self.y) == len(self.y_low) == len(self.y_high)
        if len(self.x) > 0:
            self.ax.plot(self.x,
                         self.y,
                         label=self.label,
                         color='C' + str(len(self.baselines)))
            if self.y_low[0] is not None:  # TODO: better design to deal with incomplete y_range
                self.ax.fill_between(self.x,
                                     self.y_low,
                                     self.y_high,
                                     alpha=0.2,
                                     color='C' + str(len(self.baselines)))

        if len(self.baselines) >= 2 or (len(self.baselines) > 0 and len(self.x) > 0):
            self.ax.legend()  #(loc="lower right")

        if filename == 'screen':
            if block:
                plt.show(block=True)
            else:
                plt.draw()
                plt.pause(0.001)
        else:
            self._save(filename)

    def show(self, block=False):
        self.output(filename='screen', block=block)

    def __del__(self):
        plt.close(self.fig)

    def _new_window(self):
        if not matplotlib.is_interactive():
            plt.ion()
        self.fig, self.ax = plt.subplots() # constrained_layout=True
        self.fig.set_tight_layout(True)

    def _save(self, filename):
        self.fig.savefig(filename, transparent=False, bbox_inches='tight')

        with open(filename + '.pkl', 'wb') as fp:
            fig, ax = self.fig, self.ax
            self.fig, self.ax = None, None
            pickle.dump(self, fp)
            self.fig, self.ax = fig, ax

        with open(filename + '.csv', 'w') as fp:
            print(self.fig_title, file=fp)
            print('None' if self.xlim is None else '{},\t {}'.format(self.xlim[0], self.xlim[1]), file=fp)
            print('None' if self.ylim is None else '{},\t {}'.format(self.ylim[0], self.ylim[1]), file=fp)

            print(self.label, file=fp)
            print('{},\t {},\t (y_low),\t (y_high)'.format(self.x_label, self.y_label), file=fp)
            for i in range(len(self.x)):
                print('{},\t {},\t {},\t {}'.format(self.x[i], self.y[i], self.y_low[i], self.y_high[i]), file=fp)

            for line in self.baselines:
                print(line['label'], file=fp)
                print('{},\t {},\t (y_low),\t (y_high)'.format(self.x_label, self.y_label), file=fp)
                for i in range(len(line['x'])):
                    print('{},\t {},\t {},\t {}'.format(line['x'][i], line['y'][i], line['y_low'][i], line['y_high'][i]), file=fp)

    # back-compatible interface
    def title(self, label):
        self.fig_title = label

    def plot(self, x=None, y=None, y_range=None, block=False):
        if x is None:
            self.show(block=block)
        else:
            self.clear_data()
            self.add_line(x, y, y_range=y_range)
            self.show(block=block)


if __name__ == "__main__":
    print('matplotlib backend: ', matplotlib.get_backend())
    '''
    # try this code block in case the tests below fail to double check that the matplotlib backend is working
    #plt.axis([-50,50,0,10000])
    plt.ion()
    plt.draw()
    plt.pause(0.001)

    x = np.arange(-50, 51)
    for pow in range(1,5):   # plot x^1, x^2, ..., x^4
        time.sleep(5)
        print(pow)
        y = [Xi**pow for Xi in x]
        plt.plot(x, y)
        #plt.draw()
        plt.pause(0.001)
        #input("Press [enter] to continue.")
    '''

    # standard interface testing
    plot1 = LinePlot('step', 'performance', xlim=[0, 100], ylim=(0, 100), title='add_point + show\n(you should see a growing blue curve)')
    for x in interval(0, 100, gap=10):
        if x % 20 == 0:
            plot1.add_point(x, x, x-5, x+5)
        else:
            #plot.add_point(x, x, (x-3,x+3))  # should lead to error
            plot1.add_point(x, x, y_range=(x-3, x+3))
        time.sleep(1)
        #for i in range(100000): print(i)
        plot1.show()
    
    plot2 = LinePlot('step', 'performance')
    plot2.title('add_line + show\n(you should first see 5 ascending "baseline" curves,\nthen see a "new" curve at the bottom)')
    for i in range(5):
        x = list(range(100))
        y = [i * xx for xx in x]
        y_range = [(yy - 2*i, yy + 2*i) for yy in y]
        plot2.add_line(x, y, y_range=y_range)
        plot2.show()
        time.sleep(2)
    for x in range(100):
        plot2.add_point(x, -x)
    plot2.show()

    plot3 = LinePlot('step', 'performance', title='a title to be replaced soon')
    plot3.title('add_point + show(block)\n(you should now see an empty plot,\nclose all windows to proceed)')
    for x in interval(0, 100, 20):
        if x % 20 == 0:
            plot3.add_point(x, x, x-5, x+5)
        else:
            plot3.add_point(x, x, y_range=(x-5, x+5))
        plot3.show(block=True)
        plot3.title('add_point + show(block)\n(you should now see a blue curve with narrowing error bar,\nclose this window to proceed)')

    plot3.title('add_line + show(block)\n(you should now see 5 "baseline" curves inserted below\n the original "new" curve, close this window to proceed)')
    for i in range(5):
        x = list(range(100))
        y = [-i * xx for xx in x]
        plot3.add_line(x, y)
        plot3.show(block=True)

    plot2.output('plot2.png')
    plot3.output('plot3.png')

    plot_copy = LinePlot('x', 'y')
    plot_copy.load_data('plot2.png.pkl')
    plot_copy.show(block=True)

    plot_copy.clear_data()
    plot_copy.title('add_line + show\n(but the data was deleted so you should now see an empty plot,\nclose this window to proceed)')
    plot_copy.show(block=True)
    plot_copy.output('plot_empty.png')

    plot_copy.load_data('plot3.png.pkl')
    plot_copy.show(block=True)

    del plot2
    del plot3
    del plot_copy

    # legacy interface testing
    for i in range(10):
        plot = LinePlot('step', 'performance')
        plot.title('test')
        x = [i for i in range(100)]
        y = x
        y_range = [(xx - 10 * i, xx + 10 * i) for xx in x]
        plot.plot(x, y, y_range)
        time.sleep(1)

    plot = LinePlot('step', 'performance')
    plot.add_point(0, 10.2)
    plot.plot()
    time.sleep(1)
    plot.add_point(12, 20.4)
    plot.plot()
    time.sleep(1)
    plot.add_point(15, 31.2)
    plot.plot()
    time.sleep(1)
    plot.add_point(20, 15.6)
    plot.plot()
    time.sleep(1)

    perf = RV(save_data=True)
    for i in range(100):
        perf.append(i)
    perf.plot(blocking=False)
    perf.plot()

    print('finished')