import pybullet_envs
import matplotlib.pyplot as plt
import namegenerator
from torch._C import device
from tqdm import tqdm
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback, EvalCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from model_based.stable_ppo.custom_ppo import MPPO
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
import click
from stable_baselines3.common.monitor import Monitor
import yaml
import torch
import gym
import numpy as np
import os
import sys
from pathlib import Path
from causal_world.evaluation import CausalWorld
from causal_world.
LIBPATH = str(Path(__file__).parents[1])
sys.path.append(LIBPATH)

# define an environment function

def make_causal_env(rank, task_name):
    def _init():
        task = generate_task


def make_env(env_id, rank, params, log_dir, seed=0):
    def _init():
        env_ = gym.make(env_id)
        env_ = Monitor(env_, log_dir, allow_early_resets=True)
        env_.seed(seed + rank)
        env_.action_space.seed(seed+rank)
        env_.observation_space.seed(seed+rank)
        return env_

    set_random_seed(seed)
    torch.manual_seed(seed)
    return _init


def evaluate(model, env, log_dir, verbose):
    s = env.reset()
    if verbose:
        print(' -> Obs', s)

    def arr_to_str(lst):
        return ','.join(list(map(str, lst)))

    done = False
    count = 0
    all_rewards = []
    all_metrics = []
    all_obs = []

    # You may directly copy the output to a speadsheet with space as the delimeter.
    print(' reward min_tput std_tput avg_tran dead_cell_count icu_count action_params')
    while not done:
        if model is not None:
            action, _ = model.predict(s, deterministic=True)
        else:
            action = None
        s, r, done, _ = env.step(action)
        sls_obs = env.all_sls_obs[-1]
        if verbose:
            print(' -> Obs', s)
        metrics = [rewarder.reward(sls_obs, done) for rewarder in rewarders]
        metrics_format = ['%.2f' % rew for rew in metrics]
        if action is not None:
            if hasattr(env.actor, 'action_offset'):
                # Undo action offset if needed. Action offsets are used
                # when training with the multi-discrete action space
                action = env.actor.action_offset()+action
            print(count, '%.2f' % r, ' '.join(
                metrics_format), arr_to_str(action))
        else:
            print(count, '%.2f' % r, ' '.join(metrics_format), 'baseline')
        count += 1
        all_rewards.append(r)
        all_metrics.append(metrics)
        all_obs.append(sls_obs)

    avg_reward = np.mean(all_rewards)
    avg_metrics = np.mean(all_metrics, axis=0)
    avg_metrics_format = ['%.2f' % rew for rew in avg_metrics]

    print('AVERAGE', '%.2f' % avg_reward, ' '.join(avg_metrics_format))

    fig = make_plots(all_obs)
    plt.tight_layout()
    figures_dir = os.path.join(log_dir, 'figures')
    Path(figures_dir).mkdir(parents=True, exist_ok=True)
    fig.savefig(os.path.join(figures_dir, 'observations.png'))


best_mean_reward, n_steps = -np.inf, 0


def get_callback(eval_env, eval_freq, checkpoint_freq, log_dir, model_path, n_eval_episodes):
    checkpoint_callback = CheckpointCallback(
        save_freq=checkpoint_freq, save_path=model_path)
    eval_callback = EvalCallback(eval_env, best_model_save_path=model_path,
                                 log_path=os.path.join(
                                     log_dir, 'eval_callback_result'),
                                 eval_freq=eval_freq,
                                 n_eval_episodes=n_eval_episodes,
                                 deterministic=True, render=False)
    callback = CallbackList([checkpoint_callback, eval_callback])
    return callback


