#!/usr/bin/env python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union

import yaml

from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION


hf_cache_home = os.path.expanduser(
    os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
cache_dir = os.path.join(hf_cache_home, "accelerate")
default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")

# For backward compatibility: the default config is the json one if it's the only existing file.
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
    default_config_file = default_yaml_config_file
else:
    default_config_file = default_json_config_file


def load_config_from_file(config_file):
    if config_file is not None:
        if not os.path.isfile(config_file):
            raise FileNotFoundError(
                f"The passed configuration file `{config_file}` does not exist. "
                "Please pass an existing file to `accelerate launch`, or use the default one "
                "created through `accelerate config` and run `accelerate launch` "
                "without the `--config_file` argument."
            )
    else:
        config_file = default_config_file
    with open(config_file, encoding="utf-8") as f:
        if config_file.endswith(".json"):
            if (
                json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
                == ComputeEnvironment.LOCAL_MACHINE
            ):
                config_class = ClusterConfig
            else:
                config_class = SageMakerConfig
            return config_class.from_json_file(json_file=config_file)
        else:
            if (
                yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
                == ComputeEnvironment.LOCAL_MACHINE
            ):
                config_class = ClusterConfig
            else:
                config_class = SageMakerConfig
            return config_class.from_yaml_file(yaml_file=config_file)


@dataclass
class BaseConfig:
    compute_environment: ComputeEnvironment
    distributed_type: Union[DistributedType, SageMakerDistributedType]
    mixed_precision: str
    use_cpu: bool
    debug: bool

    def to_dict(self):
        result = self.__dict__
        # For serialization, it's best to convert Enums to strings (or their underlying value type).

        def _convert_enums(value):
            if isinstance(value, Enum):
                return value.value
            if isinstance(value, dict):
                if not bool(value):
                    return None
                for key1, value1 in value.items():
                    value[key1] = _convert_enums(value1)
            return value

        for key, value in result.items():
            result[key] = _convert_enums(value)
        result = {k: v for k, v in result.items() if v is not None}
        return result

    @staticmethod
    def process_config(config_dict):
        """
        Processes `config_dict` and sets default values for any missing keys
        """
        if "compute_environment" not in config_dict:
            config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
        if "distributed_type" not in config_dict:
            raise ValueError("A `distributed_type` must be specified in the config file.")
        if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
            config_dict["num_processes"] = 1
        if "mixed_precision" not in config_dict:
            config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
        if "fp16" in config_dict:  # Convert the config to the new format.
            del config_dict["fp16"]
        if "dynamo_backend" in config_dict:  # Convert the config to the new format.
            dynamo_backend = config_dict.pop("dynamo_backend")
            config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
        if "use_cpu" not in config_dict:
            config_dict["use_cpu"] = False
        if "debug" not in config_dict:
            config_dict["debug"] = False
        if "enable_cpu_affinity" not in config_dict:
            config_dict["enable_cpu_affinity"] = False
        return config_dict

    @classmethod
    def from_json_file(cls, json_file=None):
        json_file = default_json_config_file if json_file is None else json_file
        with open(json_file, encoding="utf-8") as f:
            config_dict = json.load(f)
        config_dict = cls.process_config(config_dict)
        extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
        if len(extra_keys) > 0:
            raise ValueError(
                f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
                " version or fix (and potentially remove) these keys from your config file."
            )

        return cls(**config_dict)

    def to_json_file(self, json_file):
        with open(json_file, "w", encoding="utf-8") as f:
            content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
            f.write(content)

    @classmethod
    def from_yaml_file(cls, yaml_file=None):
        yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
        with open(yaml_file, encoding="utf-8") as f:
            config_dict = yaml.safe_load(f)
        config_dict = cls.process_config(config_dict)
        extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
        if len(extra_keys) > 0:
            raise ValueError(
                f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
                " version or fix (and potentially remove) these keys from your config file."
            )
        return cls(**config_dict)

    def to_yaml_file(self, yaml_file):
        with open(yaml_file, "w", encoding="utf-8") as f:
            yaml.safe_dump(self.to_dict(), f)

    def __post_init__(self):
        if isinstance(self.compute_environment, str):
            self.compute_environment = ComputeEnvironment(self.compute_environment)
        if isinstance(self.distributed_type, str):
            if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
                self.distributed_type = SageMakerDistributedType(self.distributed_type)
            else:
                self.distributed_type = DistributedType(self.distributed_type)
        if getattr(self, "dynamo_config", None) is None:
            self.dynamo_config = {}


@dataclass
class ClusterConfig(BaseConfig):
    num_processes: int = -1  # For instance if we use SLURM and the user manually passes it in
    machine_rank: int = 0
    num_machines: int = 1
    gpu_ids: Optional[str] = None
    main_process_ip: Optional[str] = None
    main_process_port: Optional[int] = None
    rdzv_backend: Optional[str] = "static"
    same_network: Optional[bool] = False
    main_training_function: str = "main"
    enable_cpu_affinity: bool = False

    # args for FP8 training
    fp8_config: dict = None
    # args for deepspeed_plugin
    deepspeed_config: dict = None
    # args for fsdp
    fsdp_config: dict = None
    # args for megatron_lm
    megatron_lm_config: dict = None
    # args for ipex
    ipex_config: dict = None
    # args for mpirun
    mpirun_config: dict = None
    # args for TPU
    downcast_bf16: bool = False

    # args for TPU pods
    tpu_name: str = None
    tpu_zone: str = None
    tpu_use_cluster: bool = False
    tpu_use_sudo: bool = False
    command_file: str = None
    commands: list[str] = None
    tpu_vm: list[str] = None
    tpu_env: list[str] = None

    # args for dynamo
    dynamo_config: dict = None

    def __post_init__(self):
        if self.deepspeed_config is None:
            self.deepspeed_config = {}
        if self.fsdp_config is None:
            self.fsdp_config = {}
        if self.megatron_lm_config is None:
            self.megatron_lm_config = {}
        if self.ipex_config is None:
            self.ipex_config = {}
        if self.mpirun_config is None:
            self.mpirun_config = {}
        if self.fp8_config is None:
            self.fp8_config = {}
        return super().__post_init__()


@dataclass
class SageMakerConfig(BaseConfig):
    ec2_instance_type: str
    iam_role_name: str
    image_uri: Optional[str] = None
    profile: Optional[str] = None
    region: str = "us-east-1"
    num_machines: int = 1
    gpu_ids: str = "all"
    base_job_name: str = f"accelerate-sagemaker-{num_machines}"
    pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
    transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
    py_version: str = SAGEMAKER_PYTHON_VERSION
    sagemaker_inputs_file: str = None
    sagemaker_metrics_file: str = None
    additional_args: dict = None
    dynamo_config: dict = None
    enable_cpu_affinity: bool = False
