from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import pprint
from distutils.util import strtobool

from dotmap import DotMap

from MBExperiment_dtwil import MBExperiment_dtwil
from MPC_dtwil import MPC
from config import create_config
import gym
import d4rl
import env # We run this so that the env is registered

import torch
import numpy as np
import random
import tensorflow as tf

def set_global_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    tf.set_random_seed(seed)


def main(env, ctrl_type, ctrl_args, overrides, logdir, constraint_type, constraint_ub,criteria, obs_clip, arc_beta, arc_step, dynamic_alignment_progress, state_normalize):
    set_global_seeds(0)

    ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
    cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir ,constraint_type, constraint_ub, criteria, obs_clip, arc_beta, arc_step, dynamic_alignment_progress, state_normalize)
    cfg.pprint()

    assert ctrl_type == 'MPC'

    cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
    exp = MBExperiment_dtwil(cfg.exp_cfg)

    os.makedirs(exp.logdir)
    with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
        f.write(pprint.pformat(cfg.toDict()))

    exp.run_experiment()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-env', type=str, required=True,
                        help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
    parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
                        help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
    parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
                        help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
    parser.add_argument('-logdir', type=str, default='log',
                        help='Directory to which results will be logged (default: ./log)')
    parser.add_argument('-constraint_ub', type=float, default=1.0,
                        help='constraint upper bound, constraint_val')
    parser.add_argument('-constraint_type', type=str, default='no',
                        help='constraint type, constraint options: no, box, power')
    parser.add_argument('-criteria', type=str, default='dtw',
                        help='constraint value')
    parser.add_argument('-obs_clip', type=int, default=0,
                        help='clip stationary obs')
    parser.add_argument('-arc_beta', type=float, default=0,
                        help='beta in arc')
    parser.add_argument('-arc_step', type=int, default=5,
                        help='How many steps after the current action we apply arc')                                 
    parser.add_argument('-dynamic_alignment_progress', type=lambda x: bool(strtobool(x)), default=True,
                        help='Use dynamic time warping path to determine the alignment progress')
    parser.add_argument('-state_normalize', type=lambda x: bool(strtobool(x)), default=True,
                        help='Normalize the state before DTW calculation')
    args = parser.parse_args()

    main(args.env, "MPC", args.ctrl_arg, args.override, args.logdir, args.constraint_type, args.constraint_ub, args.criteria, args.obs_clip,
         args.arc_beta, args.arc_step, args.dynamic_alignment_progress, args.state_normalize)
