import os
import pickle
import argparse
import tqdm
import numpy as np
import skvideo.io as io
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-dir', type=str, required=True)
    parser.add_argument('--out-postfix', type=str, default='')
    parser.add_argument('--state-dims', nargs='+', type=int, default=None)
    parser.add_argument('--state-dims-is-range', action='store_true', default=False)
    parser.add_argument('--normalize-within-episode', action='store_true', default=False)
    args = parser.parse_args()

    video_root_dir = os.path.join(args.results_dir, 'monitor')
    video_paths = []
    for video_path in sorted(os.listdir(video_root_dir)):
        if os.path.splitext(video_path)[-1] != '.mp4':
            continue
        video_path = os.path.join(video_root_dir, video_path)
        video_paths.append(video_path)

    results_path = os.path.join(args.results_dir, 'results.pkl')
    with open(results_path, 'rb') as f:
        data = pickle.load(f)

    out_dir = os.path.join(args.results_dir, 'out'+args.out_postfix)
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    state_data = np.array([[vv[2] for vv in v] for v in data]) # n_episodes x n_steps x n_states x state_dim
    if not args.normalize_within_episode:
        fl_state_data = state_data.reshape(-1, state_data.shape[-1])
        state_bound = np.stack([fl_state_data.min(0), fl_state_data.max(0)], axis=-1)

    add_colorbar = True # NOTE: can only add once for now
    for ep_i, ep_data in tqdm.tqdm(enumerate(data), desc='Episode', total=len(data)):
        # set up video reader and writer
        video_reader = io.FFmpegReader(video_paths[ep_i], inputdict={}, outputdict={})

        out_path = os.path.join(out_dir, f'ep_{ep_i:03d}.mp4')
        rate = '2' # '10'
        video_writer = io.FFmpegWriter(out_path,
                                       inputdict={'-r': rate},
                                       outputdict={'-vcodec': 'libx264',
                                                   '-pix_fmt': 'yuv420p',
                                                   '-r': rate,})

        # set up figure
        fig, axes = plt.subplots(1, 2)
        ax = axes[0]
        ax.set_xticks([])
        ax.set_yticks([])
        im = ax.imshow(next(video_reader.nextFrame()))

        # parse all state dimensions of interest
        if args.state_dims != None:
            state_dims = args.state_dims
            if args.state_dims_is_range:
                state_dims = range(*state_dims)
        else:
            state_dims = list(range([0, state_bound.shape[0]]))

        # compute episode state bound
        fl_ep_state_data = state_data[ep_i]
        fl_ep_state_data = fl_ep_state_data.reshape(-1, fl_ep_state_data.shape[-1])
        if args.normalize_within_episode:
            state_bound = np.stack([fl_ep_state_data.min(0), fl_ep_state_data.max(0)], axis=-1)
        min_max_text = ''
        for state_dim in state_dims:
            fl_ep_min_idx = fl_ep_state_data[:,state_dim].argmin(0)
            fl_ep_max_idx = fl_ep_state_data[:,state_dim].argmax(0)
            fp_ep_min = fl_ep_state_data[fl_ep_min_idx,state_dim]
            fp_ep_max = fl_ep_state_data[fl_ep_max_idx,state_dim]
            min_max_text += f'[State {state_dim}] min-id: {fl_ep_min_idx} max-id: {fl_ep_max_idx}\n'
            min_max_text += f'    min: {fp_ep_min:.3f} max: {fp_ep_max:.3f}\n'
        print(min_max_text)
        axes[1].set_title(min_max_text)

        # instantiate circle for neuron activation visualization
        circles = {'artist': [], 'norm': [], 'colormap': plt.cm.viridis}
        ax = axes[1]
        radius = 7
        diameter = radius * 2
        for state_dim_i, state_dim in enumerate(state_dims):
            norm = matplotlib.colors.Normalize(vmin=state_bound[state_dim,0], vmax=state_bound[state_dim,1])
            color = circles['colormap'](norm(0.))
            artist = matplotlib.patches.Circle((0, radius + diameter * state_dim_i), radius, color=color)
            ax.add_patch(artist)
            circles['artist'].append(artist)
            circles['norm'].append(norm)
        ax.set_xlim(-diameter, diameter)
        dy = 0.1
        ax.set_ylim(diameter * (-dy), diameter * (len(state_dims) + dy))
        ax.set_aspect('equal')
        ax.axis('off')

        # add color
        if add_colorbar:
            if len(state_dims) > 1:
                print('May have overlapped colorbar')
            ax = axes[1]
            for state_dim_i, state_dim in enumerate(state_dims):
                divider = make_axes_locatable(ax)
                cax = divider.append_axes('right', size='5%', pad=0.05*(state_dim_i + 1))
                mappable = matplotlib.cm.ScalarMappable(norm=circles['norm'][state_dim_i], cmap=circles['colormap'])
                fig.colorbar(mappable, cax=cax, orientation='vertical')
            # add_colorbar = False

        # run episode
        try:
            for step_i, frame in tqdm.tqdm(enumerate(video_reader.nextFrame()), desc='Step', total=len(ep_data)):
                axes[0].set_title(f'Frame {step_i:04d}')
                state = state_data[ep_i, step_i]
                state = state.flatten()
                
                for state_dim_i, state_dim in enumerate(state_dims):
                    color = circles['colormap'](norm(state[state_dim]))
                    circles['artist'][state_dim_i].set_color(color)

                im.set_data(frame)
                img = fig2arr(fig)
                video_writer.writeFrame(img)
        except KeyboardInterrupt:
            video_writer.close()
            break
        finally:
            video_writer.close()


def fig2arr(fig):
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return data


if __name__ == '__main__':
    main()
