# relearn_attack.py
import os
import argparse
import random
import torch
from torch.utils.data import DataLoader, Subset, SequentialSampler, DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch.nn.utils.rnn import pad_sequence
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, LanguageModelingDataset
from utils.lora import apply_lora
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
access_token = "your_huggingface_access_token"

def parse_args():
    parser = argparse.ArgumentParser(description="Relearning Attack Script")
    parser.add_argument("--model_name",type=str,required=True,help="基座模型名称（HuggingFace 上的 ID），例如 mistralai/Mistral-7B-v0.1")
    parser.add_argument("--save_path",type=str,required=True,help="Relearn Attack 完成后新的 LoRA adapter 保存目录")
    parser.add_argument("--dataset",type=str,required=True,help="数据集名称，默认为 wmdp-cyber")
    parser.add_argument("--batch_size",type=int,required=True,help="Relearn 微调时的 batch_size（默认为 4）")
    parser.add_argument( "--lr", type=float, required=True, help="Relearn 微调时的学习率，默认为 1e-4")
    parser.add_argument("--num_epochs",type=int,required=True,help="Relearn 微调的 epoch 数，默认为 4")
    parser.add_argument("--seed",type=int,required=True,help="随机种子，用于采样 60 条 forget 样本")
    return parser.parse_args()

def main():
    args = parse_args()
    # 固定随机种子，确保可复现
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # 1. 设备设置：优先使用 cuda:0
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    t = TransformerModel(args.model_name)
    tokenizer = t.get_tokenizer()

    def collate_fn(batch):
        """
        batch 是一个 list，每个元素是来自 LanguageModelingDataset 的 dict，包含键 'input_ids'、'attention_mask'。
        我们对它们进行 pad_sequence，使同一 batch 内长度一致。
        """
        input_ids_list = [torch.tensor(item["input_ids"], dtype=torch.long) for item in batch]
        attention_mask_list = [torch.tensor(item["attention_mask"], dtype=torch.long) for item in batch]
        # pad 到当前 batch 内的最大长度
        input_ids_padded = pad_sequence(input_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id)
        attention_mask_padded = pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
        return {
            "input_ids": input_ids_padded,
            "attention_mask": attention_mask_padded
        }

    # 3. 加载基座模型 + LoRA adapter
    print("Loading base model and LoRA adapter...")
    base_model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        use_cache=False
    )
    # apply LoRA adapter
    model = apply_lora(base_model, r=8, model_name=args.model_name)
    model = model.to(device)
    model.train()

    for name, param in model.named_parameters():
        if "lora" in name.lower():
            param.requires_grad = True
        else:
            param.requires_grad = False
    print("Building forget DataLoader for sampling 600 examples...")
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        all_forget_loaders, _ = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=1, rank=0
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        all_forget_loaders, _ = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=1, rank=0
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    # get_train_data_ddp 返回的 forget_loaders 通常是一个长度与 forget_corpora 等长的列表
    forget_loader_full = all_forget_loaders[0]  # 假设只有一个 forget corpora
    full_dataset = forget_loader_full.dataset  # 应该是 LanguageModelingDataset

    total_num = len(full_dataset)
    print(f"Total forget samples in dataset: {total_num}")
    # 用常规 DataLoader 将 60 条样本打包，shuffle=True
    relearn_loader = DataLoader(
        full_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=False
    )

    # 5. 定义优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    # 6. 微调循环（默认为 1 个 epoch）
    print("Starting fine-tuning ...")
    for epoch in range(args.num_epochs):
        epoch_loss = 0.0
        step = 0
        for batch in relearn_loader:
            step += 1
            input_ids = batch["input_ids"].to(device)         # [batch_size, seq_len]
            attention_mask = batch["attention_mask"].to(device)  # [batch_size, seq_len]

            # LM HEAD 训练：labels = input_ids，模型会自行计算 shift
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            if step % 10 == 0 or step == len(relearn_loader):
                print(f"[Epoch {epoch+1} Step {step}/{len(relearn_loader)}] loss = {loss.item():.4f}")

        avg_loss = epoch_loss / len(relearn_loader)
        print(f"==> Epoch {epoch+1} done, average loss = {avg_loss:.4f}")

    # 7. 保存微调后的 LoRA adapter
    print(f"save fine tuning 后的 LoRA 模型到 {args.save_path} ...")
    os.makedirs(args.save_path, exist_ok=True)
    model.save_pretrained(args.save_path)
    tokenizer.save_pretrained(args.save_path)
    print("✅ fine tuning 完成，模型已保存。")

if __name__ == "__main__":
    main()
