import os, re, argparse, numpy as np
from itertools import chain

import torch
from reg_trainer import RegularizedTrainer, RegularizedTrainingArguments
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    default_data_collator
)

from modeling.utils import AdaptiveProjector
from datasets import load_metric
from peft import LoraConfig, get_peft_model
import time
from utils import *
from modeling.modeling_qwen3_kv import Qwen3ForCausalLM as ModelForCausalLM
from modeling.configuration_qwen3_kv import Qwen3Config as Config

# ------------------------------ main ------------------------------------- #
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--base_model", default="Qwen/Qwen3-4B")
    pa.add_argument("--dataset_name", default="xsum")
    pa.add_argument("--output_dir", default="lora_out")
    pa.add_argument("--ds_config", default="./config/ds_config_lora.json", help="Path to DeepSpeed json.")
    pa.add_argument("--per_device_train_batch_size", type=int, default=4)
    pa.add_argument("--per_device_eval_batch_size", type=int, default=1)
    pa.add_argument("--num_train_epochs", type=int, default=1)
    pa.add_argument("--learning_rate", type=float, default=1e-4)
    pa.add_argument("--gradient_accumulation_steps", type=int, default=1)
    pa.add_argument("--max_length", type=int, default=1024)
    pa.add_argument("--streaming", action="store_true")
    pa.add_argument("--num_proj", type=int, default=2, help="The number of projector")
    pa.add_argument("--hidden_dim", type=int, default=256)
    pa.add_argument("--local_rank", type=int, default=-1)
    pa.add_argument("--alpha", type=float, default=1.0)

    pa.add_argument("--eval_max_new_tokens", type=int, default=64)
    pa.add_argument("--eval_num_beams", type=int, default=4)
    pa.add_argument("--save_predictions", type=str, default="results/")

    args = pa.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=args.base_model, 
        trust_remote_code=True
    )
    
    
    config = Config.from_pretrained(
        pretrained_model_name_or_path=args.base_model,
        trust_remote_code=True
    )
    if args.num_proj:
        config.num_proj = args.num_proj
    if args.hidden_dim:
        config.hidden_dim = args.hidden_dim
    model = ModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=args.base_model,
        config=config, 
        torch_dtype="bfloat16", 
        trust_remote_code=True,
        attn_implementation="eager"
    )
    for name, module in model.named_modules():
        if isinstance(module, AdaptiveProjector):
            module.weight_init()
            module.requires_grad_(True)
        if "adapt" not in name:
            module.requires_grad_(False)
    
    model.cuda()
    print_trainable_parameters(model)

    if "xsum" in args.dataset_name:
        from data.xsum import XsumDataset as Dataset
    elif "piqa" in args.dataset_name:
        from data.piqa import PiqaDataset as Dataset
    
    train_ds = Dataset("train", tokenizer, mode="train")
    val_ds = Dataset("validation", tokenizer, mode="eval")
    targs = RegularizedTrainingArguments(
        output_dir=args.output_dir,
        deepspeed=args.ds_config,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        save_strategy="no",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':False},
        bf16=True,
        fp16=False,                     
        remove_unused_columns=False,
        logging_steps=50,
        report_to=["none"],
        alpha=args.alpha,
    )
    
    trainer = RegularizedTrainer(
        model=model,
        args=targs,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=default_data_collator,
        tokenizer=tokenizer,
        dataset_text_field="text",
        preprocess_logits_for_metrics=None,
        packing=False,
        formatting_func=None,
    )

    trainer.train()
    
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    



if __name__ == "__main__":
    main()