import tensorflow as tf
from gym.wrappers import TimeLimit

from causal_irl.algorithms.gail import GAIL
from causal_irl.envs.noisy_action_wrapper import NoisyActionWrapper
from causal_irl.envs.noisy_observation_wrapper import NoisyObservationWrapper

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from datetime import datetime
import numpy as np
import gym
import matplotlib.pyplot as plt
from stable_baselines.bench import Monitor
from causal_irl.algorithms.common.evaluation import evaluate_policy_original as evaluate_policy
from causal_irl.algorithms.common.evaluation import evaluate_policy as evaluate_rejection

from causal_irl.algorithms.airl import AIRL
from causal_irl.algorithms.common.utils import encode_trajectories, name_to_models, get_expert_weights_and_rewards
from causal_irl.algorithms.bc import BC
from causal_irl.algorithms.gail.dataset import ExpertDataset

from causal_irl.algorithms.fusion import BC_GAIL
from causal_irl.algorithms.fusion.rl_gan import RL_GAN
from causal_irl.envs.my_observation_wrapper import MyObservationWrapper
import argparse
date = '{}-{}'.format(datetime.now().month, datetime.now().day)
model_name = 'rl_gan'
from pathlib import Path
Path(os.path.join(os.getcwd(), "results", date, model_name)).mkdir(parents=True, exist_ok=True)

