'''
=== Evaluate pre-trained model ===
Evaluate IMLB+SBHO+MLB model:
python src/py/min_rl/baseline_algos.py --eval -m PPO --log-dir results/IMLB+SBHO+MLB-sec8-day1 --config src/py/min_rl/config/rl/imlb+sbho+mlb.yaml

Evaluate IMLB+SBHO model:
python src/py/min_rl/baseline_algos.py --eval -m PPO --log-dir results/IMLB+SBHO-sec8-day1 --config src/py/min_rl/config/rl/imlb+sbho.yaml

Evaluate IMLB model:
python src/py/min_rl/baseline_algos.py --eval -m PPO --log-dir results/IMLB-sec8-day1 --config src/py/min_rl/config/rl/imlb.yaml

To evaluate the baseline, do not specify a model name:
python src/py/min_rl/baseline_algos.py --eval --log-dir results/IMLB-sec8-day1-baseline --config src/py/min_rl/config/rl/imlb.yaml

=== Train ===
The command below will start a training job and save the results to logs/train_logs/my_run.
You can modify the config argument based on the control parameters you are interested
in using. The command below will control IMLB, SBHO, and MLB parameters

python src/py/min_rl/baseline_algos.py --train -m PPO --log-dir logs/train_logs/my_run --config src/py/min_rl/config/rl/imlb+sbho+mlb.yaml

'''
import sys
from pathlib import Path
LIBPATH = str(Path(__file__).parents[1])
sys.path.append(LIBPATH)
import os
import numpy as np
import gym
import torch
import yaml
from stable_baselines3.common.monitor import Monitor
import click
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import A2C, PPO, SAC, TD3
from model_based.stable_ppo.custom_ppo import MPPO
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback, EvalCallback
from min_rl import env_utils
from saic5g.envs.sls_reward import SLSRewarder
from saic5g.vis.sls_obs import make_plots
from tqdm import tqdm
import namegenerator
import matplotlib.pyplot as plt

# define an environment function
def make_env(env_id, rank, params, log_dir, seed=0):
    def _init():
        from min_rl import env_utils
        env_ = gym.make(env_id, **params['env_kwargs'])
        env_ = Monitor(env_, log_dir, allow_early_resets=True)
        # env_.seed(seed + rank)
        return env_

    set_random_seed(seed)
    torch.manual_seed(seed)
    return _init

def evaluate(model, env, log_dir, verbose):
    s = env.reset()
    if verbose:
        print(' -> Obs', s)

    def arr_to_str(lst):
        return ','.join(list(map(str, lst)))

    reward_presets = SLSRewarder.get_presets()
    rewarders = [
        SLSRewarder(agent_type='sector_0', **reward_presets['worst_cell_per_ue_tput']['kwargs']), # minIPTput
        SLSRewarder(agent_type='sector_0', **reward_presets['std_per_ue_tput']['kwargs']), # stdIPTput
        SLSRewarder(agent_type='sector_0', **reward_presets['total_tput']['kwargs']), # average transmission
        SLSRewarder(agent_type='sector_0', **reward_presets['dead_cell_count']['kwargs']), # dead cell count
        SLSRewarder(agent_type='sector_0', **reward_presets['icu_count']['kwargs']) # ICU count
    ]

    done=False
    count=0
    all_rewards = []
    all_metrics = []
    all_obs = []

    # You may directly copy the output to a speadsheet with space as the delimeter.
    print(' reward min_tput std_tput avg_tran dead_cell_count icu_count action_params')
    while not done:
        if model is not None:
            action, _ = model.predict(s, deterministic=True)
        else:
            action = None
        s, r, done, _ = env.step(action)
        sls_obs = env.all_sls_obs[-1]
        if verbose:
            print(' -> Obs', s)
        metrics = [rewarder.reward(sls_obs, done) for rewarder in rewarders]
        metrics_format = ['%.2f' % rew for rew in metrics]
        if action is not None:
            if hasattr(env.actor, 'action_offset'):
                # Undo action offset if needed. Action offsets are used
                # when training with the multi-discrete action space
                action = env.actor.action_offset()+action
            print(count, '%.2f' % r, ' '.join(metrics_format), arr_to_str(action))
        else:
            print(count, '%.2f' % r, ' '.join(metrics_format), 'baseline')
        count += 1
        all_rewards.append(r)
        all_metrics.append(metrics)
        all_obs.append(sls_obs)

    avg_reward = np.mean(all_rewards)
    avg_metrics = np.mean(all_metrics, axis=0)
    avg_metrics_format = ['%.2f' % rew for rew in avg_metrics]

    print('AVERAGE', '%.2f' % avg_reward, ' '.join(avg_metrics_format))

    fig = make_plots(all_obs)
    plt.tight_layout()
    figures_dir = os.path.join(log_dir, 'figures')
    Path(figures_dir).mkdir(parents=True, exist_ok=True)
    fig.savefig(os.path.join(figures_dir, 'observations.png'))

best_mean_reward, n_steps = -np.inf, 0

def get_callback(eval_env, eval_freq, checkpoint_freq, log_dir, model_path, n_eval_episodes):
    checkpoint_callback = CheckpointCallback(save_freq=checkpoint_freq, save_path=model_path)
    eval_callback = EvalCallback(eval_env, best_model_save_path=model_path,
                                 log_path=os.path.join(log_dir, 'eval_callback_result'),
                                 eval_freq=eval_freq,
                                 n_eval_episodes=n_eval_episodes,
                                 deterministic=True, render=False)
    callback = CallbackList([checkpoint_callback, eval_callback])
    return callback

