from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import pdb
import argparse
import random
import os

import gym
import estimators
import utils
import yaml

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser()
# saving
parser.add_argument('--outfile', default = None)

# common setup
parser.add_argument('--env_name', type = str, required = True)
parser.add_argument('--dataset_name', type = str)
parser.add_argument('--d4rl_dataset', type = str2bool)
parser.add_argument('--gamma', default = 0.999, type = float)
parser.add_argument('--oracle_num_trajs', default = 300, type = int)
parser.add_argument('--image_state', default = 'false', type = str2bool)

# variables
parser.add_argument('--seed', default = 0, type = int)

FLAGS = parser.parse_args()

with open('cfg.yaml', 'r') as file:
    config = yaml.safe_load(file)

def env_setup():
    env = utils.load_env(FLAGS.env_name, FLAGS.gamma)
    return env

def policies_setup(env, pi_num):
    if pi_num == None:
        return None
    pi = utils.load_policies(config, env, FLAGS.env_name, pi_num)
    return pi

def _data_prep(gamma, mdp, data_collection_info):

    data_list = []
    pibs, sample_fracs = data_collection_info
    for pi, sub_samples in zip(pibs, sample_fracs):
        if sub_samples > 0:
            if pi is None:
                g_paths, _ = utils.collect_data_samples(mdp, None, sub_samples, random_pi = True)
            else:
                g_paths, _ = utils.collect_data_samples(mdp, pi, sub_samples)
            data_list.append(g_paths)

    merged_paths = utils.merge_datasets(data_list)
    initial_states = np.array([path['obs'][0] for path in merged_paths])
    data = {
        'initial_states': initial_states,
        'ground_data': merged_paths
    }
    # format data into relevant inputs needed by loss function
    data = utils.format_data_new(data, gamma)
    return data

def on_policy_estimate(num_trajs, gamma, mdp, pi, random_pi = False):
    if pi == None:
        random_pi = True
    # on-policy ground case
    g_pi_paths, _ = utils.collect_data(mdp, pi, num_trajs, random_pi = random_pi)
    g_pi_estimator = estimators.OnPolicy(pi, gamma)
    #g_pi_est_ret, g_pi_est_rew = g_pi_estimator.estimate(g_pi_paths)
    g_pi_est_ret = g_pi_estimator.estimate(g_pi_paths)
    return g_pi_est_ret#, g_pi_est_rew

def off_policy_estimate(data, gamma):
    estimator = estimators.OffPolicy(gamma)
    data_est_ret, data_est_rew = estimator.estimate(data)
    return data_est_ret, data_est_rew

def main():  # noqa
    seed = FLAGS.seed
    utils.set_seed_everywhere(seed)

    directory = 'datasets/'
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    mdp, gamma = env_setup(), FLAGS.gamma

    single_pi = False # whether to merge multiple policies into mixture
    if 'single_pi' in config[FLAGS.env_name][FLAGS.dataset_name]:
        single_pi = config[FLAGS.env_name][FLAGS.dataset_name]['single_pi']

    if FLAGS.d4rl_dataset:
        pie_num = config[FLAGS.env_name]['d4rl']['expert']
        pib_num = config[FLAGS.env_name]['d4rl']['medium']
        pie = utils.load_d4rl_policy(config, mdp, FLAGS.env_name, pie_num)
        pib = utils.load_d4rl_policy(config, mdp, FLAGS.env_name, pib_num)
        print ('loaded d4rl policies: {}, {}'.format(pie_num, pib_num))
    else:
        # pie = policies_setup(mdp, pi_type = 'expert')
        # pib = policies_setup(mdp, pi_type = 'medium')
        pib_nums = config[FLAGS.env_name][FLAGS.dataset_name]['pibs']
        pibs = []
        for pb in pib_nums:
            pb = None if pb == 'None' else pb
            pibs.append(policies_setup(mdp, pb))
        mix_ratios = config[FLAGS.env_name][FLAGS.dataset_name]['mix']

    if single_pi:
        pi = utils.load_sb_mixture(mdp, pibs, mix_ratios)
        pibs = [pi] # overwrite list of behavior policies with the single one
        mix_ratios = [1.]

    assert len(mix_ratios) == len(pibs)
    assert np.abs(sum(mix_ratios) - 1) <= 1e-10    

    for pib in pibs:
        pib_est_ret_stats = on_policy_estimate(FLAGS.oracle_num_trajs, gamma, mdp, pib)
        pib_est_ret = pib_est_ret_stats['avg_disc_ret']
        print ('pib true (ret, rew): {} {}'.format(pib_est_ret, pib_est_ret * (1. - gamma)))

    samples_to_collect = config[FLAGS.env_name][FLAGS.dataset_name]['num_samples']
    sample_fracs = [int(ratio * samples_to_collect) for ratio in mix_ratios]
    print ('fraction: ' + str(sample_fracs))

    data_collection_info = (pibs, sample_fracs)

    dataset_name_part = '{}-{}'.format('d4rl' if FLAGS.d4rl_dataset else 'custom', FLAGS.dataset_name)
    dataset_name = 'name_{}'.format(dataset_name_part)

    ope_data = _data_prep(gamma, mdp, data_collection_info = data_collection_info)

    summary = {
        'dataset_name': '{}_{}'.format(FLAGS.env_name, dataset_name),
        'dataset': ope_data,
        'num_samples': samples_to_collect,
        'seed': seed,
    }

    data_est_ret, data_est_rew = off_policy_estimate(ope_data, gamma)
    print ('collected dataset; data value (ret, rew): {} {}'.format(data_est_ret, data_est_rew))
    outfile = summary['dataset_name']
    np.save(directory + outfile, summary)
    print ('saved dataset')

if __name__ == '__main__':
    main()
