# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from sagemaker.pytorch import PyTorch
from sagemaker.huggingface import HuggingFace
from sagemaker.estimator import Framework

from syne_tune.backend.sagemaker_backend.custom_framework import CustomFramework
from syne_tune.backend.sagemaker_backend.sagemaker_utils import (
    get_execution_role,
    default_sagemaker_session,
)

logger = logging.getLogger(__name__)


def sagemaker_estimator_factory(
    entry_point: str,
    instance_type: str,
    framework: str = None,
    role: str = None,
    instance_count: int = 1,
    framework_version: str = None,
    py_version: str = None,
    dependencies: list = None,
    **kwargs,
) -> Framework:
    if role is None:
        role = get_execution_role()
    if py_version is None:
        py_version = "py3"
    common_kwargs = dict(
        kwargs,
        instance_type=instance_type,
        instance_count=instance_count,
        role=role,
        sagemaker_session=default_sagemaker_session(),
    )
    if dependencies is not None:
        common_kwargs["dependencies"] = dependencies
    if framework == "PyTorch":
        sm_estimator = PyTorch(
            entry_point,
            framework_version=framework_version,
            py_version=py_version,
            **common_kwargs,
        )
    elif framework == "HuggingFace":
        sm_estimator = HuggingFace(
            py_version,
            entry_point,
            **common_kwargs,
            transformers_version=framework_version,
        )
    else:
        if framework is not None:
            logger.info(
                f"framework = '{framework}' not supported, using " "CustomFramework"
            )
        assert (
            kwargs.get("image_uri") is not None
        ), "CustomFramework requires 'image_uri' to be specified"
        sm_estimator = CustomFramework(entry_point, **common_kwargs)
    return sm_estimator
