import argparse

from pprint import pprint


def get_args() -> argparse.Namespace:
    """ Get the arguments from the command line.

    Returns:
        argparse.Namespace: The arguments from the command line.
    """

    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-bs',
        '--batch_size',
        type=int,
        required=True,
        help='The batch size for training.',
    )
    parser.add_argument(
        '-mbs',
        '--micro_batch_size',
        type=int,
        required=True,
        help='The micro batch size for training.',
    )
    parser.add_argument(
        '-gn',
        '--gpu_number',
        type=int,
        required=True,
        help='The number of GPUs for training.',
    )
    parser.add_argument(
        '-dn',
        '--data_number',
        default=-1,
        type=int,
        required=False,
        help=
        'The number of training data. If not provided, the program will calculate it based on one GPU steps.',
    )
    parser.add_argument(
        '-ogs',
        '--one_gpu_steps',
        default=-1,
        type=int,
        required=False,
        help=
        'The number of steps for one GPU. If not provided, the program will calculate it based on data number.',
    )
    parser.add_argument(
        '-lsr',
        '--log_steps_ratio',
        default=0.01,
        type=float,
        required=False,
        help=
        'The ratio (0 to 1) of log steps to one GPU steps. The default value is 0.01.',
    )
    parser.add_argument(
        '-fssr',
        '--full_save_steps_ratio',
        default=0.2,
        type=float,
        required=False,
        help=
        'The ratio (0 to 1) of full save steps to one GPU steps. The default value is 0.2.',
    )
    parser.add_argument(
        '-pssr',
        '--partial_save_steps_ratio',
        default=0.1,
        type=float,
        required=False,
        help=
        'The ratio (0 to 1) of partial save steps to one GPU steps. The default value is 0.1.',
    )

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()

    batch_size = args.batch_size
    micro_batch_size = args.micro_batch_size
    gpu_number = args.gpu_number

    assert batch_size > 0, 'The batch size must be greater than 0.'
    assert micro_batch_size > 0, 'The micro batch size must be greater than 0.'
    assert gpu_number > 0, 'The GPU number must be greater than 0.'
    assert batch_size >= micro_batch_size and batch_size % micro_batch_size == 0, 'The batch size must be greater than or equal to the micro batch size and divisible by it.'
    assert batch_size >= gpu_number and batch_size % gpu_number == 0, 'The batch size must be greater than or equal to the GPU number and divisible by it.'
    assert batch_size >= micro_batch_size * gpu_number and batch_size % (
        micro_batch_size * gpu_number
    ) == 0, 'The batch size must be greater than or equal to the product of the micro batch size and GPU number and divisible by it.'

    gradient_accumulation_steps = int(batch_size / micro_batch_size /
                                      gpu_number)

    log_steps = None
    full_save_steps = None
    partial_save_steps = None
    update_steps = None
    warmup_steps = None

    data_number = args.data_number
    one_gpu_steps = args.one_gpu_steps

    assert data_number != -1 or one_gpu_steps != -1, 'Either the number of training data or the number of steps for one GPU must be provided.'

    match data_number:
        case -1:
            assert one_gpu_steps > 0, 'The number of steps for one GPU must be greater than 0 (if provided).'

            data_number = one_gpu_steps * micro_batch_size * gpu_number
        case _:
            assert data_number > 0, 'The number of training data must be greater than 0 (if provided).'
            assert data_number >= batch_size and data_number % batch_size == 0, 'The number of training data must be greater than or equal to the batch size and divisible by it.'

            one_gpu_steps = int(data_number / micro_batch_size / gpu_number)

    log_steps = int(round(one_gpu_steps * args.log_steps_ratio))
    full_save_steps = int(round(one_gpu_steps * args.full_save_steps_ratio))
    partial_save_steps = int(
        round(one_gpu_steps * args.partial_save_steps_ratio))
    update_steps = int(one_gpu_steps / gradient_accumulation_steps)
    warmup_steps = int(round(update_steps * 0.1))

    result_dict = {
        'batch_size': batch_size,
        'micro_batch_size': micro_batch_size,
        'gpu_number': gpu_number,
        'gradient_accumulation_steps': gradient_accumulation_steps,
        'data_number': data_number,
        'one_gpu_steps': one_gpu_steps,
        'log_steps': log_steps,
        'full_save_steps': full_save_steps,
        'partial_save_steps': partial_save_steps,
        'update_steps': update_steps,
        'warmup_steps': warmup_steps,
    }
    pprint(result_dict)
