from __future__ import print_function
from __future__ import division
import numpy as np
import pdb
import argparse
import random

from tabular_metrics import SimilarityMetric
from gridworld import Gridworld
#from policies import GridworldPolicy
from utils_gw import collect_data_discrete, heatmap, GridworldPolicy, triang_heatmap, compute_Q_values, group_clusters
#import estimators

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser()
# saving
parser.add_argument('--outfile', default = None)

# common setup
parser.add_argument('--env_name', type = str, required = True)
parser.add_argument('--mdp_num', default = 0, type = int)
parser.add_argument('--gamma', default = 0.999, type = float)
parser.add_argument('--pi_set', default = 0, type = int)
parser.add_argument('--mix_ratio', default = 0.7, type = float)
parser.add_argument('--epochs', default = 2000, type = int)
parser.add_argument('--phi_epochs', default = 50000, type = int)
parser.add_argument('--abs_pie_finetune_epochs', default = 5000, type = int)
parser.add_argument('--oracle_batch_size', default = 5, type = int)
parser.add_argument('--pib_est', default = 'false', type = str2bool)
parser.add_argument('--avg_rew', default = 'false', type = str2bool)
parser.add_argument('--image_state', default = 'false', type = str2bool)

# variables
parser.add_argument('--seed', default = 0, type = int)
parser.add_argument('--batch_size', default = 5, type = int)
parser.add_argument('--traj_len', default = 5, type = int)
parser.add_argument('--phi_lr', default = 3e-4, type = float)
parser.add_argument('--Q_lr', default = 1e-4, type = float)
parser.add_argument('--W_lr', default = 1e-4, type = float)
parser.add_argument('--lam_lr', default = 1e-4, type = float)
parser.add_argument('--rep_alpha', default = 0.5, type = float)

# misc
parser.add_argument('--plot', default = 'false', type = str2bool)
parser.add_argument('--exp_name', default = 'gan', type = str)
parser.add_argument('--print_log', default = 'false', type = str2bool)

FLAGS = parser.parse_args()

def env_setup():
    if FLAGS.env_name == 'Gridworld':
        if FLAGS.mdp_num == 0:
            env = Gridworld(length = 3)
    return env

def policies_setup(env):
    mix_ratios = [FLAGS.mix_ratio, 1. - FLAGS.mix_ratio]
    if FLAGS.env_name == 'Gridworld':
        if FLAGS.pi_set == 0:
            pie = GridworldPolicy(env, typ = 'det-L', L_right_prob = 1)
            pib = GridworldPolicy(env, typ = 'random')
            #pib = GridworldPolicy(env, typ = 'file', f_name = 'gridworld_policies/p2.txt')
        elif FLAGS.pi_set == 1:
            pie = GridworldPolicy(env, typ = 'det-L', L_right_prob = 1)
            pib = GridworldPolicy(env, typ = 'CW')
            #pib = GridworldPolicy(env, typ = 'file', f_name = 'gridworld_policies/p2.txt')
    return pie, pib

def on_policy_estimate(seed, batch_size, truncated_horizon, gamma, mdp, pie):
    # on-policy ground case
    g_pie_paths, _ = collect_data_discrete(mdp, pie, batch_size, truncated_horizon)
    g_pie_estimator = estimators.OnPolicy(pie, gamma)
    g_pie_est, g_pie_ret_est = g_pie_estimator.estimate(g_pie_paths)
    return g_pie_est if FLAGS.avg_rew else g_pie_ret_est


