# -*- coding: utf-8 -*-
import os
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    AutoConfig,
)
import torch
from transformers import TrainerCallback
from trl import SFTConfig, SFTTrainer
from meft import MeftTrainer,MeftConfig
import argparse
import numpy as np

import random

from trl import SFTTrainer
from meft.optimizer.lowrank_adamw import AdamW
from distutils.util import strtobool
class lowrankTrainer(SFTTrainer):
    def __init__(self, alpha=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha

    def create_optimizer(self):
        if self.optimizer is None:
            self.optimizer = AdamW(
                self.model.parameters(),
                lr=self.args.learning_rate,
                alpha=self.alpha,
            )
        return self.optimizer
    
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

model_path = r"Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True,
    local_files_only=True,
    padding_side="left"
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token



parser = argparse.ArgumentParser()
parser.add_argument(
    "--patch-locations",
    type=str,
    nargs='+',
    default = None,
)
parser.add_argument(
    "--output_dir",
    type=str,
    required=True,
)
parser.add_argument(
    "--rank",
    type=int,
    default=32,
)
parser.add_argument(
    "--niter",
    type=int,
    default=1,
)
parser.add_argument(
    "--num_train_epochs",
    type=int,
    default=10,
)
parser.add_argument(
    "--per_device_train_batch_size",
    type=int,
    default=8,
)
parser.add_argument(
    "--learning_rate",
    type=float,
    default=2e-5,
)
parser.add_argument(
    "--use_optimizer_compress",
    type=lambda x: bool(strtobool(x)),
    default=False,
)

parser.add_argument(
    "--use_gradient_compress",
    type=lambda x: bool(strtobool(x)),
    default=False,
)
parser.add_argument(
    "--compressed_method",
    type=str,
    default="rqb",
)
parser.add_argument(
    "--alpha",
    type=float,
    default=1.0,
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "args.txt"), "w", encoding="utf-8") as f:
    for k, v in vars(args).items():
        f.write(f"{k}: {v}\n")


model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
model.config.use_cache = False
model.gradient_checkpointing_disable()
dataset = load_dataset(
    "parquet",
    data_files={
        "train": "/path/to/train-00000-of-00001.parquet",
        "test": "/path/to/test-00000-of-00001.parquet"
    }
)
SYSTEM_PROMPT = r"You are a helpful assistant. You need to solve some math problems and present the answers enclosed in boxed{}."
XML_COT_FORMAT = """
{think}

boxed{{{answer}}}

"""

def extract_cot(text: str) -> str:
    if "####" not in text:
        return ""
    cot = text.split("####")
    return XML_COT_FORMAT.format(think=cot[0].strip(), answer=cot[1].strip())

dataset = dataset.map(lambda x: {
    'messages': [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': x['question']},
        {'role': 'assistant', 'content': extract_cot(x['answer'])},
    ]
})

patch_locations = args.patch_locations
meftconfig = MeftConfig(patch_locations = patch_locations,
                        compress_kwargs = {"rank": args.rank,"niter": args.niter,"use_optimizer_compress": args.use_optimizer_compress,"method":args.compressed_method,"use_gradient_compress": args.use_gradient_compress},
                        compress_workers = 2,
                        )

class SaveWeightOnlyCheckpointCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        global_step = state.global_step
        checkpoint_folder = f"checkpoint-{global_step}"
        output_dir = os.path.join(args.output_dir, checkpoint_folder)
        model.save_pretrained(output_dir)

sft_config = SFTConfig(
    output_dir=args.output_dir,
    num_train_epochs=args.num_train_epochs,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=args.learning_rate,
    lr_scheduler_type="linear",
    warmup_ratio=0.0,
    logging_steps=10,
    bf16=torch.cuda.is_bf16_supported(),
    max_length = 1024,
    packing=False,
    dataset_text_field=None,
    dataset_kwargs={"add_special_tokens": False},
    report_to="tensorboard",
)

trainer = MeftTrainer[lowrankTrainer](
    model=model,
    args=sft_config,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    callbacks=[SaveWeightOnlyCheckpointCallback()],
    meft_config=meftconfig,
    alpha = args.alpha,
)
trainer.train()