import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from datetime import datetime
import tensorflow as tf
from gym.wrappers import TimeLimit
import matplotlib.pyplot as plt

from causal_irl.envs.noisy_action_wrapper import NoisyActionWrapper
from causal_irl.envs.noisy_observation_wrapper import NoisyObservationWrapper

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import numpy as np
import gym
from stable_baselines.bench import Monitor
from stable_baselines.common.evaluation import evaluate_policy

from causal_irl.algorithms.airl import AIRL
from causal_irl.algorithms.common.utils import encode_trajectories
from causal_irl.algorithms.bc import BC
from causal_irl.algorithms.gail.dataset import ExpertDataset

from causal_irl.algorithms.fusion import BC_GAIL
from causal_irl.envs.my_observation_wrapper import MyObservationWrapper
import argparse

date = '{}-{}'.format(datetime.now().month, datetime.now().day)
model_name = 'bc'
from pathlib import Path
Path(os.path.join(os.getcwd(), "results", date, model_name)).mkdir(parents=True, exist_ok=True)
def argsparser():
    parser = argparse.ArgumentParser("BC")
    parser.add_argument('--exp_code', default='000')
    parser.add_argument('--env_id', help='environment ID', default='Hopper-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--expert_path', type=str,
                        default=os.path.join(os.getcwd(), "data/expert_demonstrations/hopper_trpo_subsampled.npz"))
    parser.add_argument('--checkpoint_dir', help='the directory to save model',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--log_dir', help='the directory to save log file',
                        default=os.path.join(os.getcwd(), 'results', date, model_name))
    parser.add_argument('--model_path', type=str, default=os.path.join(os.getcwd(), 'results', date, model_name))

    # For training
    parser.add_argument('--deterministic', type=bool, default=False)
    parser.add_argument('--traj_limitation', type=float, default=-1)
    parser.add_argument('--val_interval', type=int, default=1)
    parser.add_argument('--load_model', dest='load_model', action='store_true')
    parser.add_argument('--no-load_model', dest='load_model', action='store_false')
    parser.set_defaults(load_model=False)
    parser.add_argument('--train_mode', type=str, default='testing')
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--no-save_model', dest='save_model', action='store_false')
    parser.set_defaults(save_model=False)

    # for evaluatation
    parser.add_argument('--eval_mode', type=str, default='testing')
    parser.add_argument('--eval_best', dest='eval_best', action='store_true')
    parser.add_argument('--no-eval_best', dest='eval_best', action='store_false')
    parser.set_defaults(eval_best=True)
    parser.add_argument('--save_eval', dest='save_eval', action='store_true')
    parser.add_argument('--no-save_eval', dest='save_eval', action='store_false')
    parser.set_defaults(save_eval=True)

    parser.add_argument('--n_epochs', help='Max iteration for training BC', type=int, default=1000)
    parser.add_argument('--total_eval_episodes', type=int, default=50)
    parser.add_argument('--sigma', type=float, default=1.0)
    parser.add_argument('--dropout_rate', type=float, default=0.0)
    parser.add_argument('--action_noise', type=float, default=0.0)
    parser.add_argument('--ob_noise', type=float, default=0.0)
    # parser.add_argument('--dropout', dest='dropout', action='store_true')
    # parser.add_argument('--no-dropout', dest='dropout', action='store_false')
    # parser.set_defaults(dropout=False)
    return parser.parse_args()