def argsparser():
    parser = argparse.ArgumentParser("Rejection Sampling")
    parser.add_argument('--exp_code', default='000')
    parser.add_argument('--env_id', help='environment ID', default='Hopper-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--expert_path', type=str, default=os.path.join(os.getcwd(), "data/expert_demonstrations/hopper_trpo_subsampled.npz"))
    parser.add_argument('--generator_path', type=str)
    parser.add_argument('--checkpoint_dir', help='the directory to save model',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--log_dir', help='the directory to save log file',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--model_path', type=str, default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--generator', type=str, default='bc')
    parser.add_argument('--dis_path', type=str, default=None)
    parser.add_argument('--use_airl', dest='use_airl', action='store_true')
    parser.add_argument('--no-use_airl', dest='use_airl', action='store_false')
    parser.set_defaults(use_airl=False)

    parser.add_argument('--use_random', dest='use_random', action='store_true')
    parser.add_argument('--no-use_random', dest='use_random', action='store_false')
    parser.set_defaults(use_random=False)
    # For training
    parser.add_argument('--val_interval', type=int, default=None)
    parser.add_argument('--traj_limitation', type=float, default=-1)
    parser.add_argument('--load_model', dest='load_model', action='store_true')
    parser.add_argument('--no-load_model', dest='load_model', action='store_false')
    parser.set_defaults(load_model=False)
    parser.add_argument('--train_mode', type=str, default='testing')
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--no-save_model', dest='save_model', action='store_false')
    parser.set_defaults(save_model=False)

    # for evaluatation
    parser.add_argument('--eval_mode', type=str, default='testing')
    parser.add_argument('--eval_best', dest='eval_best', action='store_true')
    parser.add_argument('--no-eval_best', dest='eval_best', action='store_false')
    parser.set_defaults(eval_best=True)
    parser.add_argument('--save_eval', dest='save_eval', action='store_true')
    parser.add_argument('--no-save_eval', dest='save_eval', action='store_false')
    parser.set_defaults(save_eval=True)
    # parser.add_argument('--eval_rejection', dest='eval_rejection', action='store_true')
    # parser.add_argument('--no-eval_rejection', dest='eval_rejection', action='store_false')
    # parser.set_defaults(eval_rejection=False)
    # parser.add_argument('--eval_state', dest='eval_state', action='store_true')
    # parser.add_argument('--no-eval_state', dest='eval_state', action='store_false')
    # parser.set_defaults(eval_state=False)
    # parser.add_argument('--eval_mpc', dest='eval_mpc', action='store_true')
    # parser.add_argument('--no-eval_mpc', dest='eval_mpc', action='store_false')
    # parser.set_defaults(eval_mpc=False)
    parser.add_argument('--n_epochs', help='Max iteration for training GAN', type=int, default=1000)
    parser.add_argument('--plan_interval', type=int, default=1)
    parser.add_argument('--look_ahead', type=int, default=1000)
    parser.add_argument('--n_envs', type=int, default=10)
    parser.add_argument('--total_eval_episodes', type=int, default=50)
    # parser.add_argument('--gravity', type=float, default=-9.81)
    parser.add_argument('--sigma', type=float, default=1.0)
    return parser.parse_args()


def main(args):
    rejection_code = "{}_{}_{}".format(args.n_envs, args.plan_interval, args.look_ahead)
    use_random = "random" if args.use_random else "gail"
    print("Train Mode: {}, Eval Mode: {}".format(args.train_mode, args.eval_mode))
    if args.env_id == 'Pendulum-v0':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=100)
    elif args.env_id == 'Ant-v2':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=500)
    else:
        env = gym.make(args.env_id)

    # set the seed so that train-val split is consistent
    np.random.seed(args.seed)
    if args.train_mode != 'testing':
        env = MyObservationWrapper(env, args.train_mode)
        # Load the expert dataset
        # but first, encode the rollout

        encoded_traj = encode_trajectories(env, args.expert_path, args.seed)
        expert_dataset = ExpertDataset(traj_data=encoded_traj, verbose=1, traj_limitation=args.traj_limitation)
    else:
        expert_dataset = ExpertDataset(expert_path=args.expert_path, verbose=1, traj_limitation=args.traj_limitation)

    # policy_dataset = ExpertDataset(expert_path=args.policy_path, verbose=1)
    # generator_path = os.path.join(os.getcwd(), 'results',
    #                               os.path.split(os.path.split(os.path.split(os.path.split(args.policy_path)[0])[0])[0])[1],
    #                                             args.generator, os.path.split(args.policy_path)[1][8:-4])
    # If loading uncomment this
    if args.load_model:
        model = RL_GAN.load(args.model_path)
        model.expert_dataset = expert_dataset
        # model.policy_dataset = policy_dataset
    else:
        if args.dis_path is not None:
            if args.use_airl:
                dis_model = AIRL.load(args.dis_path)
            else:
                dis_model = GAIL.load(args.dis_path)
            model = RL_GAN('MlpPolicy', env, discriminator=dis_model, expert_dataset=expert_dataset, generator_name=args.generator, generator_path=args.generator_path,
                       verbose=0, val_interval=args.val_interval, seed=args.seed, save_path=args.model_path,
                       using_airl=args.use_airl, using_random=args.use_random, code=args.exp_code, plan_interval=args.plan_interval, look_ahead=args.look_ahead, n_rollout_envs=args.n_envs)
        else:
            raise RuntimeError("must pass in a discriminiator")

    if args.save_model:
        model.save(os.path.join(args.checkpoint_dir,
                                "{}_{}_{}_{}_traj{}_{}_{}_{}".format(args.exp_code, model_name, use_random, args.env_id,
                                                                  args.traj_limitation,
                                                                  args.train_mode,
                                                                  model.trained_epochs,
                                                                  args.seed)))

    if not args.load_model and args.eval_best:
        eval_model = model.load(
            os.path.join(model.save_path, args.exp_code+"_best_rl_gan_{}_{}_traj{}_{}_{}".format(use_random,
                                                                                                 model.env.unwrapped.spec.id,
                                                                            model.expert_dataset.num_traj,
                                                                            args.train_mode,
                                                                            model.seed)))
    else:
      eval_model = model

    # for Pendulum only!
    env_eval = gym.make(args.env_id)
    if args.env_id == 'Pendulum-v0':
        env_eval = TimeLimit(env_eval, max_episode_steps=100)
    if args.env_id == 'Ant-v2':
        env_eval = TimeLimit(gym.make(args.env_id), max_episode_steps=500)

    eval_logs_file = None
    eval_rets_file = None
    if args.save_eval:
        eval_logs_file = os.path.join(args.log_dir,
                                      args.exp_code + "_{}_eval_logs_{}_{}_{}_traj{}_{}_{}_{}_{}".format(rejection_code,
                                                                                                         model_name,
                                                                                                         use_random,
                                                                                                         args.env_id,
                                                                                                         args.traj_limitation,
                                                                                                         args.train_mode,
                                                                                                         args.eval_mode,
                                                                                                         eval_model.trained_epochs,
                                                                                                         args.seed))
        eval_rets_file = os.path.join(args.log_dir, args.exp_code + "_{}_episode_rewards_{}_{}_{}_traj{}_{}_{}_{}_{}".format(rejection_code, model_name,
                                                                                               use_random,
                                                                                               args.env_id,
                                                                                               args.traj_limitation,
                                                                                               args.train_mode,
                                                                                               args.eval_mode,
                                                                                               eval_model.trained_epochs,
                                                                                               args.seed))
        if args.eval_mode == 'testing':
            env_eval = Monitor(env_eval,
                               filename=eval_logs_file)
        elif args.eval_mode in ['confounded', 'original']:
            env_eval = Monitor(MyObservationWrapper(env_eval, args.eval_mode),
                               filename=eval_logs_file)
        else:
            eval_logs_file = os.path.join(args.log_dir,
                                                     args.exp_code + "_{}_eval_logs_{}_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(
                                                         rejection_code, model_name, use_random,
                                                         args.env_id,
                                                         args.traj_limitation,
                                                         args.train_mode,
                                                         args.eval_mode,
                                                         args.sigma,
                                                         eval_model.trained_epochs,
                                                         args.seed))
            eval_rets_file = os.path.join(args.log_dir,
                                                     args.exp_code + "_{}_episode_rewards_{}_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(
                                                         rejection_code, model_name, use_random,
                                                         args.env_id,
                                                         args.traj_limitation,
                                                         args.train_mode,
                                                         args.eval_mode,
                                                         args.sigma,
                                                         eval_model.trained_epochs,
                                                         args.seed))
            if args.eval_mode == 'noisy':
                env_eval = Monitor(NoisyActionWrapper(env_eval, args.sigma),
                               filename=eval_logs_file)
            else:
                assert args.eval_mode == 'noisy-obs'
                env_eval = Monitor(NoisyObservationWrapper(env_eval, args.sigma),
                               filename=eval_logs_file)
    env_eval.seed(args.seed)
    # env_eval.unwrapped.sim.model.opt.gravity[2] = args.gravity
    eval_model.set_env(env_eval)
    # eval_model.rejection_sampling = args.eval_state
    # eval_model.mpc = args.eval_mpc
    render = False
    # if args.eval_state or args.eval_mpc:
    traj_rewards, traj_lengths = evaluate_policy(eval_model, env_eval, args.total_eval_episodes, deterministic=True, render=render, return_episode_rewards=True)
    np.save(eval_rets_file, {"traj_rewards": traj_rewards, "traj_lengths": traj_lengths})
    print(np.mean(traj_rewards), np.std(traj_rewards))


if __name__ == '__main__':
    args = argsparser()
    main(args)
