from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import pdb
import argparse
import random
import os

import gym
import estimators
import utils
import yaml

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser()
# saving
parser.add_argument('--outfile', default = None)

# common setup
parser.add_argument('--env_name', type = str, required = True)
parser.add_argument('--env_type', type = str, default = 'dmc')
parser.add_argument('--gamma', default = 0.999, type = float)
parser.add_argument('--oracle_num_trajs', default = 300, type = int)
parser.add_argument('--pie_num', type = int)

# variables
parser.add_argument('--seed', default = 0, type = int)

FLAGS = parser.parse_args()

with open('cfg.yaml', 'r') as file:
    config = yaml.safe_load(file)

def env_setup():
    env = utils.load_env(FLAGS.env_name, FLAGS.gamma, env_type = FLAGS.env_type)
    return env

def policies_setup(env, pi_num):
    pi = utils.load_policies(config, env, FLAGS.env_name, pi_num, env_type =  FLAGS.env_type)
    return pi

def on_policy_estimate(num_trajs, gamma, mdp, pi, random_pi = False):
    # on-policy ground case
    g_pi_paths, _ = utils.collect_data(mdp, pi, num_trajs, random_pi = random_pi)
    g_pi_estimator = estimators.OnPolicy(pi, gamma)
    #g_pi_est_ret, g_pi_est_rew = g_pi_estimator.estimate(g_pi_paths)
    stats = g_pi_estimator.estimate(g_pi_paths)
    return stats#, g_pi_est_rew

def main():  # noqa
    seed = FLAGS.seed
    utils.set_seed_everywhere(seed)

    directory = 'pie_configs/'.format(FLAGS.env_name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    mdp, gamma = env_setup(), FLAGS.gamma

    pie = policies_setup(mdp, FLAGS.pie_num)

    pie_ret_stats = on_policy_estimate(FLAGS.oracle_num_trajs, gamma, mdp, pie)
    pie_avg_est_ret = pie_ret_stats['avg_disc_ret']
    pie_std_est_ret = pie_ret_stats['std_disc_ret']
    pie_path_states = pie_ret_stats['pi_path_states']
    pie_path_acts = pie_ret_stats['pi_path_acts']
    pie_path_sa_vals = pie_ret_stats['pi_sa_val']

    rand_ret_stats = on_policy_estimate(FLAGS.oracle_num_trajs, gamma, mdp, pi = None, random_pi = True)
    rand_avg_est_ret = rand_ret_stats['avg_disc_ret']
    rand_std_est_ret = rand_ret_stats['std_disc_ret']

    summary = {
        'seed': seed,
        'pie_num': FLAGS.pie_num,
        'gamma': FLAGS.gamma,
        'pie_avg_est_ret': pie_avg_est_ret,
        'pie_std_est_ret': pie_std_est_ret,
        'rand_avg_est_ret': rand_avg_est_ret,
        'rand_std_est_ret': rand_std_est_ret,
        'pie_path_states': pie_path_states,
        'pie_path_acts': pie_path_acts,
        'pie_path_sa_vals': pie_path_sa_vals
    }

    outfile = '{}_{}_{}.npy'.format(FLAGS.env_name, FLAGS.gamma, FLAGS.pie_num)
    np.save(directory + outfile, summary)
    print ('saved dataset {}'.format(summary))

if __name__ == '__main__':
    main()
