# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from pydantic import BaseModel
from .constants import *


class CommsConfig(BaseModel):

    class Config:
        validate_all = True
        validate_assignment = True
        use_enum_values = True
        extra = 'forbid'


class CommsLoggerConfig(CommsConfig):
    enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
    prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
    prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
    verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
    debug: bool = COMMS_LOGGER_DEBUG_DEFAULT


class DeepSpeedCommsConfig:

    def __init__(self, ds_config):
        self.comms_logger_enabled = 'comms_logger' in ds_config

        if self.comms_logger_enabled:
            self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])
