"""DeepSpeed configuration for training and evaluation."""

from __future__ import annotations

import json
import pathlib
from typing import Any, Literal

import torch.distributed as dist


__all__ = ['TEMPLATE_DIR', 'get_deepspeed_train_config', 'get_deepspeed_eval_config']


TEMPLATE_DIR = pathlib.Path(__file__).absolute().parent
TRAIN_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_train_config_template.json'
EVAL_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_eval_config_template.json'


def get_deepspeed_train_config(
    *,
    micro_batch_size_per_gpu: int = 16,
    gradient_accumulation_steps: int = 1,
    stage: int = 3,
    offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none',
    enable_hybrid_engine: bool = False,
    max_length: int = 512,
    fp16: bool = False,
    bf16: bool = False,
) -> dict[str, Any]:
    """Get the DeepSpeed config for training.

    Args:
        micro_batch_size_per_gpu (int, optional): The micro batch size per GPU. Defaults to 16.
        gradient_accumulation_steps (int, optional): The number of gradient accumulation steps.
            Defaults to 1.
        stage (int, optional): The stage of ZeRO. Defaults to 3.
        offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode.
        enable_hybrid_engine (bool, optional): Whether to enable the DeepSpeed hybrid engine.
            Defaults to False.
        max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
        fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
        bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.

    Returns:
        The DeepSpeed config for training.
    """
    assert offload in {'none', 'parameter', 'optimizer', 'all'}

    with TRAIN_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
        train_config = json.load(f)

    word_size = dist.get_world_size() if dist.is_initialized() else 1
    train_batch_size = micro_batch_size_per_gpu * word_size * gradient_accumulation_steps

    train_config['train_batch_size'] = train_batch_size
    train_config['train_micro_batch_size_per_gpu'] = micro_batch_size_per_gpu
    train_config['gradient_accumulation_steps'] = gradient_accumulation_steps
    train_config['zero_optimization']['stage'] = stage
    if offload in {'parameter', 'all'}:
        train_config['zero_optimization'].setdefault('offload_param', {})
        train_config['zero_optimization']['offload_param']['device'] = 'cpu'
    if offload in {'optimizer', 'all'}:
        train_config['zero_optimization'].setdefault('offload_optimizer', {})
        train_config['zero_optimization']['offload_optimizer']['device'] = 'cpu'
    train_config['hybrid_engine']['enabled'] = enable_hybrid_engine
    train_config['hybrid_engine']['max_out_tokens'] = max_length
    if fp16 or 'fp16' in train_config:
        train_config.setdefault('fp16', {})
        train_config['fp16']['enabled'] = fp16
    if bf16 or 'bf16' in train_config:
        train_config.setdefault('bf16', {})
        train_config['bf16']['enabled'] = bf16
    return train_config


def get_deepspeed_eval_config(
    *,
    stage: int = 3,
    offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none',
    fp16: bool = False,
    bf16: bool = False,
) -> dict[str, Any]:
    """Get the DeepSpeed config for evaluation.

    Args:
        stage (int, optional): The stage of ZeRO. Defaults to 3.
        offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode.
        fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
        bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.

    Returns:
        The DeepSpeed config for evaluation.
    """
    assert offload in {'none', 'parameter', 'optimizer', 'all'}

    with EVAL_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
        eval_config = json.load(f)

    if stage in {1, 2}:
        # The evaluation config only works for ZeRO stage 0 and ZeRO stage 3
        stage = 0

    eval_config['train_batch_size'] = None
    eval_config['train_micro_batch_size_per_gpu'] = 1
    eval_config['gradient_accumulation_steps'] = 1
    eval_config['zero_optimization']['stage'] = stage
    if offload in {'parameter', 'all'}:
        eval_config['zero_optimization'].setdefault('offload_param', {})
        eval_config['zero_optimization']['offload_param']['device'] = 'cpu'
    if fp16 or 'fp16' in eval_config:
        eval_config.setdefault('fp16', {})
        eval_config['fp16']['enabled'] = fp16
    if bf16 or 'bf16' in eval_config:
        eval_config.setdefault('bf16', {})
        eval_config['bf16']['enabled'] = bf16
    return eval_config
