import os
import pickle
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


use_d_mu = True


def main():
    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-dir', type=str, required=True)
    parser.add_argument('--hist-bins', type=int, default=30)
    parser.add_argument('--title-size', type=float, default=26)
    parser.add_argument('--label-size', type=float, default=26)
    parser.add_argument('--tick-size', type=float, default=16)
    parser.add_argument('--marker-size', type=float, default=70)
    parser.add_argument('--arrow-width', type=float, default=0.0045)
    args = parser.parse_args()

    # read data
    dt_path = os.path.join(args.results_dir, 'dt/dt.pkl')
    with open(dt_path, 'rb') as f:
        dt = pickle.load(f)
    cmd_neuron_indices = dt['cmd_neuron_indices']

    results_path = os.path.join(args.results_dir, 'results.pkl')
    with open(results_path, 'rb') as f:
        data = pickle.load(f)

    # collect relevant information
    state_names = ['dx', 'dyaw']
    assert state_names == ['dx', 'dyaw'], 'Double check everything before changing to other state names'
    states = {state_names[0]: [], state_names[1]: [], 'dq': []}
    rnn_states_of_interest = []
    for ep_i, ep_data in enumerate(data):
        for step_i, step_data in enumerate(ep_data[:-1]): # make sure we can get next state
            ego_agent_id = step_data['ego_agent_id']
            step_logs = step_data['logs'][ego_agent_id]
            
            next_step_data = ep_data[step_i + 1]
            next_step_logs = next_step_data['logs'][ego_agent_id]

            dq_i = []
            for state_name in state_names:
                states[state_name].append(step_logs[state_name])
                dq_i.append(next_step_logs[state_name] - step_logs[state_name])
                if use_d_mu and state_name == 'dx':
                    states[state_name][-1] *= -1
                    dq_i[-1] *= -1
            states['dq'].append(dq_i)

            rnn_state = step_data['rnn_state']
            rnn_state_of_interest = np.stack(rnn_state).squeeze()
            rnn_state_of_interest = rnn_state_of_interest[cmd_neuron_indices[0]:cmd_neuron_indices[1]]
            rnn_states_of_interest.append(rnn_state_of_interest)
    rnn_states_of_interest = np.array(rnn_states_of_interest)
    for k, v in states.items():
        states[k] = np.array(v)

    # convert state samples to discretize state space (grid)
    bins = dict()
    discrete_state_space_idcs = dict()
    for k in state_names:
        bins[k] = np.histogram_bin_edges(states[k], bins=args.hist_bins)
        discrete_state_space_idcs[k] = np.digitize(states[k], bins=bins[k])
    
    state_grid = dict()
    for k, v in bins.items():
        state_grid[k] = (v[1:] + v[:-1]) / 2.
    if use_d_mu: # DEBUG
        state_grid['dx'] *= -1
    fl_state_grid = dict()
    fl_state_grid[state_names[0]], fl_state_grid[state_names[1]] = \
        flatten_comb(state_grid[state_names[0]], state_grid[state_names[1]])

    # get bin allocation for the entire state space grid
    bin_idcs = dict()
    for k, v in discrete_state_space_idcs.items():
        bin_idcs[k] = []
        for i in range(1, v.max()): # NOTE: ignore rightmost bound; loop through all bins
            v_idcs = np.arange(v.shape[0])
            bin_idx = v_idcs[v == i]
            bin_idcs[k].append(bin_idx)
        bin_idcs[k] = np.array(bin_idcs[k])
    
    # get bins of dq and neuron activation
    n_state_dim = rnn_states_of_interest.shape[1]
    neuron_act_grid = np.zeros((args.hist_bins, args.hist_bins, n_state_dim))
    dq_grid = np.zeros((args.hist_bins, args.hist_bins, 2))
    for i, bin_idcs_theta_i in enumerate(bin_idcs[state_names[0]]):
        for j, bin_idcs_theta_dot_j in enumerate(bin_idcs[state_names[1]]):
            bin_idcs_ij = set(bin_idcs_theta_i).intersection(set(bin_idcs_theta_dot_j))
            bin_idcs_ij = np.array(list(bin_idcs_ij))
            if bin_idcs_ij.shape[0] > 0:
                neuron_act_ij = rnn_states_of_interest[bin_idcs_ij].mean(0)
                neuron_act_grid[i,j] = neuron_act_ij

                dq_ij = states['dq'][bin_idcs_ij].mean(0)
                dq_grid[i,j] = dq_ij
    fl_dq_grid = dq_grid.reshape(-1, 2)
    if use_d_mu:
        fl_dq_grid[:,1] *= -1

    # plot
    fig, axes = plt.subplots(1, n_state_dim, figsize=(6.4*n_state_dim, 4.8*1.5))
    colormap = plt.cm.viridis
    for i in range(n_state_dim):
        ax = axes[i]
        ax.set_title(f'Neuron {i:02d}', fontsize=args.title_size)
        if use_d_mu:
            ax.set_xlabel(r'$d$', fontsize=args.label_size)
            ax.set_ylabel(r'$\mu$', fontsize=args.label_size)
        else:
            ax.set_xlabel(r'$\delta x$', fontsize=args.label_size)
            ax.set_ylabel(r'$\delta \theta$', fontsize=args.label_size)
        ax.tick_params(axis='both', which='major', labelsize=args.tick_size)
        ax.tick_params(axis='both', which='minor', labelsize=args.tick_size)

        if use_d_mu and state_names[0] == 'dx':
            fl_neuron_act_grid_i = neuron_act_grid[::-1,:,i].flatten()
        else:
            fl_neuron_act_grid_i = neuron_act_grid[:,:,i].flatten()
        fl_neuron_act_grid_i = neuron_act_grid[:,:,i].flatten()
        valid_fl_neuron_act_grid_i = fl_neuron_act_grid_i[fl_neuron_act_grid_i != 0.0]
        neuron_act_norm_i = matplotlib.colors.Normalize(vmin=valid_fl_neuron_act_grid_i.min(), vmax=valid_fl_neuron_act_grid_i.max())
        
        color = colormap(neuron_act_norm_i(fl_neuron_act_grid_i))
        color[fl_neuron_act_grid_i == 0.,:] = np.array([211/255.]*3 + [0.5]) # make invalid point gray
        ax.scatter(fl_state_grid[state_names[0]], fl_state_grid[state_names[1]], c=color, s=args.marker_size)
        ax.quiver(fl_state_grid[state_names[0]], fl_state_grid[state_names[1]], fl_dq_grid[:,0], fl_dq_grid[:,1], color='k', angles='xy', width=args.arrow_width)
        # ax.set_aspect('equal')

        mappable = matplotlib.cm.ScalarMappable(norm=neuron_act_norm_i, cmap=colormap)
        divider = make_axes_locatable(ax)
        cax = divider.new_vertical(size="5%", pad=0.7, pack_start=True)
        fig.add_axes(cax)
        cbar = fig.colorbar(mappable, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize=args.tick_size)
    fig.canvas.draw()
    fig.tight_layout()
    fig.savefig('./local/test.png')


def flatten_comb(arr1, arr2):
    # out1 = np.concatenate([arr1]*arr2.shape[0], axis=0)
    # out2 = np.repeat(arr2, arr1.shape[0])
    out1 = np.repeat(arr1, arr2.shape[0])
    out2 = np.concatenate([arr2]*arr1.shape[0], axis=0)
    return out1, out2


if __name__ == '__main__':
    main()
