# modified based on dynamic_irl main_irl.py
# time-invariant version
# two agent collaborative exploration version

import os, argparse
import numpy as np
import pickle
from src.optimize_weights import getMAP_weights
# from src.optimize_goal_maps import getMAP_goalmaps
# from src.compute_validation_ll import get_validation_ll
from itertools import permutations
import matplotlib.pyplot as plt
# from src.compute_conf_interval import compute_conf_interval, compute_inv_hessian, compute_conf_interval_ti, compute_inv_hessian_ti
from src.irl_for_gridworld import fit_irl_gridworld
from src.irl_for_gridworld_ind import fit_irl_gridworld_ind
from plot_utils.generate_colormap import generate_colormap
import glob


def get_gen_recovered_parameters(seed, num_trajs, max_iters, 
                                 n_map_individual, n_map_interaction, lr_weights, lr_maps, lam1, lam2):
    """ makes a summary plot for recovered and generative parameters
        args:
            seed (int): which seed to plot
            n_map_individual (int): # of individual maps per agent
            n_map_interaction (int): # of interaction maps
            lr_weights (float): which learning rate to plot
            lr_maps (float): which learning rate to plot
            save (bool): whether to save the plot or not
    """
    # directories
    rec_dir_name = REC_DIR_NAME + "/maps_{}_{}_lr_{}_{}_lam_{}_{}/".format(n_map_individual, n_map_interaction, lr_weights, lr_maps, lam1,lam2)
    gen_rewards = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['generative_rewards']

    # load recovered parameters for this seed
    rec_weights = np.load(rec_dir_name + "weights_trajs_" + str(num_trajs) +
                                "_seed_" + str(seed) + "_iters_" + str(max_iters) +
                                ".npy")[-1]
    # load recovered parameters for this seed
    rec_ind1_maps = np.load(rec_dir_name + "ind1_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]
    rec_ind2_maps = np.load(rec_dir_name + "ind2_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]
    rec_inter_maps = np.load(rec_dir_name + "inter_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]

    # # get confidence intervals for weights
    # # std_weights = get_confidence_intervals_ti(rec_goal_maps, rec_weights)

    # # offset for the recovered rewards
    # min_rec_rewards = np.min(rec_rewards, axis=1)
    # offset = (min_rec_rewards - min_gen_rewards)[:, np.newaxis]

    # permutation = calculate_permutation(gen_goal_maps,gen_weights,
    #                                     rec_goal_maps, rec_weights)

    # final_rec_goal_maps, final_rec_weights, final_std_weights = transform_recovered_parameters(gen_goal_maps,
    #                                 gen_weights,rec_goal_maps[np.array(permutation)],
    #                                 rec_weights[np.array(permutation)], std_weights[np.array(permutation)],)

    # rec_rewards = rec_rewards - offset

    return gen_rewards, rec_weights, rec_ind1_maps, rec_ind2_maps, rec_inter_maps



# def compute_scale_and_offset_goal_maps(gen_goal_maps, rec_goal_maps):
#     '''
#     scale recovered goal maps so that generative and recovered goal
#     maps have the same maximum and minimum values
#     '''
#     N_goal_MAPS = gen_goal_maps.shape[0]
#     map_offsets = []
#     map_scales = []
#     for m in range(N_goal_MAPS):
#         min_gen = np.min(gen_goal_maps[m])
#         max_gen = np.max(gen_goal_maps[m])
#         min_rec = np.min(rec_goal_maps[m])
#         max_rec = np.max(rec_goal_maps[m])
#         offset = (min_gen * max_rec - min_rec * max_gen) / (max_rec - min_rec)
#         scale = (max_gen - min_gen) / (max_rec - min_rec)
#         assert scale > 0, "scale parameter should be greater than 0"
#         map_offsets.append(offset)
#         map_scales.append(scale)
#     return map_scales, map_offsets


# def transform_recovered_parameters(gen_goal_maps, gen_weights,
#                                    rec_goal_maps, rec_weights, std_weights=None, std_maps=None):
#     '''
#     perform all transformations for recovered weights and goal maps:
#     perform the sign conversion and calculate the relevant scales and offsets
#     '''
#     #calculate the sign conversion between recovered and generative
#     # parameters.  Use time-varying weights to calculate sign since these
#     # are only going to be modified by a positive scaling factor later on
#     N_GOAL_MAPS = len(gen_goal_maps)
#     for k in range(N_GOAL_MAPS):
#         # modifying sign conversion 
#         diff_orig = np.linalg.norm(rec_weights[k]-gen_weights[k])
#         diff_flipped = np.linalg.norm(rec_weights[k]+gen_weights[k])
#         if diff_orig>diff_flipped:
#             rec_weights[k] = -1 * rec_weights[k].copy()
#             rec_goal_maps[k] = -1 * rec_goal_maps[k].copy()

#     #now compute scaling factors and offsets:
#     map_scales, map_offsets = compute_scale_and_offset_goal_maps(
#         gen_goal_maps, rec_goal_maps)

#     for k in range(N_GOAL_MAPS):
#         rec_goal_maps[k] = map_scales[k]*rec_goal_maps[k].copy() + \
#                               map_offsets[k]
#         rec_weights[k] = (1 / map_scales[k]) * rec_weights[k].copy()
#         if std_weights is not None:
#             std_weights[k] = (1 / map_scales[k]) * std_weights[k].copy()
#         if std_maps is not None:
#             std_maps[k] = map_scales[k]*std_maps[k].copy() 
#     if std_weights is None:
#         return rec_goal_maps, rec_weights
#     else:
#         return rec_goal_maps, rec_weights, std_weights


# def calculate_permutation(gen_goal_maps, gen_weights, rec_goal_maps,
#                           rec_weights):
#     '''
#     loop through all permutations, perform appropriate transformations and
#     calculate the distance between the generative and recovered parameters.
#     identify the permutation of the recovered weights and goal maps so as
#     to minimize the distance between generative and recovered parameters
#     '''
#     N_GOAL_MAPS = gen_goal_maps.shape[0]
#     perms = list(permutations(range(N_GOAL_MAPS)))
#     dist_vec = []
#     for permutation in perms:
#         permuted_maps = rec_goal_maps[np.array(permutation)]
#         permuted_weights = rec_weights[np.array(permutation)]
#         # calculate transformation:
#         final_permuted_goal_maps, final_permuted_weights = \
#             transform_recovered_parameters(gen_goal_maps, gen_weights, permuted_maps, permuted_weights)
#         dist_vec.append(np.linalg.norm(final_permuted_goal_maps - gen_goal_maps))
#     optimal_permutation = perms[np.argmin(dist_vec)]
#     return optimal_permutation


# def get_confidence_intervals_ti(rec_goal_maps, rec_weights):
#     """ compute the confidence intervals of the recovered time-invariant weights
#         returns:
#             std_weights (1 X N_MAPS): std dev of weights """
    
#     # load parameters
#     P_a = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['P_a']
#     sigma = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['sigmas']
#     # load expert trajectories
#     all_trajectories = pickle.load(open(GEN_DIR_NAME + "/expert_trajectories.pickle", 'rb'))
#     val_indices = np.arange(start=0, stop=num_trajs, step=5)
#     train_indices = np.delete(np.arange(num_trajs), val_indices)
#     expert_trajectories = [all_trajectories[train_idx] for train_idx in train_indices]

#     N_GOAL_MAPS = rec_goal_maps.shape[0]
#     N_STATES = grid_H*grid_W
#     T = len(expert_trajectories[0]["actions"])

#     # compute inverse hessian of the MAP objective at the recovered parameters
#     inv_hess = compute_inv_hessian_ti(seed, P_a, expert_trajectories, [sigma]*N_GOAL_MAPS, rec_weights, rec_goal_maps, gamma=0.9)
#     # compute confidence intervals now
#     std_weights = compute_conf_interval_ti(inv_hess, T, N_GOAL_MAPS, N_STATES)
#     return std_weights



if __name__=='__main__':

    parser = argparse.ArgumentParser(description='enter environment specifics')
    parser.add_argument('--TRAIN_NOW', type=int, default=0,
                        help='whether to load and plot previously saved results or train DIRL now')
    parser.add_argument('--TRAIN_NOW_IND', type=int, default=0,
                        help='whether to train independent control now')
    parser.add_argument('--TAG_SIM', type=int, default=0, help='which grid interaction type to use')
    parser.add_argument('--TAG_TRAIN', type=int, default=0, help='which interaction type used to infer')
    parser.add_argument('--VERSION', type=int, default=1, help='which grid env to use')
    parser.add_argument('--SingleAgent', type=int, default=0, help='which agent to plot')

    args = parser.parse_args()

    TRAIN_NOW = args.TRAIN_NOW 
    TRAIN_NOW_IND = args.TRAIN_NOW_IND
    VERSION = args.VERSION
    TAG = args.TAG_SIM
    TAG_TRAIN = args.TAG_TRAIN

    tags_sim = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction',
                 'independent_control_w_policy_prediction', 'selfish_perfect_prediction', 'selfish_selfish']
    tags_train = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction', '', 'selfish']
    tag_sim = tags_sim[TAG]
    tag_train = tags_train[TAG_TRAIN]

    GEN_DIR_NAME = 'data/simulated_gridworld_data'
    REC_DIR_NAME = 'recovered_parameters/gridworld_recovered_params'
    REC_DIR_NAME = REC_DIR_NAME + '/fit_'+tag_train

    REC_DIR_NAME = REC_DIR_NAME + '/'+tag_sim+str(VERSION)
    GEN_DIR_NAME = GEN_DIR_NAME + "/"+tag_sim+str(VERSION)

    grid_H, grid_W = 5, 5 
    if VERSION == 3:
        grid_H, grid_W = 9, 9

    num_trajs = 200 # number of simulated trajectories to use
    max_iters = 20 # max iters to run SGD for optimization of goal maps and weights durng each outer loop of dirl
    n_maps_individual = 1 # individual maps per agent
    n_maps_interaction = 1 # interaction maps to use
    lr_maps = 0.005 # lr of goal maps
    lr_weights = 0.01 # lr of weights
    seed = 1 # initialization seed
    gamma = 0.9 # discount factor
    lam1 = 5. # l2 reg on individual maps
    lam2 = 1. # l2 reg on interactions maps

    info = {'Neval':0}
    # info = {'Neval':0,'ll':1}

    if TRAIN_NOW:
        fit_irl_gridworld(num_trajs, lr_weights, lr_maps, max_iters, gamma, n_maps_individual, n_maps_interaction,
                          seed, GEN_DIR_NAME, REC_DIR_NAME, lam1, lam2, TAG_TRAIN, info)
    if TRAIN_NOW_IND:
        fit_irl_gridworld_ind(num_trajs, lr_weights, lr_maps, max_iters, gamma, n_maps_individual, n_maps_interaction,
                            seed, GEN_DIR_NAME, REC_DIR_NAME, lam1, lam2, TAG_TRAIN, info)

    
    rec_dir_name = REC_DIR_NAME + "/maps_{}_{}_lr_{}_{}_lam_{}_{}/".format(n_maps_individual, n_maps_interaction, lr_weights, lr_maps, lam1,lam2)
    if args.SingleAgent:
        rec_dir_name = rec_dir_name + 'agent_{}/'.format(args.SingleAgent)
    gen_rewards = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['generative_rewards']
    rec_weights = np.load(rec_dir_name + "weights_trajs_" + str(num_trajs) +"_seed_" + str(seed) + "_iters_" + str(max_iters) +".npy")[-1]
    rec_ind1_maps = np.load(rec_dir_name + "ind1_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_ind2_maps = np.load(rec_dir_name + "ind2_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_inter_maps = np.load(rec_dir_name + "inter_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]

    # ---------------------------------------------------------
    # Begin Plotting
    # ---------------------------------------------------------
    save_dir = rec_dir_name 

    LEGEND_SIZE = 10
    SMALL_SIZE = 15
    BIGGER_SIZE = 20

    plt.rc('font', size=LEGEND_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=LEGEND_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=LEGEND_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=LEGEND_SIZE)    # legend fontsize
    colors = ['steelblue', '#D85427', 'tab:green', 'k']

    MAP_LABELS = ['home', 'water'] # labels of the goal maps
    STATE_LABELS = ['home', 'water'] # labels of states in gridworld to plot rewards for

    # Plot the maps and list the weight
    fig = plt.figure(figsize=(5,5))
    plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.93, wspace=0.3, hspace=0.4)

    plt.subplot(2, 2, 1)
    cmap = plt.get_cmap('viridis')
    new_cmap = generate_colormap(cmap, 0.5, 1)
    plt.imshow(np.reshape(gen_rewards[:,0],(grid_H, grid_W),order='F'))
    plt.colorbar()
    plt.title('gen. rewards', fontsize=8)
    plt.axis('off')

    plt.subplot(2, 2, 2)
    plt.plot(np.unique([i**2+j**2 for i in range(grid_H) for j in range(grid_W)]), rec_inter_maps[0,:])
    plt.ylabel('Strength')
    plt.xlabel('sqaure grid difference ')
    plt.title('rec. diff map: {:.2f}'.format(rec_weights[-1,0]), fontsize=8)
    # plt.axis('off')

    # RECOVERED agent 1 map
    plt.subplot(2, 2, 3)
    plt.imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'))
    plt.colorbar()
    plt.title('rec. agent 1: {:.2f}'.format(rec_weights[0,0]), fontsize=8)
    # plt.axis('off')

    # RECOVERED agent 2 map
    plt.subplot(2, 2, 4)
    plt.imshow(np.reshape(rec_ind2_maps[0,:],(grid_H, grid_W),order='F'))
    plt.colorbar()
    plt.title('rec. agent 2: {:.2f}'.format(rec_weights[1,0]), fontsize=8)
    # plt.axis('off')
    plt.tight_layout()
    fig.savefig(save_dir + 'maps.png')
    # fig.savefig(save_dir + 'maps.pdf')

    # # Plot the time-invariant weight as barplot
    # fig = plt.figure(figsize=(2.2,2.2))
    # plt.subplots_adjust(left=0.3, bottom=0.3, right=0.9, top=0.9)
    # plt.errorbar(gen_time_invariant_weights[:,0], final_rec_weights[:,0], yerr=final_std_weights[:,0], fmt='o', markersize=10)
    # xlim_min, xlim_max = plt.xlim()
    # plt.plot([xlim_min, xlim_max], [xlim_min, xlim_max], 'k--')
    # plt.xlabel('Gen a')
    # plt.ylabel('Rec a')
    # fig.savefig(REC_DIR_NAME + '/maps_{}_lr_{}_{}/weights.pdf'.format(n_maps, lr_maps, lr_weights))


    # load in loss function and plot
    losses_all_maps = np.load(save_dir + 'losses_maps_trajs_{}_seed_{}_iters_{}.npy'.format(num_trajs, seed, max_iters))
    losses_weights = np.load(save_dir + 'losses_weights_trajs_{}_seed_{}_iters_{}.npy'.format(num_trajs, seed, max_iters))

    fig, axs = plt.subplots(2,1,figsize=(4.4, 4.4))
    axs[0].plot(losses_weights)
    axs[0].set_title('Loss for weight')
    axs[1].plot(losses_all_maps)
    axs[1].set_title('Loss for maps')
    plt.tight_layout()
    fig.savefig(save_dir + 'loses.pdf')

    # load in test set ll and plot
    tmp = glob.glob(save_dir+'validation*.npy')
    test_ll = np.load(tmp[0])
    fig = plt.figure(figsize=(2.2,2.2))
    plt.plot(np.array(test_ll))
    plt.xlabel('Iteration (outer loop)')
    plt.ylabel('Test ll per decision')
    plt.title('Final: {:.4f}'.format(test_ll[-1]))
    plt.tight_layout()
    fig.savefig(save_dir + 'test_ll.pdf')


