import os
from datetime import datetime
import tensorflow as tf
from causal_irl.envs.noisy_action_wrapper import NoisyActionWrapper
from causal_irl.envs.noisy_observation_wrapper import NoisyObservationWrapper

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import gym
from gym.wrappers import TimeLimit
from stable_baselines import logger
import numpy as np
from stable_baselines.bench import Monitor
from stable_baselines.common import make_vec_env, tf_util
from stable_baselines.common.cmd_util import make_mujoco_env
from stable_baselines.common.evaluation import evaluate_policy
from stable_baselines.common.vec_env import SubprocVecEnv
from mpi4py import MPI
from causal_irl.algorithms.common.utils import encode_trajectories
from causal_irl.algorithms.gail import GAIL
from causal_irl.algorithms.gail.dataset import ExpertDataset
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
from causal_irl.envs.my_observation_wrapper import MyObservationWrapper
import argparse

date = '{}-{}'.format(datetime.now().month, datetime.now().day)
model_name = 'gail'
from pathlib import Path
Path(os.path.join(os.getcwd(), "results", date, model_name)).mkdir(parents=True, exist_ok=True)

def argsparser():

    parser = argparse.ArgumentParser("Causal Confusion Setting of GAIL")
    parser.add_argument('--exp_code', default='000')
    parser.add_argument('--env_id', help='environment ID', default='Hopper-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--expert_path', type=str,
                        default=os.path.join(os.getcwd(), "data/expert_demonstrations/hopper_trpo_subsampled.npz"))
    parser.add_argument('--policy_path', type=str)
    parser.add_argument('--checkpoint_dir', help='the directory to save model',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--log_dir', help='the directory to save log file',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--model_path', type=str, default=os.path.join(os.getcwd(), 'results', date, model_name))

    # For training
    parser.add_argument('--traj_limitation', type=float, default=-1)
    parser.add_argument('--val_interval', type=int, default=1024 * 10)
    parser.add_argument('--load_model', dest='load_model', action='store_true')
    parser.add_argument('--no-load_model', dest='load_model', action='store_false')
    parser.set_defaults(load_model=False)
    parser.add_argument('--train_mode', type=str, default='testing')
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--no-save_model', dest='save_model', action='store_false')
    parser.set_defaults(save_model=False)
    parser.add_argument('--tensorboard_log', type=str, default='')
    parser.add_argument('--entcoeff', type=float, default=0.0)
    parser.add_argument('--adversary_entcoeff', type=float, default=1e-3)
    parser.add_argument('--vf_stepsize', type=float, default=3e-4)
    parser.add_argument('--vf_iters', type=int, default=3)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--lam', type=float, default=0.98)
    parser.add_argument('--cg_damping', type=float, default=1e-2)


    # for evaluatation
    parser.add_argument('--eval_mode', type=str, default='testing')
    parser.add_argument('--eval_best', dest='eval_best', action='store_true')
    parser.add_argument('--no-eval_best', dest='eval_best', action='store_false')
    parser.set_defaults(eval_best=True)
    parser.add_argument('--save_eval', dest='save_eval', action='store_true')
    parser.add_argument('--no-save_eval', dest='save_eval', action='store_false')
    parser.set_defaults(save_eval=True)

    parser.add_argument('--total_timesteps', help='Max iteration for training BC', type=int, default=5e6)
    parser.add_argument('--total_eval_episodes', type=int, default=50)
    # parser.add_argument('--gravity', type=float, default=-9.81)
    parser.add_argument('--sigma', type=float, default=1.0)
    parser.add_argument('--dropout_rate', type=float, default=0.0)
    parser.add_argument('--action_noise', type=float, default=0.0)
    parser.add_argument('--ob_noise', type=float, default=0.0)
    return parser.parse_args()

