# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The main training script to train RLHF using PPO algorithm."""

import argparse

import deepspeed
import torch
import torch.distributed as dist
from transformers import SchedulerType
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available

from safe_rlhf.algorithms.ppo.trainer import PPOTrainer
from safe_rlhf.configs import get_deepspeed_eval_config, get_deepspeed_train_config
from safe_rlhf.datasets import parse_dataset
from safe_rlhf.logger import set_logger_level
from safe_rlhf.utils import seed_everything, str2bool


def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        prog='deepspeed --module safe_rlhf.algorithms.ppo',
        description='Train language model using RLHF with PPO algorithm.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--actor_model_name_or_path',
        type=str,
        help='Path to the model checkpoint or its name.',
        required=True,
    )
    model_parser.add_argument(
        '--reward_model_name_or_path',
        type=str,
        help='Path to the model checkpoint or its name.',
        required=True,
    )
    model_parser.add_argument(
        '--reward_critic_model_name_or_path',
        type=str,
        help='Path to the model checkpoint or its name.',
    )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    model_parser.add_argument(
        '--trust_remote_code',
        type=str2bool,
        default=False,
        help='Whether to trust the remote code.',
    )

    # Dataset
    dataset_parser = parser.add_argument_group('dataset')
    dataset_parser.add_argument(
        '--train_datasets',
        type=parse_dataset,
        nargs='+',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
        required=True,
    )
    dataset_parser.add_argument(
        '--ptx_datasets',
        type=parse_dataset,
        nargs='*',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
    )
    dataset_parser.add_argument(
        '--eval_datasets',
        type=parse_dataset,
        nargs='+',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
    )

    # Training
    training_parser = parser.add_argument_group('training')
    training_parser.add_argument(
        '--kl_coeff',
        type=float,
        default=0.02,
        help='The coefficient for the KL divergence between the reference and actor policy.',
    )
    training_parser.add_argument(
        '--clip_range_ratio',
        type=float,
        default=0.2,
        help='The clipping range for ratio between the old and new policy.',
    )
    training_parser.add_argument(
        '--ptx_coeff',
        type=float,
        default=0.0,
        help='The coefficient for the ptx loss.',
    )
    training_parser.add_argument(
        '--epochs',
        type=int,
        default=1,
        help='Total number of training epochs to perform.',
    )
    training_parser.add_argument(
        '--update_iters',
        type=int,
        default=1,
        help='The number of repeated updates on a generated batch.',
    )
    training_parser.add_argument(
        '--per_device_prompt_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the training dataloader.',
    )
    training_parser.add_argument(
        '--per_device_train_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the training dataloader.',
    )
    training_parser.add_argument(
        '--per_device_eval_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the evaluation dataloader.',
    )
    training_parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help='Number of updates steps to accumulate before performing a backward/update pass.',
    )
    training_parser.add_argument(
        '--actor_lr',
        type=float,
        default=1e-5,
        help='Initial learning rate (after the potential warmup period) for the actor model training.',
    )
    training_parser.add_argument(
        '--actor_weight_decay',
        type=float,
        default=0.0,
        help='Weight decay to for the actor model training.',
    )
    training_parser.add_argument(
        '--critic_lr',
        type=float,
        default=5e-6,
        help='Initial learning rate (after the potential warmup period) for the critic model training.',
    )
    training_parser.add_argument(
        '--critic_weight_decay',
        type=float,
        default=0.0,
        help='Weight decay to for the critic model training.',
    )
    training_parser.add_argument(
        '--lr_scheduler_type',
        type=SchedulerType,
        default='cosine',
        help='The scheduler type to use.',
        choices=[
            'linear',
            'cosine',
            'cosine_with_restarts',
            'polynomial',
            'constant',
            'constant_with_warmup',
        ],
    )
    training_parser.add_argument(
        '--num_warmup_steps',
        type=int,
        default=0,
        help='Number of steps for the warmup in the lr scheduler.',
    )
    training_parser.add_argument(
        '--actor_gradient_checkpointing',
        action='store_true',
        help='Enable gradient checkpointing for Actor model.',
    )
    training_parser.add_argument(
        '--critic_gradient_checkpointing',
        action='store_true',
        help='Enable gradient checkpointing for Critic model.',
    )
    training_parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help='A seed for reproducible training.',
    )
    training_parser.add_argument(
        '--fp16',
        type=str2bool,
        default=False,
        help='Whether to use float16 precision.',
    )
    training_parser.add_argument(
        '--bf16',
        type=str2bool,
        default=False,
        help='Whether to use bfloat16 precision.',
    )
    training_parser.add_argument(
        '--tf32',
        type=str2bool,
        default=None,
        help='Whether to use tf32 mix precision.',
    )

    # Generation Config
    generation_parser = parser.add_argument_group('generation')
    generation_parser.add_argument(
        '--temperature',
        type=float,
        default=1.0,
        help='The value used to module the next token probabilities.',
    )
    generation_parser.add_argument(
        '--top_p',
        type=float,
        default=1.0,
        help=(
            'If set to float < 1, only the smallest set of most probable tokens with '
            'probabilities that add up to`top_p` or higher are kept for generation.'
        ),
    )
    generation_parser.add_argument(
        '--num_return_sequences',
        type=int,
        default=1,
        help='The number of independently computed returned sequences for each element in the batch.',
    )
    generation_parser.add_argument(
        '--diversity_penalty',
        type=float,
        default=0.0,
        help='The parameter to penalize new tokens based on their existing frequency in the text so far.',
    )
    generation_parser.add_argument(
        '--repetition_penalty',
        type=float,
        default=1.0,
        help='The parameter for repetition penalty. 1.0 means no penalty.',
    )

    # Evaluation
    evaluation_parser = parser.add_argument_group('evaluation')
    evaluation_parser.add_argument(
        '--eval_strategy',
        type=str,
        default='epoch',
        help='The evaluation strategy to adopt.',
        choices=['epoch', 'steps'],
    )
    evaluation_parser.add_argument(
        '--eval_interval',
        type=int,
        default=1000000,
        help='The interval to evaluate the model.',
    )
    evaluation_parser.add_argument(
        '--need_eval',
        default=False,
        help='Whether to evaluate the model during training.',
        action='store_true',
    )
    evaluation_parser.add_argument(
        '--eval_split_ratio',
        type=float,
        help='The split ratio of the evaluation dataset.',
    )

    # Logging
    logging_parser = parser.add_argument_group('logging')
    logging_parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Where to store the model.',
    )
    logging_parser.add_argument(
        '--log_type',
        type=str,
        help='The type of logging.',
        default='wandb',
        choices=['wandb', 'tensorboard'],
    )
    logging_parser.add_argument(
        '--log_dir',
        type=str,
        help='The directory to store the logs.',
        default=None,
    )
    logging_parser.add_argument(
        '--log_project',
        type=str,
        help='The project name for logging.',
        default=None,
    )
    logging_parser.add_argument(
        '--log_run_name',
        type=str,
        help='The run name for logging.',
        default=None,
    )
    logging_parser.add_argument(
        '--save_16bit',
        action='store_true',
        help='Whether to save the model in 16-bit precision.',
    )
    logging_parser.add_argument(
        '--save_interval',
        type=int,
        default=1000000,
        help='The interval to save the model.',
    )

    # DeepSpeed
    deepspeed_parser = parser.add_argument_group('deepspeed')
    deepspeed_parser.add_argument(
        '--local_rank',
        type=int,
        default=-1,
        help='Local rank for distributed training on GPUs',
    )
    deepspeed_parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        help='ZeRO optimization stage for models.',
    )
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args()
    if args.local_rank == -1:
        parser.error('`local_rank` not set, please use DeepSpeed launcher to run this script.')
    if args.fp16 and args.bf16:
        parser.error('Cannot use both bf16 and fp16 precision.')
    if args.bf16 and not is_torch_bf16_gpu_available():
        parser.error(
            'bf16 precision is not supported on this GPU. '
            'Please disable `--bf16` flag or use another precision flag (e.g., `--fp16`).',
        )
    if args.tf32 is not None and is_torch_tf32_available():
        torch.backends.cuda.matmul.allow_tf32 = args.tf32

    return args


def main() -> None:
    """Main training routine."""
    args = parse_arguments()

    deepspeed.init_distributed()

    args.global_rank = dist.get_rank()
    args.device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(args.device)
    seed_everything(args.seed)
    set_logger_level()

    dist.barrier()

    ds_train_config = get_deepspeed_train_config(
        batch_size=(
            args.per_device_train_batch_size
            * dist.get_world_size()
            * args.gradient_accumulation_steps
        ),
        micro_batch_size_per_gpu=args.per_device_train_batch_size,
        stage=args.zero_stage,
        fp16=args.fp16,
        bf16=args.bf16,
    )

    ds_eval_config = get_deepspeed_eval_config(
        stage=args.zero_stage,
        fp16=args.fp16,
        bf16=args.bf16,
    )

    trainer = PPOTrainer(args, ds_train_config, ds_eval_config)
    trainer.train()
    trainer.save()


if __name__ == '__main__':
    main()
