# modified based on main_irl.py
# Adapted to hallway task described in Ho et al

import os, argparse, glob
import numpy as np
import pickle
from src.optimize_weights import getMAP_weights
from src.optimize_goal_maps import getMAP_goalmaps
from src.compute_validation_ll import get_validation_ll
import matplotlib.pyplot as plt
from src.irl_for_gridworld import fit_irl_gridworld
from src.irl_for_gridworld_ind import fit_irl_gridworld_ind
from plot_utils.generate_colormap import generate_colormap


def get_gen_recovered_parameters(seed, num_trajs, max_iters, 
                                 n_map_individual, n_map_interaction, lr_weights, lr_maps, lam1, lam2):
    """ makes a summary plot for recovered and generative parameters
        args:
            seed (int): which seed to plot
            n_map_individual (int): # of individual maps per agent
            n_map_interaction (int): # of interaction maps
            lr_weights (float): which learning rate to plot
            lr_maps (float): which learning rate to plot
            save (bool): whether to save the plot or not
    """
    # directories
    rec_dir_name = REC_DIR_NAME + "/maps_{}_{}_lr_{}_{}_lam_{}_{}/".format(n_map_individual, n_map_interaction, lr_weights, lr_maps, lam1,lam2)
    # gen_rewards = pickle.load(open(GEN_DIR_NAME + "/generative_parameters.pickle", 'rb'))['generative_rewards']

    # load recovered parameters for this seed
    rec_weights = np.load(rec_dir_name + "weights_trajs_" + str(num_trajs) +
                                "_seed_" + str(seed) + "_iters_" + str(max_iters) +
                                ".npy")[-1]
    # load recovered parameters for this seed
    rec_ind1_maps = np.load(rec_dir_name + "ind1_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]
    rec_ind2_maps = np.load(rec_dir_name + "ind2_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]
    rec_inter_maps = np.load(rec_dir_name + "inter_maps_trajs_" +
                                str(num_trajs) + "_seed_" + str(seed) +
                                "_iters_" + str(max_iters) + ".npy")[-1]

    # # get confidence intervals for weights
    # # std_weights = get_confidence_intervals_ti(rec_goal_maps, rec_weights)

    # # offset for the recovered rewards
    # min_rec_rewards = np.min(rec_rewards, axis=1)
    # offset = (min_rec_rewards - min_gen_rewards)[:, np.newaxis]

    # permutation = calculate_permutation(gen_goal_maps,gen_weights,
    #                                     rec_goal_maps, rec_weights)

    # final_rec_goal_maps, final_rec_weights, final_std_weights = transform_recovered_parameters(gen_goal_maps,
    #                                 gen_weights,rec_goal_maps[np.array(permutation)],
    #                                 rec_weights[np.array(permutation)], std_weights[np.array(permutation)],)

    # rec_rewards = rec_rewards - offset

    return None, rec_weights, rec_ind1_maps, rec_ind2_maps, rec_inter_maps


if __name__=='__main__':

    parser = argparse.ArgumentParser(description='enter environment specifics')
    parser.add_argument('--TRAIN_NOW', type=int, default=0,
                        help='whether to train now')
    parser.add_argument('--TRAIN_NOW_IND', type=int, default=0,
                        help='whether to train agents independently')
    parser.add_argument('--TAG_TRAIN', type=int, default=0, help='which interaction type used to infer')
    parser.add_argument('--VERSION', type=int, default=1, help='trajectories (version=1) or expert trajectories (version=2)')
    parser.add_argument('--SingleAgent', type=int, default=0, help='which agent to plot')


    args = parser.parse_args()

    TRAIN_NOW = args.TRAIN_NOW 
    TRAIN_NOW_IND = args.TRAIN_NOW_IND
    VERSION = args.VERSION
    TAG_TRAIN = args.TAG_TRAIN

    tags = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction',
             '','selfish']
    tag_train = tags[TAG_TRAIN]

    GEN_DIR_NAME = 'data/experiment_hallway/'+str(VERSION)
    REC_DIR_NAME = 'recovered_parameters/hallway/'
    REC_DIR_NAME = REC_DIR_NAME + 'fit_'+tag_train

    REC_DIR_NAME = REC_DIR_NAME + '/'+str(VERSION)

    grid_H, grid_W = 3, 5 

    num_trajs = 200 # number of simulated trajectories to use
    max_iters = 20 # max iters to run SGD for optimization of goal maps and weights durng each outer loop of dirl
    n_maps_individual = 1 # individual maps per agent
    n_maps_interaction = 1 # interaction maps to use
    lr_maps = 0.005 # lr of goal maps
    lr_weights = 0.01 # lr of weights
    seed = 1 # initialization seed
    gamma = 0.9 # discount factor
    lam1 = 2. # l2 reg on individual maps
    lam2 = 10. # l2 reg on interactions maps

    info = {'Neval':0}

    if TRAIN_NOW:
        fit_irl_gridworld(num_trajs, lr_weights, lr_maps, max_iters, gamma, n_maps_individual, n_maps_interaction,
                          seed, GEN_DIR_NAME, REC_DIR_NAME, lam1, lam2, TAG_TRAIN, info)
    
    if TRAIN_NOW_IND:
        fit_irl_gridworld_ind(num_trajs, lr_weights, lr_maps, max_iters, gamma, n_maps_individual, n_maps_interaction,
                          seed, GEN_DIR_NAME, REC_DIR_NAME, lam1, lam2, TAG_TRAIN, info)

    
    rec_dir_name = REC_DIR_NAME + "/maps_{}_{}_lr_{}_{}_lam_{}_{}/".format(n_maps_individual, n_maps_interaction, lr_weights, lr_maps, lam1,lam2)
    if args.SingleAgent:
        rec_dir_name = rec_dir_name + 'agent_{}/'.format(args.SingleAgent)
    rec_weights = np.load(rec_dir_name + "weights_trajs_" + str(num_trajs) +"_seed_" + str(seed) + "_iters_" + str(max_iters) +".npy")[-1]
    rec_ind1_maps = np.load(rec_dir_name + "ind1_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_ind2_maps = np.load(rec_dir_name + "ind2_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]
    rec_inter_maps = np.load(rec_dir_name + "inter_maps_trajs_" + str(num_trajs) + "_seed_" + str(seed) + "_iters_" + str(max_iters) + ".npy")[-1]

    # ---------------------------------------------------------
    # Begin Plotting
    # ---------------------------------------------------------
    save_dir = rec_dir_name 

    LEGEND_SIZE = 10
    SMALL_SIZE = 15
    BIGGER_SIZE = 20

    plt.rc('font', size=LEGEND_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=LEGEND_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=LEGEND_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=LEGEND_SIZE)    # legend fontsize
    colors = ['steelblue', '#D85427', 'tab:green', 'k']

    MAP_LABELS = ['home', 'water'] # labels of the goal maps
    STATE_LABELS = ['home', 'water'] # labels of states in gridworld to plot rewards for

    # Plot the maps and list the weight
    fig = plt.figure(figsize=(5,5))
    plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.93, wspace=0.3, hspace=0.4)

    plt.subplot(2, 2, 2)
    plt.plot(np.unique([i**2+j**2 for i in range(grid_H) for j in range(grid_W)]), rec_inter_maps[0,:])
    plt.ylabel('Strength')
    plt.xlabel('sqaure grid difference ')
    plt.title('rec. diff map: {:.2f}'.format(rec_weights[-1,0]), fontsize=8)
    # plt.axis('off')

    # RECOVERED agent 1 map
    plt.subplot(2, 2, 3)
    plt.imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'))
    plt.colorbar()
    plt.title('rec. agent 1: {:.2f}'.format(rec_weights[0,0]), fontsize=8)
    # plt.axis('off')

    # RECOVERED agent 2 map
    plt.subplot(2, 2, 4)
    plt.imshow(np.reshape(rec_ind2_maps[0,:],(grid_H, grid_W),order='F'))
    plt.colorbar()
    plt.title('rec. agent 2: {:.2f}'.format(rec_weights[1,0]), fontsize=8)
    # plt.axis('off')
    plt.tight_layout()
    fig.savefig(save_dir + 'maps.png')
    # fig.savefig(save_dir + 'maps.pdf')


    # load in loss function and plot
    losses_all_maps = np.load(save_dir + 'losses_maps_trajs_{}_seed_{}_iters_{}.npy'.format(num_trajs, seed, max_iters))
    losses_weights = np.load(save_dir + 'losses_weights_trajs_{}_seed_{}_iters_{}.npy'.format(num_trajs, seed, max_iters))

    fig, axs = plt.subplots(2,1,figsize=(4.4, 4.4))
    axs[0].plot(losses_weights)
    axs[0].set_title('Loss for weight')
    axs[1].plot(losses_all_maps)
    axs[1].set_title('Loss for maps')
    plt.tight_layout()
    fig.savefig(save_dir + 'loses.pdf')

    # load in test set ll and plot
    tmp = glob.glob(save_dir+'validation*.npy')
    test_ll = np.load(tmp[0])
    fig = plt.figure(figsize=(2.2,2.2))
    plt.plot(np.array(test_ll))
    plt.xlabel('Iteration (outer loop)')
    plt.ylabel('Test ll per decision')
    plt.title('Final: {:.4f}'.format(test_ll[-1]))
    plt.tight_layout()
    fig.savefig(save_dir + 'test_ll.pdf')


