

import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
from plot_utils.generate_colormap import generate_colormap
from plot_utils.plot_simulated_data_gridworld import plot_gridworld_trajectories
import glob
import string



LEGEND_SIZE = 10
SMALL_SIZE = 15
BIGGER_SIZE = 20

plt.rc('font', family='Helvetica')          # change font to Myriad Pro
plt.rc('font', size=LEGEND_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rcParams.update({"text.usetex": True})
# colors = ['steelblue', '#D85427', 'tab:green', 'k']
# colors = ['#fdae61','#ffffbf','#abdda4','#2b83ba']
colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3']


tags_sim = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction',
            'independent_control_w_policy_prediction', 'selfish_perfect_prediction', 'selfish_selfish']
tags_train = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction', '', 'selfish']

plot_dir = 'figures/'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

figure_idx = 4
print('Generating figure {}'.format(figure_idx))

# ------------------------------
# Collab Foraging Task : Trajectories, value map, diff map from centralized control
# ------------------------------
if figure_idx == 1:
    grid_W, grid_H = 5, 5
    tag_sim, tag_train = 0, 0
    num_trajs, seed, max_iters = 500, 1, 20
    lr_maps, lr_weights = 0.01, 0.005
    lam1, lam2 = 5.0, 1.0
    n_maps_individual, n_maps_interaction = 1, 1
    version = 1
    GEN_DIR_NAME = 'data/simulated_gridworld_data/{}{}/'.format(tags_sim[tag_sim],version)
    REC_DIR_NAME = 'recovered_parameters/gridworld_recovered_params/fit_{}/'.format(tags_train[tag_train])
    rec_dir_name = REC_DIR_NAME + '{}{}/maps_{}_{}_lr_{}_{}_lam_{:.1f}_{:.1f}/'.format(tags_train[tag_train], version, n_maps_individual, n_maps_interaction, lr_maps, lr_weights, lam1, lam2)

    gen_rewards = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['generative_rewards']
    trajs = pickle.load(open(GEN_DIR_NAME + "/expert_trajectories.pickle", 'rb'))
    rec_weights = np.load(rec_dir_name + "weights_trajs_" + str(num_trajs) +"_seed_" + str(seed) + "_iters_" + str(max_iters) +".npy")[-1]
    rec_ind1_maps = np.load(rec_dir_name + "ind1_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_ind2_maps = np.load(rec_dir_name + "ind2_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_inter_maps = np.load(rec_dir_name + "inter_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]


    fig, axs = plt.subplots(2,3,figsize=(7.5,5))

    c_max = np.max(gen_rewards[:,0])
    im = axs[0,0].imshow(np.reshape(gen_rewards[:,0],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=axs[0,0],fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    axs[0,0].set_title('actual rewards')
    axs[0,0].set_axis_off()

    traj = trajs[1]
    traj1 = np.array([(a,b) for (a,b,c,d) in traj['states4d']])
    traj2 = np.array([(c,d) for (a,b,c,d) in traj['states4d']])
    plot_gridworld_trajectories(grid_H, grid_W, {'states2d':traj1}, fig, axs[0,1])
    plot_gridworld_trajectories(grid_H, grid_W, {'states2d':traj2}, fig, axs[0,2])
    axs[0,1].set_title('agent 1')
    axs[0,2].set_title('agent 2')


    c_max = np.max(np.abs([rec_ind1_maps[0,:], rec_ind2_maps[0,:]]))
    # RECOVERED agent 1 map
    im = axs[1,0].imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=axs[1,0],fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])
    axs[1,0].set_title(r'$m(s_1): (\alpha={:.2f})$'.format(rec_weights[0,0]))
    axs[1,0].set_axis_off()

    # RECOVERED agent 2 map
    im = axs[1,1].imshow(np.reshape(rec_ind2_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=axs[1,1],fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    axs[1,1].set_title(r'$n(s_2): (\alpha={:.2f})$'.format(rec_weights[1,0]))
    axs[1,1].set_axis_off()
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])

    # Recovered diff map
    diff_maps = np.zeros((grid_H, grid_W))
    diff_square = np.unique([i**2+j**2 for i in range(grid_H) for j in range(grid_W)])
    center = (int(grid_H/2), int(grid_W/2))
    for i in range(grid_H):
        for j in range(grid_W):
            dist = (i-center[0])**2 + (j-center[1])**2
            dist_idx = np.argwhere(diff_square == dist)[0][0]
            diff_maps[i,j] = rec_inter_maps[0, dist_idx]

    c_max = np.max(np.abs(diff_maps))
    im = axs[1,2].imshow(diff_maps,vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=axs[1,2],fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    axs[1,2].text(center[0], center[1], '*', fontsize=30, ha='center', va='center')
    axs[1,2].set_title(r'$\phi(s_1,s_2): (\alpha={:.2f})$'.format(rec_weights[2,0]))
    axs[1,2].set_axis_off()
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])

    # add subplot labels
    for n, ax in enumerate(axs.flat):
        ax.text(-0.25, 1.15, string.ascii_uppercase[n], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')

    plt.tight_layout(pad=2)
    fig.savefig(plot_dir + 'rec_value.png')
    fig.savefig(plot_dir + 'rec_value.svg',transparent=True)


# ---------------------------------------------
# Collab Foraging Task :  Joint test LL between models
# ----------------------------------------------
if figure_idx == 2:

    # tags_label = ['Centralized','Perfect prediction', 'Uniform prediction', 'Selfish + perfect prediction']
    # tags_sim_idx = [0, 1, 2, 4]
    # tags_train_idx = [0, 1, 2, 4]

    tags_label = ['Centralized','Perfect prediction', 'Uniform prediction']
    tags_sim_idx = [0, 1, 2]
    tags_train_idx = [0, 1, 2]

    version = 1

    # generate self_predict_prediction by combining ind. fit of selfish and perfect prediction
    if False:
        rec_dir = 'recovered_parameters/gridworld_recovered_params/fit_selfish_perfect_prediction/'
        rec_dir_agent1 = 'recovered_parameters/gridworld_recovered_params/fit_selfish/'
        rec_dir_agent2 = 'recovered_parameters/gridworld_recovered_params/fit_independent_control/'
        mdl_name = 'maps_1_1_lr_0.01_0.005_lam_5.0_1.0'
        os.makedirs(rec_dir, exist_ok=True)
        for tag_sim in tags_sim_idx:
            rec_dir_subfoler = rec_dir + '{}{}/{}/'.format(tags_sim[tag_sim],version,mdl_name)
            os.makedirs(rec_dir_subfoler, exist_ok=True)
            rec_dir_agent1_subfolder = rec_dir_agent1 + '{}{}/{}/'.format(tags_sim[tag_sim],version,mdl_name)
            rec_dir_agent2_subfolder = rec_dir_agent2 + '{}{}/{}/'.format(tags_sim[tag_sim],version,mdl_name)
            tmp = glob.glob(rec_dir_agent1_subfolder + 'agent_1/validation*.npy')
            file1 = tmp[-1]
            val_name1 = file1.split('/')[-1]
            ll_agent1 = np.load(file1)
            tmp = glob.glob(rec_dir_agent2_subfolder + 'agent_2/validation*.npy')
            file2 = tmp[-1]
            val_name2 = file2.split('/')[-1]
            assert val_name1 == val_name2, 'Incompatible file names: {} and {}'.format(val_name1, val_name2)
            ll_agent2 = np.load(file2)
            ll_joint = [i+j for i,j in zip(ll_agent1, ll_agent2)]
            np.save(rec_dir_subfoler + val_name1, ll_joint)


    ll_mat = np.zeros((len(tags_sim_idx),len(tags_train_idx)))
    for i, tag_sim in enumerate(tags_sim_idx):
        for j, tag_train in enumerate(tags_train_idx):
            save_dir = 'recovered_parameters/gridworld_recovered_params/fit_{}/{}{}/maps_1_1_lr_0.01_0.005_lam_5.0_1.0/'.format(tags_train[tag_train], tags_sim[tag_sim], version)
            tmp = glob.glob(save_dir+'validation*.npy')
            if len(tmp)>0:
                test_ll = np.load(tmp[0])
                ll_mat[i, j] = test_ll[-1]
            else:
                print('No joint file found for {}'.format(save_dir))
                print('Use indept fit')
                for agent in [1,2]:
                    tmp1 = glob.glob(save_dir+'agent_{}/validation*.npy'.format(agent))
                    assert len(tmp1)==1, 'No file found for {}'.format(save_dir) + 'agent_{}/validation*.npy'.format(agent)
                    test_ll = np.load(tmp1[0])
                    ll_mat[i, j] += test_ll[-1]
            
    print(ll_mat)



    fig,axs = plt.subplots(1,3, figsize=(10,3))
    ax = axs[1]
    im = ax.imshow(ll_mat)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[np.min(ll_mat),0,np.max(ll_mat)])
    ax.set_xticks(range(ll_mat.shape[1]))
    ax.set_xticklabels(tags_label, rotation=45, ha='right')
    ax.set_yticks(range(ll_mat.shape[1]))
    ax.set_yticklabels(tags_label)
    # ax.set_xlabel('Inference')
    ax.set_ylabel('Simulation')
    ax.set_title('LL per decision')

    mat = ll_mat
    norm = [(mat[i,:] - np.min(mat[i,:])) / (np.max(mat[i,:])-np.min(mat[i,:])) for i in range(mat.shape[0])]
    ax = axs[2]
    im = ax.imshow(norm)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[0,1])
    ax.set_xticks(range(ll_mat.shape[1]))
    ax.set_xticklabels(tags_label, rotation=45, ha='right')
    ax.set_yticks([])
    ax.set_title('Row norm.')

    tasks = []
    for tag in tags_sim_idx:
        gen_dir = 'data/simulated_gridworld_data/{}1'.format(tags_sim[tag])
        file = open(gen_dir + '/expert_trajectories.pickle', 'rb')
        all_expert_trajectories = pickle.load(file)
        task_length = []
        for traj in all_expert_trajectories:
            task_length.append(len(traj['states']))
        tasks.append(task_length)

    # Calculate the mean and standard deviation of task_length
    mean_task_length = [np.mean(task) for task in tasks]
    std_task_length = [np.std(task) for task in tasks]

    # Plot the barplot
    ax = axs[0]
    ax.bar(tags_label, mean_task_length, yerr=std_task_length, capsize=5)
    ax.set_ylabel(r'Length (mean $\pm$ std)')
    ax.set_xticks(range(len(tags_label)))
    ax.set_xticklabels(tags_label, rotation=45, ha='right')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    for n, ax in enumerate(axs.flat):
        ax.text(-0.5, 1.15, string.ascii_uppercase[n], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')

    plt.tight_layout()

    fig.savefig(plot_dir + 'll_test.png')
    fig.savefig(plot_dir + 'll_test.svg',transparent=True)


# ---------------------------------------------
# Collab Foraging Task :  Individual test LL and value map for selfish agent
# ----------------------------------------------
if figure_idx == 3:
    grid_H, grid_W = 5, 5
    tag_sim = 4
    agent = 1
    tags_train_idx = [1, 2, 4]
    version = 1
    tags_label = ['Perfect prediction', 'Chance prediction', "Selfish"]
    ll = []
    for i, tag_train in enumerate(tags_train_idx):
        save_dir = 'recovered_parameters/gridworld_recovered_params/fit_{}/{}{}/maps_1_1_lr_0.01_0.005_lam_5.0_1.0/'.format(tags_train[tag_train], tags_sim[tag_sim], version)
        tmp1 = glob.glob(save_dir+'agent_{}/validation*.npy'.format(agent))
        assert len(tmp1)==1, 'No file found for {}'.format(save_dir) + 'agent_{}/validation*.npy'.format(agent)
        test_ll = np.load(tmp1[0])
        ll.append(test_ll[-1])

    fig, axs = plt.subplots(1,4,figsize=(10,3))

    ax = axs[0]
    ax.barh(np.arange(len(ll)), ll, color=colors)
    ax.set_xlim([-1.7, -1.2])
    ax.set_yticks(np.arange(len(ll)))
    ax.set_yticklabels(tags_label, rotation=0)
    ax.spines['top'].set_visible(False)
    # ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    for i,l in enumerate(ll):
        ax.text(l-0.15, i, tags_label[i], ha='center', va='center', color='black',fontsize=LEGEND_SIZE)
    ax.set_yticks([])
    ax.set_title('LL per decision')

    tag_train = 4
    REC_DIR_NAME = 'recovered_parameters/gridworld_recovered_params/fit_{}/'.format(tags_train[tag_train])
    rec_dir_name = REC_DIR_NAME + '{}{}/maps_1_1_lr_0.01_0.005_lam_5.0_1.0/agent_1'.format(tags_sim[tag_sim], version)

    tmp = glob.glob(rec_dir_name + '/weights_trajs*.npy')
    rec_weights = np.load(tmp[0])[-1]
    tmp = glob.glob(rec_dir_name + '/ind1_maps_trajs*.npy')
    rec_ind1_maps = np.load(tmp[0])[-1]
    tmp = glob.glob(rec_dir_name + '/ind2_maps_trajs*.npy')
    rec_ind2_maps = np.load(tmp[0])[-1]
    tmp = glob.glob(rec_dir_name + '/inter_maps_trajs*.npy')
    rec_inter_maps = np.load(tmp[0])[-1]

    c_max1 = np.max(np.abs([rec_ind1_maps[0,:], rec_ind2_maps[0,:]]))
    c_max2 = np.max(np.abs(rec_inter_maps[0,:]))
    c_max = np.max([c_max1, c_max2])
    ax = axs[2]
    im = ax.imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])
    ax.set_title(r'$m(s_1): (\alpha={:.2f})$'.format(rec_weights[0,0]))
    ax.set_axis_off()

    # RECOVERED agent 2 map
    ax = axs[3]
    im = ax.imshow(np.reshape(rec_ind2_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    ax.set_title(r'$n(s_2): (\alpha={:.2f})$'.format(rec_weights[1,0]))
    ax.set_axis_off()
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])

    # Recovered diff map
    diff_maps = np.zeros((grid_H, grid_W))
    diff_square = np.unique([i**2+j**2 for i in range(grid_H) for j in range(grid_W)])
    center = (int(grid_H/2), int(grid_W/2))
    for i in range(grid_H):
        for j in range(grid_W):
            dist = (i-center[0])**2 + (j-center[1])**2
            dist_idx = np.argwhere(diff_square == dist)[0][0]
            diff_maps[i,j] = rec_inter_maps[0, dist_idx]

    ax = axs[1]
    im = ax.imshow(diff_maps,vmin=-c_max, vmax=c_max)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[-c_max,0,c_max])
    ax.text(center[0], center[1], '*', fontsize=30, ha='center', va='center')
    ax.set_title(r'$\phi(s_1,s_2): (\alpha={:.2f})$'.format(rec_weights[2,0]))
    ax.set_axis_off()
    c_bar.set_ticks([-c_max, 0, c_max])
    c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])

    for n, ax in enumerate(axs.flat):
        ax.text(-0.25, 1.15, string.ascii_uppercase[n], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')
    plt.tight_layout(pad=1)
    fig.savefig(plot_dir + 'll_test_selfish.png')
    fig.savefig(plot_dir + 'll_test_selfish.svg',transparent=True)


# ---------------------------------------------
# Collab Foraging Task :  Trajectory comparison between models
# ----------------------------------------------
if figure_idx == 4:

    tags_sim_idx = [0, 1, 2, 4, 5]
    tags_sim_label = ['Centralized', 'Optimal prediction', 'Chance prediction',
                    'Egocentric + optimal prediction', 'Egocentric']

    tasks = []
    for tag in tags_sim_idx:
        gen_dir = 'data/simulated_gridworld_data/{}1'.format(tags_sim[tag])
        file = open(gen_dir + '/expert_trajectories.pickle', 'rb')
        all_expert_trajectories = pickle.load(file)
        task_length = []
        for traj in all_expert_trajectories:
            task_length.append(len(traj['states']))
        tasks.append(task_length)

    # Calculate the mean and standard deviation of task_length
    mean_task_length = [np.mean(task) for task in tasks]
    std_task_length = [np.std(task) for task in tasks]

    # Plot the barplot
    plt.figure(figsize=(5, 5))
    plt.bar(tags_sim_label, mean_task_length, yerr=std_task_length, capsize=5)
    plt.ylabel(r'Task Length (mean $\pm$ std)')
    plt.xticks(rotation=45, ha='right')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.tight_layout()
    plt.savefig(plot_dir + 'task_length.png')
    plt.savefig(plot_dir + 'task_length.eps', transparent=True)



# ---------------------------------------------
# Hallway Task: Test LL
# ----------------------------------------------
if figure_idx == 5:

    tags_label = ['Centralized','Perfect prediction', 'Uniform prediction', 'Selfish']
    shuffle_idx = [1,0,2]

    grid_H, grid_W = 3, 5
    # lls = []
    # for version in [1,2,3]:
    #     tags_train_idx = [1, 2, 4]
    #     ll1, ll2 = [], []
    #     for j, tag_train in enumerate(tags_train_idx):
    #         save_dir = 'recovered_parameters/hallway/fit_{}/{}/maps_1_1_lr_0.01_0.005_lam_2.0_1.0/'.format(tags_train[tag_train], version)
    #         ll_joint = 0
    #         for agent in [1,2]:
    #             tmp1 = glob.glob(save_dir+'agent_{}/validation*.npy'.format(agent))
    #             assert len(tmp1)==1, 'No file found for {}'.format(save_dir) + 'agent_{}/validation*.npy'.format(agent)
    #             test_ll = np.load(tmp1[0])
    #             if agent == 1:
    #                 ll1.append(test_ll[-1])
    #             elif agent == 2:
    #                 ll2.append(test_ll[-1])
    #     print(ll1)
    #     print(ll2)

    ll_version = []
    for version in [1,2,3]:
        tags_train_idx = [0, 1, 2, 4]
        lls = []
        for j, tag_train in enumerate(tags_train_idx):
            save_dir = 'recovered_parameters/hallway/fit_{}/{}/maps_1_1_lr_0.01_0.005_lam_2.0_1.0/'.format(tags_train[tag_train], version)
            tmp = glob.glob(save_dir+'validation*.npy')
            if len(tmp) > 0:
                test_ll = np.load(tmp[0])
                lls.append(test_ll[-1])
            else:
                ll_joint = 0
                for agent in [1,2]:
                    tmp1 = glob.glob(save_dir+'agent_{}/validation*.npy'.format(agent))
                    test_ll = np.load(tmp1[0])
                    ll_joint += test_ll[-1]
                lls.append(ll_joint)
        ll_version.append(lls)

    ll_mat = np.array(ll_version)[shuffle_idx,:3]

    fig, axs = plt.subplots(1,2,figsize=(5,2.5))
    ax = axs[0]
    im = ax.imshow(ll_mat)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[-2.5,-2.0])
    ax.set_xticks(range(ll_mat.shape[1]))
    ax.set_xticklabels(tags_label[:3],ha='right',rotation=45)
    ax.set_yticks(range(ll_mat.shape[0]))
    ax.set_yticklabels(['Expert','Success','Fail'])
    ax.set_title('LL per decision')
    
    norm = [(ll_mat[i,:] - np.min(ll_mat[i,:])) / (np.max(ll_mat[i,:])-np.min(ll_mat[i,:])) for i in range(ll_mat.shape[0])]
    ax = axs[1]
    im = ax.imshow(norm)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[0,1])
    ax.set_xticks(range(ll_mat.shape[1]))
    ax.set_xticklabels(tags_label[:3],ha='right',rotation=45)
    ax.set_yticks([])
    ax.set_title('Row norm.')

    for n, ax in enumerate(axs.flat):
        ax.text(-0.25, 1.15, string.ascii_uppercase[n], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')

    plt.tight_layout()
    fig.savefig(plot_dir + 'hallway_ll_test.png')
    fig.savefig(plot_dir + 'hallway_ll_test.svg',transparent=True)

    for tag_train_idx in [1,2]:
        for version in [2,3]:
            save_dir = 'recovered_parameters/hallway/fit_{}/{}/maps_1_1_lr_0.01_0.005_lam_2.0_1.0/'.format(tags_train[tag_train_idx], version)
            tmp1 = glob.glob(save_dir+'agent_1/ind1_maps*.npy')
            tmp2 = glob.glob(save_dir+'agent_1/ind2_maps*.npy')
            tmp3 = glob.glob(save_dir+'agent_1/weights*.npy')
            rec_ind1_maps = np.load(tmp1[0])[-1]
            rec_ind2_maps = np.load(tmp2[0])[-1]
            rec_weights = np.load(tmp3[0])[-1]

            fig, axs = plt.subplots(1,2,figsize=(5,2.5))

            c_max = np.max(np.abs([rec_ind1_maps[0,:], rec_ind2_maps[0,:]]))
            # RECOVERED agent 1 map
            ax = axs[0]
            im = ax.imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
            c_bar = plt.colorbar(im,ax=ax,fraction=0.028, pad=0.04, ticks=[-c_max,0,c_max])
            c_bar.set_ticks([-c_max, 0, c_max])
            c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])
            ax.set_title(r'$m(s_1): (\alpha={:.2f})$'.format(rec_weights[0,0]))
            ax.set_axis_off()

            # RECOVERED agent 2 map
            ax = axs[1]
            im = ax.imshow(np.reshape(rec_ind2_maps[0,:],(grid_H, grid_W),order='F'),vmin=-c_max, vmax=c_max)
            c_bar = plt.colorbar(im,ax=ax,fraction=0.028, pad=0.04, ticks=[-c_max,0,c_max])
            ax.set_title(r'$n(s_2): (\alpha={:.2f})$'.format(rec_weights[1,0]))
            ax.set_axis_off()
            c_bar.set_ticks([-c_max, 0, c_max])
            c_bar.set_ticklabels(['{:.2f}'.format(-c_max), '0', '{:.2f}'.format(c_max)])

            if version == 2:
                title = 'Expert: value map'
            else:
                title = 'Fail: value map'
            plt.suptitle(title)

            for n, ax in enumerate(axs.flat):
                ax.text(-0.25, 1.15, string.ascii_uppercase[n], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')
            plt.tight_layout()
            fig.savefig(plot_dir + 'hallway_value_{}_{}.png'.format(tags_train[tag_train_idx], version))
            fig.savefig(plot_dir + 'hallway_value_{}_{}.svg'.format(tags_train[tag_train_idx], version), transparent=True)