def q_heats(mdp, pi, data, gamma, epochs, name, show_cb = False, groupings = None):
    print (pi.pi_matrix)
    q_values = compute_Q_values(mdp, pi, data, gamma, epochs = epochs)
    d_vals = []
    for a in [1, 0, 3, 2]: # north, east, south, west, matching plotting code
        d_val = q_values[:, a]
        d_val = np.array(d_val)
        d_val = d_val.reshape((mdp.length, mdp.length))
        # done to get aligned with gridworld layout
        d_val = np.flip(d_val, axis = 0)
        d_vals.append(d_val)
    d_vals = np.array(d_vals)

    if groupings is not None:
        triang_heatmap(groupings, name + '.jpg', show_cb = show_cb, labels = d_vals)
    else:
        triang_heatmap(d_vals, name + '.jpg', show_cb = show_cb)

    return q_values

def dist_heats(mdp, metric, name, show_cb = False):
    distances = metric.distances
    diff_distances = np.copy(distances)
    for i in range(distances.shape[0]):
        for j in range(distances.shape[1]):
            if distances[i, j] != -np.inf:
                diff_distances[i, j] = distances[i, j] - 0.5 * (distances[i, i] + distances[j, j])

    #distances = diff_distances
    # (state, action)
    #refs = [(7, 0), (3, 1), (5, 2), (4, 3), (7, 2), (1,1)]
    #refs = [(7, 0), (4, 3), (1,1), (2, 1)]
    refs = [(1,1)]
    for ref in refs:
        ref_state = ref[0]
        ref_action = ref[1]
        cord = ref_state * mdp.n_action + ref_action
        dists_from_ref = distances[cord]
        d_vals = []
        for a in [1, 0, 3, 2]: # north, east, south, west, matching plotting code
            d_val = []
            for s in range(mdp.n_state):
                val = dists_from_ref[s * mdp.n_action + a]
                d_val.append(val)
            d_val = np.array(d_val)
            d_val = d_val.reshape((mdp.length, mdp.length))
            # done to get aligned with gridworld layout
            d_val = np.flip(d_val, axis = 0)
            d_vals.append(d_val)
        d_vals = np.array(d_vals)

        triang_heatmap(d_vals, '{}_metric_{}_{}.pdf'.format(name, ref[0], ref[1]), show_cb)
        