def main(args):
    model_name = 'bc' + \
                 ('' if args.dropout_rate == 0.0 else str(args.dropout_rate)) + \
                 ('' if args.action_noise == 0.0 else '-action-noise{}'.format(args.action_noise)) + \
                 ('' if args.ob_noise == 0.0 else '-ob-noise{}'.format(args.ob_noise))
    det = 'deterministic' if args.deterministic else 'stochastic'
    print("Train Mode: {}, Eval Mode: {}".format(args.train_mode, args.eval_mode))
    if args.env_id == 'Pendulum-v0':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=100)
    elif args.env_id == 'Ant-v2':
        env = TimeLimit(gym.make(args.env_id), max_episode_steps=500)
    else:
        env = gym.make(args.env_id)

    # set the seed so that train-val split is consistent
    np.random.seed(args.seed)
    if args.train_mode != 'testing':
        env = MyObservationWrapper(env, args.train_mode)
        encoded_traj = encode_trajectories(env, args.expert_path, args.seed)
        expert_dataset = ExpertDataset(traj_data=encoded_traj, verbose=1, traj_limitation=args.traj_limitation)
    else:
        expert_dataset = ExpertDataset(expert_path=args.expert_path, verbose=1, traj_limitation=args.traj_limitation)

    # If loading uncomment this
    if args.load_model:
        model = BC.load(args.model_path)
        model.expert_dataset = expert_dataset
    else:
        model = BC('MlpPolicy', env, expert_dataset, deterministic=args.deterministic, verbose=1, val_interval=args.val_interval, seed=args.seed, save_path=args.model_path,
                   dropout_rate=args.dropout_rate, action_noise=args.action_noise, ob_noise=args.ob_noise, exp_code=args.exp_code)

        i_list, train_loss, val_loss = model.learn(args.n_epochs)
        fig, ax = plt.subplots()
        plt.plot(i_list, train_loss, label='train_loss', c='blue')
        plt.plot(i_list, val_loss, label='validation_loss', c='red')
        ax.legend()
        plt.xlabel("epochs")
        plt.ylabel("accuracy")
        plt.savefig(os.path.join(args.log_dir, args.exp_code + "_training_curve_{}_{}_{}.png".format(args.env_id, args.n_epochs, args.seed)))


    if args.save_model:
        model.save(os.path.join(args.checkpoint_dir,
                                "{}_{}_traj{}_{}_{}_{}_{}".format(model_name, args.env_id,
                                                                                        args.traj_limitation,
                                                                                        args.train_mode,
                                                                                        model.trained_epochs,
                                                                                        det, args.seed)))


    if not args.load_model and args.eval_best:
        eval_model = BC.load(os.path.join(args.model_path, args.exp_code+"_best_{}_{}_traj{}_{}_{}_{}".format(model_name,
                                                                                                model.env.unwrapped.spec.id,
                                                                                                model.expert_dataset.num_traj,
                                                                                                model.det,
                                                                                                args.train_mode,
                                                                                                model.seed)))
    else:
        eval_model = model

    # for Pendulum only!
    env_eval = gym.make(args.env_id)
    if args.env_id == 'Pendulum-v0':
        env_eval = TimeLimit(env_eval, max_episode_steps=100)
    if args.env_id == 'Ant-v2':
        env_eval = TimeLimit(gym.make(args.env_id), max_episode_steps=500)

    eval_logs_file = None
    eval_rets_file = None
    if args.save_eval:
        eval_logs_file = os.path.join(args.log_dir,
                                      args.exp_code + "_eval_logs_{}_{}_traj{}_{}_{}_{}_{}".format(model_name,
                                                                                                   args.env_id,
                                                                                                   args.traj_limitation,
                                                                                                   args.train_mode,
                                                                                                   args.eval_mode,
                                                                                                   eval_model.trained_epochs,
                                                                                                   model.seed))
        eval_rets_file = os.path.join(args.log_dir,
                                      args.exp_code + "_episode_rewards_{}_{}_traj{}_{}_{}_{}_{}".format(model_name,
                                                                                                         args.env_id,
                                                                                                         args.traj_limitation,
                                                                                                         args.train_mode,
                                                                                                         args.eval_mode,
                                                                                                         eval_model.trained_epochs,
                                                                                                         args.seed))
        if args.eval_mode == 'testing':
            env_eval = Monitor(env_eval,
                               filename=eval_logs_file)
        elif args.eval_mode in ['confounded', 'original']:
            env_eval = Monitor(MyObservationWrapper(env_eval, args.eval_mode),
                               filename=eval_logs_file)
        else:
            eval_logs_file = os.path.join(args.log_dir,
                                          args.exp_code + "_eval_logs_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(
                                              model_name, args.env_id,
                                              args.traj_limitation,
                                              args.train_mode,
                                              args.eval_mode,
                                              args.sigma,
                                              eval_model.trained_epochs,
                                              model.seed))
            eval_rets_file = os.path.join(args.log_dir,
                                          args.exp_code + "_episode_rewards_{}_{}_traj{}_{}_{}_sigma{}_{}_{}".format(
                                              model_name, args.env_id,
                                              args.traj_limitation,
                                              args.train_mode,
                                              args.eval_mode,
                                              args.sigma,
                                              eval_model.trained_epochs,
                                              model.seed))
            if args.eval_mode == 'noisy':
                env_eval = Monitor(NoisyActionWrapper(env_eval, sigma=args.sigma),
                                   filename=eval_logs_file)
            else:
                assert args.eval_mode == 'noisy-obs'
                env_eval = Monitor(NoisyObservationWrapper(env_eval, sigma=args.sigma),
                                   filename=eval_logs_file)

    env_eval.seed(args.seed)
    eval_model.set_env(env_eval)
    render = False
    traj_rewards, traj_lengths = evaluate_policy(eval_model, env_eval, args.total_eval_episodes, deterministic=True,
                                                 render=render, return_episode_rewards=True)
    np.save(eval_rets_file,
            {"traj_rewards": traj_rewards,
             "traj_lengths": traj_lengths})

    print(np.mean(traj_rewards), np.std(traj_rewards))


if __name__ == '__main__':
    args = argsparser()
    main(args)
