import os
import pickle
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def main():
    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-dir', type=str, required=True)
    parser.add_argument('--hist-bins', type=int, default=20)
    parser.add_argument('--title-size', type=float, default=20)
    parser.add_argument('--label-size', type=float, default=18)
    parser.add_argument('--tick-size', type=float, default=14)
    parser.add_argument('--marker-size', type=float, default=50)
    args = parser.parse_args()

    # read data
    dt_path = os.path.join(args.results_dir, 'dt/dt.pkl')
    with open(dt_path, 'rb') as f:
        dt = pickle.load(f)
    cmd_neuron_indices = dt['cmd_neuron_indices']

    results_path = os.path.join(args.results_dir, 'results.pkl')
    with open(results_path, 'rb') as f:
        data = pickle.load(f)

    # collect relevant information
    rnn_states_of_interest = []
    states = dict(theta=[], theta_dot=[], theta_ddot=[], dq=[])
    for ep_data in data:
        for step_data in ep_data:
            obs, action, rnn_state, next_obs, reward, done, info = step_data

            cos_theta, sin_theta, theta_dot = obs
            theta = np.arctan2(sin_theta, cos_theta)
            theta_ddot = action

            next_cos_theta, next_sin_theta, next_theta_dot = next_obs
            next_theta = np.arctan2(next_sin_theta, next_cos_theta)
            dtheta = next_theta - theta
            if np.abs(dtheta) > np.pi: # handle np.pi to -np.pi
                dtheta = -(next_theta + theta)
            assert np.abs(dtheta) < np.pi # HACK check
            dq = np.array([dtheta, next_theta_dot - theta_dot])

            states['theta'].append(theta)
            states['theta_dot'].append(theta_dot)
            states['theta_ddot'].append(theta_ddot)
            states['dq'].append(dq)

            rnn_state_of_interest = np.stack(rnn_state).squeeze()
            rnn_state_of_interest = rnn_state_of_interest[cmd_neuron_indices[0]:cmd_neuron_indices[1]]
            rnn_states_of_interest.append(rnn_state_of_interest)
    rnn_states_of_interest = np.array(rnn_states_of_interest)
    for k, v in states.items():
        states[k] = np.array(v)

    # convert state samples to discretize state space (grid)
    bins = dict()
    discrete_state_space_idcs = dict()
    for k in ['theta', 'theta_dot']:
        bins[k] = np.histogram_bin_edges(states[k], bins=args.hist_bins)
        discrete_state_space_idcs[k] = np.digitize(states[k], bins=bins[k])
    
    state_grid = dict()
    for k, v in bins.items():
        state_grid[k] = (v[1:] + v[:-1]) / 2.
    fl_state_grid = dict()
    fl_state_grid['theta'], fl_state_grid['theta_dot'] = flatten_comb(state_grid['theta'], state_grid['theta_dot'])

    # get bin allocation for the entire state space grid
    bin_idcs = dict()
    for k, v in discrete_state_space_idcs.items():
        bin_idcs[k] = []
        for i in range(1, v.max()): # NOTE: ignore rightmost bound; loop through all bins
            v_idcs = np.arange(v.shape[0])
            bin_idx = v_idcs[v == i]
            bin_idcs[k].append(bin_idx)
        bin_idcs[k] = np.array(bin_idcs[k])
    
    # get bins of dq and neuron activation
    n_state_dim = rnn_states_of_interest.shape[1]
    neuron_act_grid = np.zeros((args.hist_bins, args.hist_bins, n_state_dim))
    dq_grid = np.zeros((args.hist_bins, args.hist_bins, 2))
    for i, bin_idcs_theta_i in enumerate(bin_idcs['theta']):
        for j, bin_idcs_theta_dot_j in enumerate(bin_idcs['theta_dot']):
            bin_idcs_ij = set(bin_idcs_theta_i).intersection(set(bin_idcs_theta_dot_j))
            bin_idcs_ij = np.array(list(bin_idcs_ij))
            if bin_idcs_ij.shape[0] > 0:
                neuron_act_ij = rnn_states_of_interest[bin_idcs_ij].mean(0)
                neuron_act_grid[i,j] = neuron_act_ij

                dq_ij = states['dq'][bin_idcs_ij].mean(0)
                dq_grid[i,j] = dq_ij
    fl_dq_grid = dq_grid.reshape(-1, 2)

    # plot
    fig, axes = plt.subplots(1, n_state_dim, figsize=(6.4*n_state_dim, 4.8))
    colormap = plt.cm.viridis
    for i in range(n_state_dim):
        ax = axes[i]
        ax.set_title(f'Neuron {i:02d}', fontsize=args.title_size)
        ax.set_xlabel(r'$\theta$', fontsize=args.label_size)
        # ax.set_ylabel(r'$\dot{\theta}}$', fontsize=args.label_size)
        ax.set_ylabel(r'$d\theta/dt$', fontsize=args.label_size)
        ax.tick_params(axis='both', which='major', labelsize=args.tick_size)
        ax.tick_params(axis='both', which='minor', labelsize=args.tick_size)

        fl_neuron_act_grid_i = neuron_act_grid[:,:,i].flatten()
        valid_fl_neuron_act_grid_i = fl_neuron_act_grid_i[fl_neuron_act_grid_i != 0.0]
        neuron_act_norm_i = matplotlib.colors.Normalize(vmin=valid_fl_neuron_act_grid_i.min(), vmax=valid_fl_neuron_act_grid_i.max())
        
        color = colormap(neuron_act_norm_i(fl_neuron_act_grid_i))
        color[fl_neuron_act_grid_i == 0.,:] = np.array([211/255.]*3 + [0.5]) # make invalid point gray
        ax.scatter(fl_state_grid['theta'], fl_state_grid['theta_dot'], c=color, s=args.marker_size)
        ax.quiver(fl_state_grid['theta'], fl_state_grid['theta_dot'], fl_dq_grid[:,0], fl_dq_grid[:,1], color='k', angles='xy')
        # ax.set_aspect('equal')

        mappable = matplotlib.cm.ScalarMappable(norm=neuron_act_norm_i, cmap=colormap)
        divider = make_axes_locatable(ax)
        cax = divider.new_vertical(size="5%", pad=0.7, pack_start=True)
        fig.add_axes(cax)
        cbar = fig.colorbar(mappable, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize=args.tick_size)
    fig.canvas.draw()
    fig.tight_layout()
    fig.savefig('./local/test.png')


def flatten_comb(arr1, arr2):
    # out1 = np.concatenate([arr1]*arr2.shape[0], axis=0)
    # out2 = np.repeat(arr2, arr1.shape[0])
    out1 = np.repeat(arr1, arr2.shape[0])
    out2 = np.concatenate([arr2]*arr1.shape[0], axis=0)
    return out1, out2


if __name__ == '__main__':
    main()
