"""
# Full training
python trl/scripts/sft.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --learning_rate 2.0e-5 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2-0.5B-SFT \
    --push_to_hub

# LoRA
python sft_train_first.py \
    --model_name_or_path cognitivecomputations/dolphin-2.1-mistral-7b \
    --dataset_name argilla/ultrafeedback-binarized-preferences-cleaned \
    --learning_rate 2.0e-4 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 100 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --output_dir SFT_dolphin_ultra

accelerate launch sft_train_first.py \
    --model_name_or_path cognitivecomputations/dolphin-2.1-mistral-7b \
    --dataset_name argilla/ultrafeedback-binarized-preferences-cleaned \
    --learning_rate 2.0e-4 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --gradient_checkpointing \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 100 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --output_dir SFT_dolphin_ultra

deepspeed sft_train_first.py     --model_name_or_path cognitivecomputations/dolphin-2.1-mistral-7b   --dataset_name /home/ubuntu/arizonafiles/cvxdpo/ultra_dataset_full.json  --learning_rate 2.0e-4     --num_train_epochs 1     --packing     --per_device_train_batch_size 2     --gradient_accumulation_steps 8     --gradient_checkpointing     --logging_steps 25     --eval_strategy steps     --eval_steps 100     --use_peft     --lora_r 32     --lora_alpha 16     --output_dir SFT_dolphin_ultra     --deepspeed ds_config_zero3.json --weight_decay 0.01 --bf16 True --warmup_steps 1000


"""
import argparse
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, LlamaTokenizer
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_peft_config,
    get_quantization_config,
)
from transformers import AutoTokenizer


# Use Flash Attention if GPU is Ampere or newer
if torch.cuda.get_device_capability()[0] >= 8:
    attn_implementation = "flash_attention_2"
else:
    attn_implementation = "eager"


def main(script_args, training_args, model_args):
    ################
    # Model Init & Tokenizer
    ################
    
    quantization_config = get_quantization_config(model_args)

    # # Load model correctly for DeepSpeed ZeRO-3
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_args.model_name_or_path,
    #     torch_dtype=torch.bfloat16,  # ✅ Use bf16
    #     attn_implementation=attn_implementation,
    #     trust_remote_code=model_args.trust_remote_code,
    #     low_cpu_mem_usage=False,  # ✅ FIX: Must be False for ZeRO-3
    # ).to("cuda")  

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        attn_implementation=attn_implementation,
        trust_remote_code=model_args.trust_remote_code,
        #device_map={"": 0},  # Load model directly onto GPU vs model.to("cuda") 
        low_cpu_mem_usage=False,
    ).to("cuda")


    # Load tokenizer
    #tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left' 
    tokenizer.truncation_side = 'left' 

    ################
    # Dataset
    ################
    
    
    dataset = load_dataset("json", data_files=script_args.dataset_name)

    # Split dataset into train (90%) and test (10%)
    split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=1024)
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Eval dataset size: {len(eval_dataset)}")

    def format_sft_data(batch):
        """Ensures that 'prompt' and 'chosen' are correctly formatted as lists."""
        
        formatted_texts = []
        
        for prompt, chosen in zip(batch["prompt"], batch["chosen"]):
            prompt = " ".join(prompt) if isinstance(prompt, list) else str(prompt)
            if isinstance(chosen, list):
                chosen = " ".join(
                    item["text"] if isinstance(item, dict) and "text" in item else str(item)
                    for item in chosen
                )
            else:
                chosen = str(chosen)
            formatted_texts.append(prompt + "\n" + chosen)

        return {"text": formatted_texts}

    # Apply formatting to datasets
    train_dataset = train_dataset.map(format_sft_data, batched=True, num_proc=4)
    eval_dataset = eval_dataset.map(format_sft_data, batched=True, num_proc=4)

    # Enable DeepSpeed
    training_args.deepspeed = "ds_config_zero3.json"
    ################
    # Training
    ################
    #model.to("cuda")
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
    )

    trainer.train()

    # Save trained model
    trainer.save_model(training_args.output_dir)


def make_parser(subparsers: argparse._SubParsersAction = None):
    dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)  # These are 3 types of HF dataclasses
    if subparsers is not None:
        parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
    else:
        parser = TrlParser(dataclass_types)
    return parser


if __name__ == "__main__":
    parser = make_parser()
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