def update_sector_day(params, scenario, sector, day):
    if scenario is not None:
        params['env_kwargs']['scenario'] = scenario
    if sector is not None:
        params['env_kwargs']['sector'] = sector
    if day is not None:
        try:
            params['env_kwargs']['day'] = int(day)
        except ValueError:
            assert day == 'all'
            params['env_kwargs']['day'] = day

@click.command()
@click.option('--gpu_device', '-d', default=0, type=int, help='The gpu number want to use')
@click.option('--model_name', '-m', type=str, help='The RL algorithm name. Not specifying this parameter during evaluation will cause the baseline to be evaluated.')
@click.option('--train', is_flag=True)
@click.option('--eval', is_flag=True)
@click.option('--log-dir', type=str, help='The location for loading or saving data related to the execution. This is where the trained model and figures are saved.')
@click.option('--config', type=str, required=True, help='Path to yaml config file')
@click.option('--scenario_set', type=str, required=False, help='Can be 10sector or 320sector.')
@click.option('--sector', type=int, required=False, help='Sector id to use for the scenario.')
@click.option('--day', type=str, required=False, help='Day to use for the scenario. Can be an integer or the string "all"')
@click.option('--delay', required=False, type=int, help='Number of steps by which to delay the observation.')
@click.option('--delay_prep_start', required=False, type=int, help='If delay is not 0, this gives the time step from which to start the preparation steps.')
@click.option('--delay_prep_steps', required=False, type=int, help='If delay is not 0, this gives the number of steps to take for preparation.')
@click.option('--predictor_csv_path', type=str, help='The path to the CSV prediction files.')
@click.option('--predictor_obs_error', default=0, type=int, help='The percentage random error to add on the predicted state.')
@click.option('--verbose', default=0, type=int, help='Verbosity during evaluation, 1 for printing out the state.')
def run(gpu_device, train, eval, model_name, log_dir, config, scenario_set, sector,
        day, predictor_csv_path, predictor_obs_error, delay, delay_prep_start, delay_prep_steps, verbose):
    # Use CPU or GPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.is_available():
        torch.cuda.set_device(gpu_device)

    dir_path = os.path.dirname(os.path.realpath(__file__))
    params = yaml.safe_load(open(config, 'r'))

    update_sector_day(params, scenario_set, sector, day)
    if delay is not None:
        params['env_kwargs'].update({
            'delay': delay
        })
    if delay_prep_start is not None:
        params['env_kwargs'].update({
            'delay_prep_start': delay_prep_start
        })
    if delay_prep_steps is not None:
        params['env_kwargs'].update({
            'delay_prep_steps': delay_prep_steps
        })
    if predictor_csv_path is not None:
        params['env_kwargs'].update({
            'predictor_csv_path': predictor_csv_path,
            'predictor_obs_error': predictor_obs_error,
        })
    params['env_kwargs'].update({'verbose': verbose})

    eval_env = gym.make(params['env_id'], **params['env_kwargs'])
    if train:
        if model_name is None:
            raise ValueError('model name must be specified during training.')
        if log_dir is None:
            log_dir = os.path.join('logs/train_logs/', namegenerator.gen())
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        model_path = os.path.join(log_dir, 'model')
        # set env
        train_env = None
        if params['n_processes'] == 1:
            env = gym.make(params['env_id'], **params['env_kwargs'])
            env = Monitor(env, log_dir, allow_early_resets=True)
            train_env = DummyVecEnv([lambda: env])
        else:
            train_env = SubprocVecEnv(
                [make_env(params['env_id'], i + params['n_processes'], params, log_dir) for i in
                 range(params['n_processes'])],
                start_method='spawn')
        print("####start training#####")
        train_env.reset()
        if model_name == "SAC":
            model = SAC('MlpPolicy', train_env, verbose=0, tensorboard_log=log_dir)
        elif model_name == "PPO":
            model = PPO('MlpPolicy', train_env, verbose=1, tensorboard_log=log_dir, batch_size=params['batch_size'], n_steps=params['env_kwargs']['n_steps'])
        elif model_name == 'A2C':
            model = A2C('MlpPolicy', train_env, verbose=0, tensorboard_log=log_dir)
        elif model_name == 'TD3':
            model = TD3('MlpPolicy', train_env, verbose=0, tensorboard_log=log_dir)
        elif model_name == 'MPPO':
            model = MPPO('MlpPolicy', train_env, verbose=0, tensorboard_log=log_dir, batch_size=params['batch_size'], n_steps=params['env_kwargs']['n_steps'])
        else:
            assert NotImplementedError("model name doesn't exist")
        callback = get_callback(eval_env=eval_env, eval_freq=params['eval_freq'],
                                checkpoint_freq=params['checkpoint_freq'], log_dir=log_dir, model_path=model_path,
                                n_eval_episodes=params['n_eval_episodes'])
        model.learn(total_timesteps=params['total_num_steps'], callback=callback)

    if eval:
        print("####start evaluating#####")
        if log_dir is None:
            raise ValueError('Please specify a log directory when evaluating.')
        if model_name is not None:
            model_path = os.path.join(log_dir,'model/best_model')
            print('Loading trained model from', model_path)
        if model_name is None:
            model = None # Evaluate the baseline
        elif model_name == "SAC":
            model = SAC.load(model_path)
        elif model_name == "PPO":
            model = MPPO.load(model_path)
        elif model_name == 'A2C':
            model = A2C.load(model_path)
        elif model_name == 'TD3':
            model = TD3.load(model_path)
        else:
            assert NotImplementedError("model name doesn't exist")

        evaluate(model, eval_env, log_dir, verbose)


if __name__ == '__main__':
    run()
