import unsloth
from datasets import load_from_disk,load_dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel
import torch
import pandas as pd
import os
from datasets import Dataset
import argparse
from typing import Optional, Union
def get_args():
    parser = argparse.ArgumentParser(description="SFT Script")
    
    parser.add_argument(
        "--model_dir",
        type=str,
        required=False,
        default="model",
        help="Path to the model directory or model name.",
    )
    
    parser.add_argument(
        "--model_name",
        type=str,
        required=False,
        default="Qwen/Qwen3-8B",
    )

    parser.add_argument(
        "--content_field",
        type=int,
        required=False,
        default=15,
    )
    
    return parser.parse_args()
def generate_conversation(examples,content_field=15):
    problems  = examples["origin_question"]
    # solutions = examples["generated_solution"]
    conversations = []
    for index, problem in enumerate(problems):
        conversations.append([
            {"role" : "user",      "content" : problem},
            {"role" : "assistant", "content" : examples[f"split_answer_{content_field}"][index]},
        ])
        # conversations.append([
        #     {"role" : "user",      "content" : problem},
        #     {"role" : "assistant", "content" : examples[f"split_answer_10"][index]},
        # ])
    return { "conversations": conversations}
def load_sft_data(
    tokenizer,
    content_field: int = 15,
    data_path: str = "gsm8k_batch_inference_v3",
    shuffle_seed: Optional[int] = 3407,
    data_split: str = "train",
) -> Dataset:   
    ds=load_dataset(data_path)[data_split]
    print(f"finish load data len:{len(ds)}")
    data=ds.map(generate_conversation,
                batched=True,
                batch_size=100,
                fn_kwargs={
                    'content_field':content_field,
                }
                )
    # Apply chat template
    reasoning_conversations = tokenizer.apply_chat_template(
        data["conversations"],
        tokenize=False,
    )
    
    # Create dataset directly without intermediate pandas steps
    combined_dataset = Dataset.from_dict({"text": reasoning_conversations})
    
    # Shuffle if seed is provided
    if shuffle_seed is not None:
        combined_dataset = combined_dataset.shuffle(seed=shuffle_seed)
    print(f"finish load sft data len:{len(combined_dataset)}")
    print(f"sft data exapmle:{combined_dataset[0]}")
    return combined_dataset

def main():
    args=get_args()
    print(args)
    model_path=f"{args.model_dir}/{args.model_name}"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = model_path,
        max_seq_length = 8192,   # Context length - can be longer, but uses more memory
        load_in_4bit = False,     # 4bit uses much less memory
        load_in_8bit = False,    # A bit more accurate, uses 2x memory
        full_finetuning = True, # We have full finetuning now!
        token = "",      # use one if using gated models
    )
    train_dataset=load_sft_data(tokenizer,content_field=args.content_field)
    # model = FastLanguageModel.get_peft_model(
    #     model,
    #     r = 32,           # Choose any number > 0! Suggested 8, 16, 32, 64, 128
    #     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
    #                     "gate_proj", "up_proj", "down_proj",],
    #     lora_alpha = 32,  # Best to choose alpha = rank or rank*2
    #     lora_dropout = 0, # Supports any, but = 0 is optimized
    #     bias = "none",    # Supports any, but = "none" is optimized
    #     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    #     use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    #     random_state = 3407,
    #     use_rslora = False,   # We support rank stabilized LoRA
    #     loftq_config = None,  # And LoftQ
    # )
    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = train_dataset,
        eval_dataset = None, # Can set up evaluation!
        args = SFTConfig(
            dataset_text_field = "text",
            per_device_train_batch_size = 8,
            # gradient_accumulation_steps = 1, # Use GA to mimic batch size!
            warmup_steps = 5,
            num_train_epochs = 1, # Set this for 1 full training run.
            # max_steps = 30,
            learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
            logging_steps = 1,
            # optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            report_to = "wandb", # Use this for WandB etc
        ),
    )

    print("start training")
    trainer_stats = trainer.train()
    print(trainer_stats)

    sft_model_name=f"{args.model_name}_SFT_{args.content_field}"
    print(sft_model_name)    
    model.save_pretrained(sft_model_name)  # Local saving
    tokenizer.save_pretrained(sft_model_name)
    print("="*20+"finish save_pretrained"+"="*20)
    model.push_to_hub_merged(f"{sft_model_name_merged}", tokenizer, token = "")
    print("="*20+"finish push_to_hub_merged"+"="*20)
if __name__ == '__main__':
    main()