import os

from dataclasses import dataclass, field
from typing import Optional, List, Literal

from transformers import HfArgumentParser


@dataclass
class PreprocessArg:

    data_dir: str = field(
        metadata={"help": "The directory that includes raw text dataset"},
    )
    dest_dir: str = field(
        metadata={"help": "The directory that saves binarized dataset."},
    )
    save_name: str = field(
        default="train",
        metadata={"help": "The prefix name for saving files."},
    )
    worker: int = field(
        default=50,
        metadata={"help": "The number of workers to process data."},
    )
    numberized: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether the dataset has been numberized."},
    )
    tokenizer: Optional[str] = field(
        default="hf-internal-testing/llama-tokenizer",
        metadata={"help": "The tokenizer to tokenize dataset."},
    )
    append_eos: Optional[bool] = field(
        default=True,
        metadata={"help": "Append eos token id to each sentence."},
    )
    chunk_load: Optional[bool] = field(
        default=False,
        metadata={"help": "Read large file by chunked lines."}
    )
    auth: str = field(
        default="hf_xxx",
        metadata={"help": "The authenticated token."}
    )


@dataclass
class ModelArg:
    model_name_or_path: Optional[str] = field(
        default='meta-llama/Llama-2-7b-hf',
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    use_hf_checkpoint: Optional[bool] = field(
        default=False,
        metadata={"help": "build model from hugging face checkpoint."}
    )
    use_hf_config: Optional[bool] = field(
        default=False,
        metadata={"help": "build model from hugging face config."}
    )
    hidden_size: int = field(
        default=4096,
        metadata={"help": "Dimension of the hidden representations."},
    )
    intermediate_size: int = field(
        default=11008,
        metadata={"help": "Dimension of the MLP representations."},
    )
    num_hidden_layers: int = field(
        default=32,
        metadata={"help": "Number of hidden layers in the Transformer decoder."},
    )
    num_attention_heads: int = field(
        default=32,
        metadata={"help": "Number of attention heads for each attention layer in the Transformer decoder."}
    )
    norm_eps: float = field(
        default=1e-6,
        metadata={"help": "The value for numerical stability."}
    )
    dropout: float = field(
        default=0.1,
        metadata={"help": "The ratio of dropout layer."}
    )
    vocab_size: Optional[int] = field(
        default=-1,
        metadata={"help": "The number of vocabulary, determined by tokenizer."}
    )
    auth: str = field(
        default="hf_xxx",
        metadata={"help": "The authenticated token."}
    )
    model_init: str = field(
        default=None,
        metadata={"help": "Use model initialization"}
    )


@dataclass
class TrainArg:
    data_dir: str = field(
        default="../data-bin",
        metadata={"help": "The directory that include binarized datasets."},
    )
    tokenizer: Optional[str] = field(
        default="hf-internal-testing/llama-tokenizer",
        metadata={"help": "The tokenizer to tokenize dataset"},
    )
    train_split: List[str] = field(
        default_factory=lambda: [f'train{i}' for i in range(1, 11)],
        metadata={"help": "The prefix for training dataset."}
    )
    valid_split: List[str] = field(
        default_factory=lambda: ['valid'],
        metadata={"help": "The prefix for validation dataset"}
    )
    save_dir: str = field(
        default="checkpoints",
        metadata={"help": "The save path for saving checkpoints."}
    )
    save_every_N_steps: int = field(
        default=2000,
        metadata={"help": "Save checkpoint evey N steps."}
    )
    seq_len: int = field(
        default=2048,
        metadata={"help": "The maximum length for pre-training (next token prediction)."},
    )
    sample_len: int = field(
        default=1024,
        metadata={"help": "The length of sparse training."}
    )
    batch_size: int = field(
        default=32,
        metadata={"help": "The batch size."},
    )
    eval_batch_size: int = field(
        default=32,
        metadata={"help": "The batch size for evaluation."}
    )
    accumulation_step: int = field(
        default=1,
        metadata={"help": "The gradient accumulation steps."},
    )
    num_worker: int = field(
        default=20,
        metadata={"help": "The number of subprocesses to use for data loading."},
    )
    precision: Literal["fp32", "fp16", "bf16"] = field(
        default="fp32",
        metadata={"help": "The precision for training / inference."}
    )
    max_epoch: int = field(
        default=-1,
        metadata={"help": "The maximum number of training epoch"},
    )
    max_update: int = field(
        default=100000,
        metadata={"help": "The maximum training steps."}
    )
    seed: int = field(
        default=1234,
        metadata={"help": "The seed number for experiment running."}
    )
    append_bos: bool = field(
        default=False,
        metadata={"help": "Append bos token to sequence."}
    )
    criterion: str = field(
        default="cross_entropy",
        metadata={"help": f"The criterion for calculate loss."}
    )
    lr_scheduler: Literal["linear", "cosine", "polynomial", "constant", "constant_with_warmup"] = field(
        default="linear",
        metadata={"help": "The schduler type to use."}
    )
    learning_rate: float = field(
        default=1e-4,
        metadata={"help": "The learning rate for adam."}
    )
    num_warmup_steps: int = field(
        default=500,
        metadata={"help": "The warmup steps for adam."}
    )
    trainer: Literal["ddp", "accelerate", "deepspeed"] = field(
        default="ddp",
        metadata={"help": "The kind of trainer."}
    )
    use_zero: bool = field(
        default=False,
        metadata={"help": "Use zero in ddp trainer."}
    )


@dataclass
class DistArg:
    '''
    Running Command: 
        python -m torch.distributed.launch --use-env --nproc-per-node ${NUM_OF_GPUs} train.py
    '''
    distributed_backend: Optional[str] = field(
        default="nccl", 
        metadata={
            "help": "distributed backend",
            "choices": ["nccl"],
        },
    )
    distributed_rank: int = field(
        default=os.getenv("RANK", 0),
        metadata={"help": "The rank of the GPU process on all server."},
    )
    local_rank: int = field(
        default=os.getenv("LOCAL_RANK", 0),
        metadata={"help": "The rank of the GPU process on local machine,"
                          "same as distributed rank on one server."},
    )
    use_accelerate: bool = field(
        default=False,
        metadata={"help": "Use huggingface accelerate to speedup training."}
    )
    use_deepspeed: bool = field(
        default=False,
        metadata={"help": "Use deepspeed for acceleration."}
    )
    deepspeed_config: str = field(
        default="../../deepspeed_configs/zero_stage2_offload_config.json",
        metadata={"help": "The path of deepspeed configuration."}
    )
    distributed_world_size: int = field(
        default=os.getenv("WORLD_SIZE", 0),
        metadata={"help": "The world size of all gpu processes."}
    )


def parse_training_args():
    # create arguments for LLM training
    parser = HfArgumentParser([DistArg, ModelArg, TrainArg])
    
    
    from LLMProxy.criterion import get_criterion_list
    parser._option_string_actions['--criterion'].choices = get_criterion_list()

    args, _ = parser.parse_known_args()
    dist_args, model_args, train_args = parser.parse_dict(vars(args))
    return (dist_args, model_args, train_args)


def parse_preprocess_args():
    # create arguments for preprocessing datasets
    parser = HfArgumentParser([PreprocessArg])
    args, _ = parser.parse_known_args()

    preprocess_args, = parser.parse_dict(vars(args))
    return preprocess_args
