import os
import matplotlib.pyplot as plt
import importlib
import numpy as np
if importlib.util.find_spec('matplotlib'):
    import matplotlib
    import matplotlib.pyplot as plt
    from matplotlib.font_manager import FontProperties


def plot(shape, pos, action, L, reward, lcmap, save=None):
    size = 5
    colormap = plt.cm.RdBu
    color = 'black'

    f = FontProperties(weight='bold')
    fontname = 'Times New Roman'
    fontsize = 20

    n_rows, n_cols = shape
    fig = plt.figure(figsize=(size, size))
    plt.rc('text', usetex=True)

    value = np.copy(reward)
    threshold = np.nanmax(np.abs(value)) * 2
    threshold = 1 if threshold == 0 else threshold

    plt.imshow(value, interpolation='nearest', cmap=colormap, vmax=threshold, vmin=-threshold)

    ax = fig.axes[0]

    ax.set_xticks(np.arange(0, n_cols, 1))
    ax.set_yticks(np.arange(0, n_rows, 1))

    ax.set_xticklabels(np.arange(n_cols), fontsize=fontsize)
    ax.set_yticklabels(np.arange(n_rows), fontsize=fontsize)

    ax.set_xticks(np.arange(-.5, n_cols, 1), minor=True)
    ax.set_yticks(np.arange(-.5, n_rows, 1), minor=True)

    ax.xaxis.tick_top()

    ax.grid(which='minor', color='lightgray', linestyle='-', linewidth=1, alpha=0.5)

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.tick_params(bottom='off', left='off')

    # action arrow
    i, j = pos[0], pos[1]
    if action == 'U':
        plt.arrow(j, i, 0, -0.2, head_width=.2, head_length=.15, color=color)
    elif action == 'D':
        plt.arrow(j, i - .3, 0, 0.2, head_width=.2, head_length=.15, color=color)
    elif action == 'R':
        plt.arrow(j - .15, i - 0.15, 0.2, 0, head_width=.2, head_length=.15, color=color)
    elif action == 'L':
        plt.arrow(j + .15, i - 0.15, -0.2, 0, head_width=.2, head_length=.15, color=color)

    # inserting labels
    for i in range(n_rows):
        for j in range(n_cols):
            if L[(i, j)] in lcmap:
                circle = plt.Circle((j, i + 0.24), 0.2, color=lcmap[L[(i, j)]])
                plt.gcf().gca().add_artist(circle)

            if L[(i, j)]:
                plt.text(j, i + 0.4, L[(i, j)][0], horizontalalignment='center', color=color, fontproperties=f,
                         fontname=fontname, fontsize=fontsize + 5)
    plt.savefig(save, bbox_inches='tight')


def multi_plot(shape, episode, L, reward, lcmap, animation):
    pad = 5
    if not os.path.exists(animation):
        os.makedirs(animation)

    T = len(episode)
    for t in range(T):
        # plot(shape, episode[t][0], episode[t][1], L, lcmap, save=animation + os.sep + str(t).zfill(pad) + '.png')
        plot(shape, episode[t][0], episode[t][1], L, reward, lcmap,
             save=animation + os.sep + str(t).zfill(pad) + '.png')
        plt.close()

    os.system(
        '/Users/dongmingshen/ffmpeg -r 2 -i ' + animation + os.sep + '%0' + str(pad) + 'd.png ' + animation + '.mp4')


def main():
    file_name = "result/video_txt/timed6newServer.txt"
    is_multi = True  # if multi, draw extra b based on state information

    shape = (4, 4)
    L = {}
    for i in range(shape[0]):
        for j in range(shape[1]):
            L[(i, j)] = ()
    L[(2,2)] = ('b',)
    L[(2,3)] = ('a',)
    L[(3,0)] = ('c',)
    L[(0,3)] = ('d',)

    lcmap = {
        ('a',): 'yellow',
        ('b',): 'greenyellow',
        ('c',): 'turquoise',
        ('d',): 'pink'
    }

    reward = np.zeros(shape)
    reward[3][0] = 5
    reward[0][3] = 2

    k = 1000
    line_mark = "==========[Running Simulation at k={},".format(k)
    len_mark = len(line_mark)
    with open(file_name, 'r') as file:
        Lines = [line.rstrip() for line in file]
        length = len(Lines)
    start = -1
    for i in range(length):
        if Lines[i][0:len_mark] == line_mark:
            start = i
            break
    if start == -1:
        print("wrong k, not in range")
        return
    print(file_name)
    print(Lines[start])

    if is_multi:
        Lstr = Lines[start + 1][17:23]
        Ltuple = eval(Lstr)
        L[Ltuple] = ('b',)
        episode = [None] * 349
        for i in range(350):
            if i == 0:
                continue
            line = Lines[start + i]
            state_str = line[9:15]
            state_tuple = eval(state_str)
            action_str = line[-1]
            state_action = state_tuple, action_str
            episode[i - 1] = state_action
    else:
        episode = [None] * 349
        for i in range(350):
            if i == 0:
                continue
            line = Lines[start + i]
            state_str = line[9:13]
            state_tuple = eval(state_str)
            action_str = line[-1]
            state_action = state_tuple, action_str
            episode[i - 1] = state_action
    multi_plot(shape, episode, L, reward, lcmap, 'case6_22')
    return


if __name__ == '__main__':
    main()
