
"""The main training script to supervised finetune a model using Hugging Face Transformers Trainer."""

import argparse
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union

import transformers
from transformers.training_args import OptimizerNames

from safe_rlhf.datasets import SupervisedDataset, parse_dataset
from safe_rlhf.models import load_pretrained_models


@dataclass
class ModelArguments:
    """Arguments for models."""

    model_name_or_path: str


@dataclass
class DataArguments:
    """Arguments for datasets."""

    datasets: List[parse_dataset] = field(
        default=None,
        metadata={'help': 'Path to the training data.'},
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    """Arguments for the training loop."""

    cache_dir: Optional[str] = field(default=None)
    optim: Union[OptimizerNames, str] = field(
        default=OptimizerNames.ADAMW_TORCH,
        metadata={'help': 'The optimizer to use.'},
    )
    model_max_length: int = field(
        default=512,
        metadata={
            'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).',
        },
    )


def parse_arguments() -> Tuple[argparse.Namespace, argparse.Namespace, argparse.Namespace]:
    """Parse the command-line arguments."""
    parser = transformers.HfArgumentParser([TrainingArguments, ModelArguments, DataArguments])
    # pylint: disable-next=unbalanced-tuple-unpacking
    training_args, model_args, data_args = parser.parse_args_into_dataclasses()
    return training_args, model_args, data_args


def main() -> None:
    """Main training routine."""
    # pylint: disable=no-member
    training_args, model_args, data_args = parse_arguments()

    model, tokenizer = load_pretrained_models(
        model_args.model_name_or_path,
        model_max_length=training_args.model_max_length,
        padding_side='right',
        cache_dir=training_args.cache_dir,
        trust_remote_code=True,
    )

    train_dataset = SupervisedDataset(
        data_args.datasets,
        tokenizer=tokenizer,
        seed=training_args.seed,
    )
    data_collator = train_dataset.get_collator()

    trainer = transformers.Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model()


if __name__ == '__main__':
    main()
