# generate.py
# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Generate responses from a model given a prompt dataset.

Usage:
    deepspeed --module safe_rlhf.evaluate.generate \
        --model_name_or_path <model_path> \
        --datasets <dataset_name> \
        --output_file <path_to_save_responses>
"""

from __future__ import annotations

import argparse
import json
import os
from typing import Any

import deepspeed
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available

from safe_rlhf.configs import get_deepspeed_eval_config
from safe_rlhf.datasets import PromptOnlyDataset, parse_dataset
from safe_rlhf.logger import set_logger_level
from safe_rlhf.models import load_pretrained_models
from safe_rlhf.utils import is_main_process, seed_everything, str2bool, to_device


def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        prog='deepspeed --module safe_rlhf.evaluate.generate',
        description='Generate responses from a model.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Model
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--model_name_or_path',
        type=str,
        help='The name or path of the model to load from.',
        required=True,
    )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    model_parser.add_argument(
        '--trust_remote_code',
        type=str2bool,
        default=False,
        help='Whether to trust the remote code.',
    )

    # Dataset
    dataset_parser = parser.add_argument_group('dataset')
    dataset_parser.add_argument(
        '--datasets',
        type=parse_dataset,
        nargs='+',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
        required=True,
    )

    # Generation
    evaluation_parser = parser.add_argument_group('generation')
    evaluation_parser.add_argument(
        '--per_device_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the dataloader.',
    )
    evaluation_parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help='A seed for reproducible generation.',
    )
    evaluation_parser.add_argument(
        '--fp16',
        type=str2bool,
        default=False,
        help='Whether to use float16 precision.',
    )
    evaluation_parser.add_argument(
        '--bf16',
        type=str2bool,
        default=False,
        help='Whether to use bfloat16 precision.',
    )
    evaluation_parser.add_argument(
        '--tf32',
        type=str2bool,
        default=None,
        help='Whether to use tf32 mix precision.',
    )

    # Logging
    logging_parser = parser.add_argument_group('logging')
    logging_parser.add_argument(
        '--output_file',
        type=str,
        default=None,
        required=True,
        help='Where to store the generated responses, in JSONL format.',
    )

    # DeepSpeed
    deepspeed_parser = parser.add_argument_group('deepspeed')
    deepspeed_parser.add_argument(
        '--local_rank',
        type=int,
        default=-1,
        help='Local rank for distributed training on GPUs',
    )
    deepspeed_parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        choices=[0, 1, 2, 3],
        help='ZeRO optimization stage for models.',
    )
    deepspeed_parser.add_argument(
        '--offload',
        type=str,
        default='none',
        choices=['none', 'parameter', 'optimizer', 'all'],
        help='Offload parameters and/or optimizer states to CPU.',
    )
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args()
    if args.local_rank == -1:
        parser.error('`local_rank` not set, please use DeepSpeed launcher to run this script.')
    if args.fp16 and args.bf16:
        parser.error('Cannot use both bf16 and fp16 precision.')
    if args.bf16 and not is_torch_bf16_gpu_available():
        parser.error(
            'bf16 precision is not supported on this GPU. '
            'Please disable `--bf16` flag or use another precision flag (e.g., `--fp16`).',
        )
    if args.tf32 is not None and is_torch_tf32_available():
        torch.backends.cuda.matmul.allow_tf32 = args.tf32

    return args


def batch_generate(
    batch: dict[str, torch.Tensor],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    args: argparse.Namespace,
    generation_config: GenerationConfig,
) -> list[dict[str, Any]]:
    """Generate responses for a batch of prompts."""
    batch = to_device(batch, args.device)
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            generation_config=generation_config,
            synced_gpus=True,
        )

    dist.barrier()

    # Gather all output_ids
    max_length = torch.tensor(output_ids.size(-1), dtype=torch.long, device=args.device)
    dist.all_reduce(max_length, op=dist.ReduceOp.MAX)
    pad_length = max_length.item() - output_ids.size(-1)
    if pad_length > 0:
        output_ids = F.pad(
            output_ids,
            (0, pad_length),
            mode='constant',
            value=tokenizer.pad_token_id,
        )

    if is_main_process():
        gathered_output_ids = [torch.empty_like(output_ids) for _ in range(dist.get_world_size())]
    else:
        gathered_output_ids = []

    dist.gather(output_ids, gathered_output_ids, dst=0)

    generations = []
    if is_main_process():
        gathered_output_ids = torch.cat(gathered_output_ids, dim=0)
        responses = tokenizer.batch_decode(gathered_output_ids, skip_special_tokens=True)
        prompts = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)

        for prompt, response in zip(prompts, responses):
            generations.append({'prompt': prompt, 'response_text': response})

    return generations


def main() -> None:
    """Main function for running generation."""
    args = parse_arguments()

    deepspeed.init_distributed()

    args.global_rank = dist.get_rank()
    args.device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(args.device)
    seed_everything(args.seed)
    set_logger_level()

    dist.barrier()

    ds_config = get_deepspeed_eval_config(
        stage=args.zero_stage,
        fp16=args.fp16,
        bf16=args.bf16,
    )

    if ds_config['zero_optimization']['stage'] == 3:
        args.dschf = HfDeepSpeedConfig(ds_config)

    model, tokenizer = load_pretrained_models(
        args.model_name_or_path,
        model_max_length=args.max_length,
        padding_side='left',
        auto_model_type=AutoModelForCausalLM,
        trust_remote_code=args.trust_remote_code,
    )

    model, *_ = deepspeed.initialize(model=model, config=ds_config)
    model.eval()

    dataset = PromptOnlyDataset(args.datasets, tokenizer)
    dataloader = DataLoader(
        dataset,
        collate_fn=dataset.get_collator(),
        sampler=DistributedSampler(dataset, shuffle=False),
        batch_size=args.per_device_batch_size,
    )

    dist.barrier()

    generation_config = GenerationConfig(
        max_length=args.max_length,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.0,
        do_sample=True,
        num_return_sequences=1,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    all_generations = []
    for batch in tqdm(
        dataloader,
        desc='Generating',
        disable=not is_main_process(),
    ):
        generations = batch_generate(batch, model, tokenizer, args, generation_config)
        if is_main_process():
            all_generations.extend(generations)

    if is_main_process():
        output_dir = os.path.dirname(args.output_file)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)

        with open(args.output_file, 'w', encoding='utf-8') as f:
            for item in all_generations:
                f.write(json.dumps(item) + '\n')
        print(f"Responses saved to {args.output_file}")


if __name__ == '__main__':
    main()