import os, re, argparse, numpy as np
from itertools import chain

import torch
from kd_trainer import DistillationTrainer, DistillationTrainingArguments
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    default_data_collator
)
from modeling.utils import AdaptiveProjector
from datasets import load_metric
from peft import LoraConfig, get_peft_model
import time
from utils import *
import copy
from modeling.modeling_qwen3_kv import Qwen3ForCausalLM as ModelForCausalLM
from modeling.configuration_qwen3_kv import Qwen3Config as Config
    
# ------------------------------ main ------------------------------------- #
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--base_model", default="Qwen/Qwen3-4B")
    pa.add_argument("--dataset_name", default="xsum")
    pa.add_argument("--output_dir", default="lora_out")
    pa.add_argument("--ds_config", default="./config/ds_config_lora.json", help="Path to DeepSpeed json.")
    pa.add_argument("--per_device_train_batch_size", type=int, default=4)
    pa.add_argument("--per_device_eval_batch_size", type=int, default=4)
    pa.add_argument("--gradient_accumulation_steps", type=int, default=1)
    pa.add_argument("--num_train_epochs", type=int, default=1)
    pa.add_argument("--learning_rate", type=float, default=1e-4)
    pa.add_argument("--max_length", type=int, default=512)
    # pa.add_argument("--max_target_len", type=int, default=128)
    pa.add_argument("--streaming", action="store_true")
    # pa.add_argument("--num_proj", type=int, default=8, help="The number of projector")
    # pa.add_argument("--lora_targets", type=str, default="q_proj,k_proj,v_proj,o_proj")
    pa.add_argument("--local_rank", type=int, default=-1)
    pa.add_argument("--alpha", type=float, default=1.0)
    
    pa.add_argument("--beta", type=float, default=0.0)
    pa.add_argument("--ratio", type=float, default=0.95)
    pa.add_argument("--eval_max_new_tokens", type=int, default=64)
    pa.add_argument("--eval_num_beams", type=int, default=4)
    pa.add_argument("--save_predictions", type=str, default="results/")

    args = pa.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=args.base_model, 
        trust_remote_code=True
    )
    

    config = Config.from_pretrained(
        pretrained_model_name_or_path=args.base_model,
        trust_remote_code=True
    )
    teacher_model = None

    model = ModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=args.base_model, 
        config=config,
        torch_dtype="bfloat16", 
        # device_map="auto",
        trust_remote_code=True
    )

    model.cuda()
    
    

    model, low_dim_list = reduced_svd(model, args.ratio, r_type="none")
    
    model.config.low_dim_list = low_dim_list
    
    for name, module in model.named_modules():
        if isinstance(module, AdaptiveProjector):
            module.weight_init()
            module.requires_grad_(True)
        if "adapt" not in name:
            module.requires_grad_(False)
    model = model.to(torch.bfloat16)
    print_trainable_parameters(model)
    
    if "xsum" in args.dataset_name:
        from data.xsum import XsumDataset as Dataset
    elif "piqa" in args.dataset_name:
        from data.piqa import PiqaDataset as Dataset
    
    train_ds = Dataset("train", tokenizer, mode="train")
    val_ds = Dataset("validation", tokenizer, mode="eval")
    targs = DistillationTrainingArguments(
        output_dir=args.output_dir,
        deepspeed=args.ds_config,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        save_strategy="no",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':False},
        bf16=True,
        fp16=False,                     
        remove_unused_columns=False,
        logging_steps=50,
        report_to=["none"],
        alpha=args.alpha,
        beta=args.beta,
    )
    
    trainer = DistillationTrainer(
        teacher_model=teacher_model,
        model=model,
        args=targs,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=default_data_collator,
        tokenizer=tokenizer,
        dataset_text_field="text",
        preprocess_logits_for_metrics=None,
        packing=False,
        formatting_func=None,
    )

    trainer.train()
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)




if __name__ == "__main__":
    main()