# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DeepSpeed configuration for training and evaluation."""

from __future__ import annotations

import json
import pathlib
from typing import Any


__all__ = ['TEMPLATE_DIR', 'get_deepspeed_train_config', 'get_deepspeed_eval_config']


TEMPLATE_DIR = pathlib.Path(__file__).absolute().parent
TRAIN_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_train_config_template.json'
EVAL_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_eval_config_template.json'


def get_deepspeed_train_config(
    *,
    batch_size: int = 128,
    micro_batch_size_per_gpu: int = 16,
    stage: int = 3,
    enable_hybrid_engine: bool = False,
    max_length: int = 512,
    fp16: bool = False,
    bf16: bool = False,
) -> dict[str, Any]:
    """Get the DeepSpeed config for training.

    Args:
        batch_size (int, optional): The batch size for training. Defaults to 128.
        micro_batch_size_per_gpu (int, optional): The micro batch size per GPU for training.
            Defaults to 16.
        stage (int, optional): The stage of ZeRO. Defaults to 3.
        enable_hybrid_engine (bool, optional): Whether to enable the DeepSpeed hybrid engine.
            Defaults to False.
        max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
        fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
        bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.

    Returns:
        The DeepSpeed config for training.
    """
    with TRAIN_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
        train_config = json.load(f)

    train_config['train_batch_size'] = batch_size
    train_config['train_micro_batch_size_per_gpu'] = micro_batch_size_per_gpu
    train_config['zero_optimization']['stage'] = stage
    train_config['hybrid_engine']['enabled'] = enable_hybrid_engine
    train_config['hybrid_engine']['max_out_tokens'] = max_length
    if fp16 or 'fp16' in train_config:
        train_config.setdefault('fp16', {})
        train_config['fp16']['enabled'] = fp16
    if bf16 or 'bf16' in train_config:
        train_config.setdefault('bf16', {})
        train_config['bf16']['enabled'] = bf16
    return train_config


def get_deepspeed_eval_config(
    *,
    stage: int = 3,
    fp16: bool = False,
    bf16: bool = False,
) -> dict[str, Any]:
    """Get the DeepSpeed config for evaluation.

    Args:
        stage (int, optional): The stage of ZeRO. Defaults to 3.
        fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
        bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.

    Returns:
        The DeepSpeed config for evaluation.
    """
    with EVAL_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
        eval_config = json.load(f)

    if stage in {1, 2}:
        # The evaluation config only works for ZeRO stage 0 and ZeRO stage 3
        stage = 0

    eval_config['train_batch_size'] = None
    eval_config['train_micro_batch_size_per_gpu'] = 1
    eval_config['gradient_accumulation_steps'] = 1
    eval_config['zero_optimization']['stage'] = stage
    if fp16 or 'fp16' in eval_config:
        eval_config.setdefault('fp16', {})
        eval_config['fp16']['enabled'] = fp16
    if bf16 or 'bf16' in eval_config:
        eval_config.setdefault('bf16', {})
        eval_config['bf16']['enabled'] = bf16
    return eval_config