def main():  # noqa
    batch_size = FLAGS.batch_size
    traj_len = FLAGS.traj_len
    
    mdp, gamma = env_setup(), FLAGS.gamma
    
    pie, pib = policies_setup(mdp)

    data_pib = GridworldPolicy(mdp, typ = 'random')
    paths, x = collect_data_discrete(mdp, data_pib, 100, 500)
    states = []
    rewards = []
    actions = []
    next_states = []
    dones = []
    for idx in range(len(paths)):
        states.extend(paths[idx]['obs'])
        rewards.extend(paths[idx]['rews'])
        actions.extend(paths[idx]['acts'])
        next_states.extend(paths[idx]['nobs'])
        dones.extend(paths[idx]['dones'])

    data = {
        'states': states,
        'rewards': rewards,
        'actions': actions,
        'next_states': next_states,
        'dones': dones
    }

    q_values = None#q_heats(mdp, pie, data, gamma, epochs = 2000, name = 'pie_q_values', show_cb = False)

    #q_heats(mdp, pib, data, gamma, epochs = 2000, name = 'pib_q_values')

    # pie_sing_sa = SimilarityMetric(mdp = mdp, typ = 'state-action',\
    #     algo = 'pie-single-SA-MICO', pie = pie, pib = pib, gamma = gamma)
    # pie_sing_sa.learn(data, epochs = 50000)

    # dist_heats(mdp, pie_sing_sa, name = 'pie_sing')


    pie_sa = SimilarityMetric(mdp = mdp, typ = 'state-action',\
        algo = 'pie-SA-MICO', pie = pie, pib = pib, gamma = gamma)
    pie_sa.learn(data, epochs = 50000)

    rope_groupings = group_clusters(mdp, pie_sa, name = 'rope_group', q_values = q_values, show_cb = False)


    q_heats(mdp, pie, data, gamma, epochs = 2000, name = 'pie_q_values', show_cb = False, groupings = rope_groupings)



    pib_sa = SimilarityMetric(mdp = mdp, typ = 'state-action',\
        algo = 'pib-SA-MICO', pie = pie, pib = pib, gamma = gamma)
    pib_sa.learn(data, epochs = 50000)

    mico_groupings = group_clusters(mdp, pib_sa, name = 'mico_group', q_values = q_values, show_cb = False)

    q_heats(mdp, pib, data, gamma, epochs = 2000, name = 'pib_q_values', groupings = mico_groupings)


    #dist_heats(mdp, pib_sa, name = 'pib', show_cb = True)

    rand_sa = SimilarityMetric(mdp = mdp, typ = 'state-action',\
        algo = 'rand-SA', pie = pie, pib = pib, gamma = gamma)
    rand_sa.learn(data, epochs = 50000)

    group_clusters(mdp, rand_sa, name = 'rand_group', q_values = q_values, show_cb = False)


    # dist_heats(mdp, rand_sa, name = 'rand', show_cb = True)

    pie_psm_sa = SimilarityMetric(mdp = mdp, typ = 'state-action',\
        algo = 'pie-SA-PSM', pie = pie, pib = pib, gamma = gamma)
    pie_psm_sa.learn(data, epochs = 50000)

    group_clusters(mdp, pie_psm_sa, name = 'psm_group', q_values = q_values, show_cb = False)


    # dist_heats(mdp, pie_psm_sa, name = 'pie_psm', show_cb = True)

    pdb.set_trace()



    seed = FLAGS.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    #oracle_est = on_policy_estimate(seed, FLAGS.oracle_batch_size, traj_len, gamma, mdp, pie)
    #print ('pie true: {}'.format(oracle_est))
    #if FLAGS.pib_est:
    #    torch.manual_seed(seed)
    #    np.random.seed(seed)
    #    pib_est = on_policy_estimate(seed, FLAGS.oracle_batch_size, traj_len, gamma, mdp, pib)
    #    print ('pib true: {}'.format(pib_est))

    run_experiment_learn_phi(seed, batch_size, traj_len, gamma, mdp, pie, pib)

    summary = {
        'results': {},
        'seed': seed,
        'batch_size': batch_size,
        'traj_len': traj_len,
        'hp': {
            'Q_lr': FLAGS.Q_lr,
            'W_lr': FLAGS.W_lr,
            'lam_lr': FLAGS.lam_lr,
            'phi_lr': FLAGS.phi_lr,
            'rep_alpha': FLAGS.rep_alpha
        },
        'oracle_est': oracle_est
    }

    for data_name in all_results:
        r_ests = all_results[data_name]['results']
        training_metrics = all_results[data_name]['metrics']

        rel_den = 1.#utils.get_MSE([pib_est], [oracle_est])['mean']
        mses = []
        for r in r_ests:
            mse = utils.get_MSE([oracle_est], [r])['mean']
            mses.append(mse / rel_den)
        mses_training = []
        for algo_training_metrics in training_metrics:
            algo_training_r_ests = algo_training_metrics['r_ests']
            mses_training_log = {}
            for epoch in sorted(algo_training_r_ests):
                mses_training_log[epoch] = np.square(oracle_est - algo_training_r_ests[epoch]) / rel_den # average across trials in plotting
            mses_training.append(mses_training_log)
        
        summary['results'][data_name] = {}
        for idx, algo in enumerate(algos):
            summary['results'][data_name][algo] = {
                'mse': mses[idx], # single MSE just for single trial
                'r_est': r_ests[idx],
                'mse_training': mses_training[idx],
                'r_ests': training_metrics[idx]['r_ests'],
                #'q_ranks': q_ranks_training[idx],
                #'q_rank': q_ranks_training[idx][-1],
                #'w_ranks': w_ranks_training[idx],
                #'w_rank': w_ranks_training[idx][-1]
            }
        print (summary)
    np.save(FLAGS.outfile, summary) 

if __name__ == '__main__':
    main()