def run(args):
    model_name = 'gail' + \
                 ('' if args.dropout_rate == 0.0 else str(args.dropout_rate)) + \
                 ('' if args.action_noise == 0.0 else '-action-noise{}'.format(args.action_noise)) + \
                 ('' if args.ob_noise == 0.0 else '-ob-noise{}'.format(args.ob_noise))
    print("Train Mode: {}, Eval Mode: {}".format(args.train_mode, args.eval_mode))
    if args.env_id == 'Pendulum-v0':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=100)
    elif args.env_id == 'Ant-v2':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=500)
    else:
        env = gym.make(args.env_id)

    if args.train_mode != 'testing':
        env = MyObservationWrapper(env, args.train_mode)
        # Load the expert dataset
        # but first, encode the rollout
        encoded_traj = encode_trajectories(env, args.expert_path, args.seed)
        dataset = ExpertDataset(traj_data=encoded_traj, verbose=1, traj_limitation=args.traj_limitation)
    else:
        dataset = ExpertDataset(expert_path=args.expert_path, verbose=1, traj_limitation=args.traj_limitation)

    # If loading uncomment this
    if args.load_model:
        model = GAIL.load(args.model_path, tensorboard_log=args.tensorboard_log)
        # model.cg_damping = args.cg_damping
        # model.vf_stepsize = args.vf_stepsize
        # model.vf_iters = args.vf_iters
        # model.gamma = args.gamma
        # model.lam = args.lam
        model.expert_dataset = dataset
        model.set_env(env)
    else:
        model = GAIL('MlpPolicy', env, dataset, verbose=1, val_interval=args.val_interval, seed=args.seed,
                     save_path=args.model_path, exp_code=args.exp_code,
                     entcoeff=args.entcoeff,
                     adversary_entcoeff=args.adversary_entcoeff,
                     vf_stepsize=args.vf_stepsize,
                     vf_iters=args.vf_iters,
                     gamma=args.gamma,
                     lam=args.lam,
                     cg_damping=args.cg_damping,
                     dropout_rate=args.dropout_rate,
                     action_noise=args.action_noise,
                     ob_noise=args.ob_noise,
                     tensorboard_log=os.path.join(args.log_dir,
                                                  args.exp_code+"_tensorboard_logs_{}_{}_traj{}_{}_{}_{}_".format(model_name,
                                                                                                   args.env_id,
                                                                                                   args.traj_limitation,
                                                                                                   args.train_mode,
                                                                                                   args.total_timesteps,
                                                                                                   args.seed)))

        model.learn(args.total_timesteps, reset_num_timesteps=not args.load_model)


    if args.save_model:
        model.save(os.path.join(args.checkpoint_dir,
                                args.exp_code+"{}_{}_traj{}_{}_{}_{}_".format(model_name, args.env_id,
                                                                            args.traj_limitation,
                                                                            args.train_mode,
                                                                            model.num_timesteps,
                                                                            model.seed)))

    if not args.load_model and args.eval_best:
        eval_model = GAIL.load(os.path.join(model.save_path, args.exp_code+"_best_{}_{}_traj{}_{}_{}".format(model_name,
                                                                                                             model.env.unwrapped.spec.id,
                                                                                                             model.expert_dataset.num_traj,
                                                                                                             args.train_mode,
                                                                                                             model.seed)))
    else:
        eval_model = model

    # for Pendulum only!
    env_eval = gym.make(args.env_id)
    if args.env_id == 'Pendulum-v0':
        env_eval = TimeLimit(env_eval, max_episode_steps=100)
    if args.env_id == 'Ant-v2':
        env_eval = TimeLimit(gym.make(args.env_id), max_episode_steps=500)

    eval_logs_file = None
    eval_rets_file = None
    if args.save_eval:
        eval_logs_file = os.path.join(args.log_dir,
                                                args.exp_code+"_eval_logs_{}_{}_traj{}_{}_{}_{}_{}".format(model_name, args.env_id,
                                                                                                    args.traj_limitation,
                                                                                                    args.train_mode,
                                                                                                    args.eval_mode,
                                                                                                    eval_model.num_timesteps,
                                                                                                    model.seed))
        eval_rets_file = os.path.join(args.log_dir,
                         args.exp_code + "_episode_rewards_{}_{}_traj{}_{}_{}_{}_{}".format(model_name,
                                                                                            args.env_id,
                                                                                            args.traj_limitation,
                                                                                            args.train_mode,
                                                                                            args.eval_mode,
                                                                                            eval_model.num_timesteps,
                                                                                            args.seed))
        if args.eval_mode == 'testing':
            env_eval = Monitor(env_eval,
                          filename=eval_logs_file)
        elif args.eval_mode in ['confounded', 'original']:
            env_eval = Monitor(MyObservationWrapper(env_eval, args.eval_mode),
                          filename=eval_logs_file)
        else:
            eval_logs_file = os.path.join(args.log_dir,
                                                args.exp_code+"_eval_logs_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(model_name, args.env_id,
                                                                                                    args.traj_limitation,
                                                                                                    args.train_mode,
                                                                                                    args.eval_mode,
                                                                                                    args.sigma,
                                                                                                    eval_model.num_timesteps,
                                                                                                    model.seed))
            eval_rets_file = os.path.join(args.log_dir,
                                                args.exp_code+"_episode_rewards_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(model_name, args.env_id,
                                                                                                    args.traj_limitation,
                                                                                                    args.train_mode,
                                                                                                    args.eval_mode,
                                                                                                    args.sigma,
                                                                                                    eval_model.num_timesteps,
                                                                                                    model.seed))
            if args.eval_mode == 'noisy':
                env_eval = Monitor(NoisyActionWrapper(env_eval, sigma=args.sigma),
                               filename=eval_logs_file)
            else:
                assert args.eval_mode == 'noisy-obs'
                env_eval = Monitor(NoisyObservationWrapper(env_eval, sigma=args.sigma),
                               filename=eval_logs_file)

    env_eval.seed(args.seed)
    # env_eval.unwrapped.sim.model.opt.gravity[2] = args.gravity
    eval_model.set_env(env_eval)
    render=False
    traj_rewards, traj_lengths = evaluate_policy(eval_model, env_eval, args.total_eval_episodes, deterministic=True,
                                                 render=render, return_episode_rewards=True)
    np.save(eval_rets_file,
            {"traj_rewards": traj_rewards,
             "traj_lengths": traj_lengths})

    print(np.mean(traj_rewards), np.std(traj_rewards))
    # env = gym.make('Hopper-v2')
    # obs = env.reset()
    # for i in range(1000):
    #     action, _states = model.predict(obs)
    #     obs, rewards, dones, info = env.step(action)
    #     if dones:
    #         obs = env.reset()
    #     env.render()
    # env.close()

def main():
    args = argsparser()
    run(args)

if __name__ == '__main__':
    main()
