"""
Script to train the reasoner model.

Usage:
    train_rl.py [options]

Options:
    -h --help                               Show this screen.

    --load-model-from LFM                   Path to the model to be loaded.

    --model-name MN                         Model name.

    --mode (multiple_samples|one_sample)    In what mode to train.

    --processors PS                         Which processors to use. String of comma separated values.
                                            [default: MPNN]

    --freeze-proc                           Whether to freeze processor's weights.

    --seed S                                Random seed to set. [default: 47]
"""

import os
import schema
import  pytorch_lightning as pl
from docopt import docopt
from datetime import datetime
from models.policy_net import PolicyNet
from envs.tsp_env import TSPDistCost
from models.ppo_mask_recurrent import MaskableRecurrentPPO
from functools import partial
from models.gnns import _PROCESSSOR_DICT
import wandb
from wandb.integration.sb3 import WandbCallback
from utils.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CallbackList

if __name__ == '__main__':
    schema = schema.Schema({
        '--help': bool,
        '--load-model-from': schema.Or(None, os.path.exists),
        '--model-name': schema.Or(None, schema.Use(str)),
        '--mode': lambda m: m in ['multiple_samples', 'one_sample'],
        '--processors': schema.And(schema.Use(lambda x: x.split(',')), lambda lst: all(x in _PROCESSSOR_DICT for x in lst)),
        '--freeze-proc': bool,
        '--seed': schema.Use(int)
    })

    args = docopt(__doc__)
    args = schema.validate(args)
    name = args['--model-name'] if args['--model-name'] is not None else datetime.now().strftime('%b-%d-%Y-%H-%M')
    run = wandb.init(
        project='conar',
        entity='d-n-d',
        config=args,
        group='RL',
        sync_tensorboard=True)

    pl.utilities.seed.seed_everything(args['--seed'])

    tr_env = TSPDistCost('train',
                         mask_invalid_actions=True,
                         optimisation_mode=args['--mode'])
    _ = tr_env.reset()

    # Separate evaluation env
    eval_env = Monitor(
        TSPDistCost('val',
                    mask_invalid_actions=True,
                    optimisation_mode='multiple_samples')
    )

    SAVE_PATH = f'./serialised_models/rl/{args["--model-name"]}'
    LOG_PATH = f'./logs/{run.id}'
    for pth in [SAVE_PATH, LOG_PATH]:
        if not os.path.exists(pth):
            os.makedirs(pth)
    eval_callback = EvalCallback(eval_env, best_model_save_path=SAVE_PATH,
                                 log_path=LOG_PATH, eval_freq=4096,
                                 deterministic=False, render=False)

    PolicyNetSB = partial(PolicyNet,
                          spec=tr_env.dataset.spec,
                          data=tr_env.state,
                          processors=args['--processors'],
                          load_processor=args['--load-model-from'],
                          freeze_processor=args['--freeze-proc'])

    N_STEPS = 2048
    BATCH_SIZE = 8192
    EPOCHS = 32
    MAX_GRAD_NORM = 0.2
    CLIP_RANGE_VF = 0.2
    CLIP_RANGE = 0.2
    LEARNING_RATE = 3e-4
    run.config.update({
        'batch_size': BATCH_SIZE,
        'n_epochs': EPOCHS,
        'max_grad_norm': MAX_GRAD_NORM,
        'clip_range_vf': CLIP_RANGE_VF,
        'clip_range': CLIP_RANGE
    })

    model = MaskableRecurrentPPO(PolicyNetSB,
                                 tr_env,
                                 verbose=1,
                                 n_steps=N_STEPS,
                                 learning_rate=LEARNING_RATE,
                                 batch_size=BATCH_SIZE,
                                 n_epochs=EPOCHS,
                                 max_grad_norm=MAX_GRAD_NORM,
                                 clip_range=CLIP_RANGE,
                                 clip_range_vf=CLIP_RANGE_VF,
                                 tensorboard_log=LOG_PATH)

    model.learn(total_timesteps=350_000,
                callback=CallbackList([eval_callback, WandbCallback(model_save_path=SAVE_PATH, gradient_save_freq=100)]),
                log_interval=1)

    run.finish()
