#!/usr/bin/env python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from ...utils.dataclasses import (
    ComputeEnvironment,
    DistributedType,
    DynamoBackend,
    FP8BackendType,
    PrecisionType,
    SageMakerDistributedType,
)
from ..menu import BulletMenu


DYNAMO_BACKENDS = [
    "EAGER",
    "AOT_EAGER",
    "INDUCTOR",
    "AOT_TS_NVFUSER",
    "NVPRIMS_NVFUSER",
    "CUDAGRAPHS",
    "OFI",
    "FX2TRT",
    "ONNXRT",
    "TENSORRT",
    "AOT_TORCHXLA_TRACE_ONCE",
    "TORHCHXLA_TRACE_ONCE",
    "IPEX",
    "TVM",
]


def _ask_field(input_text, convert_value=None, default=None, error_message=None):
    ask_again = True
    while ask_again:
        result = input(input_text)
        try:
            if default is not None and len(result) == 0:
                return default
            return convert_value(result) if convert_value is not None else result
        except Exception:
            if error_message is not None:
                print(error_message)


def _ask_options(input_text, options=[], convert_value=None, default=0):
    menu = BulletMenu(input_text, options)
    result = menu.run(default_choice=default)
    return convert_value(result) if convert_value is not None else result


def _convert_compute_environment(value):
    value = int(value)
    return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])


def _convert_distributed_mode(value):
    value = int(value)
    return DistributedType(
        [
            "NO",
            "MULTI_CPU",
            "MULTI_XPU",
            "MULTI_HPU",
            "MULTI_GPU",
            "MULTI_NPU",
            "MULTI_MLU",
            "MULTI_SDAA",
            "MULTI_MUSA",
            "XLA",
        ][value]
    )


def _convert_dynamo_backend(value):
    value = int(value)
    return DynamoBackend(DYNAMO_BACKENDS[value]).value


def _convert_mixed_precision(value):
    value = int(value)
    return PrecisionType(["no", "fp16", "bf16", "fp8"][value])


def _convert_sagemaker_distributed_mode(value):
    value = int(value)
    return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])


def _convert_fp8_backend(value):
    value = int(value)
    return FP8BackendType(["TE", "MSAMP"][value])


def _convert_yes_no_to_bool(value):
    return {"yes": True, "no": False}[value.lower()]


class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
    """
    A custom formatter that will remove the usage line from the help message for subcommands.
    """

    def _format_usage(self, usage, actions, groups, prefix):
        usage = super()._format_usage(usage, actions, groups, prefix)
        usage = usage.replace("<command> [<args>] ", "")
        return usage