@click.command()
@click.option('--gpu_device', '-d', default=0, type=int, help='The gpu number want to use')
@click.option('--model_name', '-m', default='MPPO', type=str, help='The RL algorithm name. Not specifying this parameter during evaluation will cause the baseline to be evaluated.')
@click.option('--train', is_flag=True)
@click.option('--eval', is_flag=True)
@click.option('--log-dir', type=str, help='The location for loading or saving data related to the execution. This is where the trained model and figures are saved.')
@click.option('--config', type=str, default='src/py/min_rl/config/rl/ant.yaml', help='Path to yaml config file')
@click.option('--verbose', default=0, type=int, help='Verbosity during evaluation, 1 for printing out the state.')
@click.option('--num_students', default=1, type=int, help='numer of ensamble students')
@click.option('--seed', default=0, type=int, help='seed number ')
@click.option('--num_teachers', default=3, type=int)
def run(gpu_device, train, eval, model_name, log_dir, config, verbose, num_teachers, num_students, seed):
    # Use CPU or GPU
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # if torch.cuda.is_available():
    #     torch.cuda.set_device(gpu_device)

    dir_path = os.path.dirname(os.path.realpath(__file__))
    params = yaml.safe_load(open(config, 'r'))

    eval_env = gym.make(params['env_id'])
    if True:
        if model_name is None:
            raise ValueError('model name must be specified during training.')
        if log_dir is None:
            log_dir = os.path.join('logs/train_logs/', namegenerator.gen())
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        model_path = os.path.join(log_dir, 'model')
        # set env
        train_env = None
        if params['n_processes'] == 1:
            env = gym.make(params['env_id'])
            env.seed(seed)
            env.action_space.seed(seed)
            env.observation_space.seed(seed)
            env = Monitor(env, log_dir, allow_early_resets=True)
            train_env = DummyVecEnv([lambda: env])
        else:
            train_env = SubprocVecEnv(
                [make_env(params['env_id'], i + params['n_processes'], params, log_dir, seed=seed) for i in
                 range(params['n_processes'])],
                start_method='spawn')
        print("####start training#####")
        train_env.reset()
        if model_name == "SAC":
            model = SAC('MlpPolicy', train_env, verbose=1,
                        tensorboard_log=log_dir)
        elif model_name == "PPO":
            model = PPO('MlpPolicy', train_env, verbose=1, tensorboard_log=log_dir,
                        batch_size=params['batch_size'])
        elif model_name == 'A2C':
            model = A2C('MlpPolicy', train_env, verbose=1,
                        tensorboard_log=log_dir)
        elif model_name == 'TD3':
            model = TD3('MlpPolicy', train_env, verbose=1,
                        tensorboard_log=log_dir)
        elif model_name == 'MPPO':
            model = MPPO('MlpPolicy', train_env, verbose=1, tensorboard_log=log_dir,
                         device='cpu',
                        #  n_steps=256,
                         num_teachers=num_teachers,
                         num_students=num_students,
                         batch_size=params['batch_size'])
        else:
            assert NotImplementedError("model name doesn't exist")
        callback = get_callback(eval_env=eval_env, eval_freq=params['eval_freq'],
                                checkpoint_freq=params['checkpoint_freq'], 
                                log_dir=log_dir, 
                                model_path=model_path,
                                n_eval_episodes=params['n_eval_episodes'])
        model.learn(
            total_timesteps=params['total_num_steps'])

    if eval:
        print("####start evaluating#####")
        if log_dir is None:
            raise ValueError('Please specify a log directory when evaluating.')
        if model_name is not None:
            model_path = os.path.join(log_dir, 'model/best_model')
            print('Loading trained model from', model_path)
        if model_name is None:
            model = None  # Evaluate the baseline
        elif model_name == "SAC":
            model = SAC.load(model_path)
        elif model_name == "PPO":
            model = MPPO.load(model_path)
        elif model_name == 'A2C':
            model = A2C.load(model_path)
        elif model_name == 'TD3':
            model = TD3.load(model_path)
        else:
            assert NotImplementedError("model name doesn't exist")

        evaluate(model, eval_env, log_dir, verbose)


if __name__ == '__main__':
    run()
