import os,json
from dataclasses import dataclass, field
from typing import Optional

import transformers

# deepspeed = {

#     "bf16": {
#         "enabled": "auto"
#     },

#     "optimizer": {
#         "type": "AdamW",
#         "params": {
#             "lr": "auto",
#             "betas": "auto",
#             "eps": "auto",
#             "weight_decay": "auto"
#         }
#     },

#     "zero_optimization": {
#         "stage": 1
#     },

#     "gradient_accumulation_steps": "auto",
#     "gradient_clipping": "auto",
#     "steps_per_print": "auto",
#     "train_batch_size": "auto",
#     "train_micro_batch_size_per_gpu": "auto",
#     "wall_clock_breakdown": False
# }

# def save_deepspeed_config(config_dict, config_path="./deepspeed_config.json"):
#     """将DeepSpeed配置保存为JSON文件"""
#     import json
#     with open(config_path, 'w') as f:
#         json.dump(config_dict, f, indent=4)
#     return config_path


@dataclass
class ModelArguments:
    local_dir: str = field(
        default=None, metadata={"help": "Local Path of storing inputs and outputs "}
    )
    input_model_filename: Optional[str] = field(
        default="test-input", metadata={"help": "Input model relative path"}
    )
    output_model_filename: Optional[str] = field(
        default="test-output", metadata={"help": "Output model relative path"}
    )
    share_embedding: Optional[bool] = field(
        default=True, metadata={"help": "whether to share input/output embedding"}
    )
    layer_sharing: Optional[bool] = field(
        default=True, metadata={"help": "whether to do layer sharing"}
    )
    first_stage: str = field(
        default=None
    )


@dataclass
class DataArguments:
    train_data_local_path: Optional[str] = field(
        default=None, metadata={"help": "Train data local path"}
    )
    eval_data_local_path: Optional[str] = field(
        default=None, metadata={"help": "Eval data local path"}
    )



@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: Optional[str] = field(default="adamw_torch")
    output_dir: Optional[str] = field(default="/tmp/output/")
    model_max_length: Optional[int] = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated). 512 or 1024"
        },
    )


def process_args():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # deepspeed_config_path = save_deepspeed_config(deepspeed)
    # training_args.deepspeed = deepspeed_config_path



    return model_args, data_args, training_args
