import math
import torch
import torch.nn as nn
def Random_label(model, tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, device):
    data_length_forget = len(forget_data)
    loss_fn = nn.CrossEntropyLoss()
    num_chunks_forget = math.ceil(data_length_forget / chunk_size)

    for epoch in range(epochs):
        print(f"\n[FineTuning] Epoch {epoch + 1} ...")
        model.train()  # Set model to training mode
        optimizer.zero_grad()
        total_loss = 0.0

        for chunk_idx in range(num_chunks_forget):
            start_idx = chunk_idx * chunk_size
            end_idx = start_idx + chunk_size
            forget_chunk_text = forget_data[start_idx:end_idx]
            # Tokenize inputs
            infer_forget_inputs = tokenizer(
                forget_chunk_text,
                return_tensors="pt",
                truncation=True,
                max_length=chunk_size,
                padding=True
            ).to(device)

            infer_forget_input_ids = infer_forget_inputs["input_ids"]
            infer_forget_attention_mask = infer_forget_inputs["attention_mask"]

            # Mask the padding tokens to be ignored in loss calculation
            infer_forget_labels = infer_forget_input_ids.masked_fill(infer_forget_attention_mask == 0, -100)

            num_classes = tokenizer.vocab_size

            random_labels = torch.randint(0, num_classes, infer_forget_labels.shape).to(device)

            random_labels = random_labels.masked_fill(infer_forget_attention_mask == 0, -100)  # 忽略填充部分

            forget_outputs = model(
                input_ids=infer_forget_input_ids,
                attention_mask=infer_forget_attention_mask,
                labels=random_labels
            )
            loss = forget_outputs.loss
            total_loss+=loss.item()
            # Zero gradients before backward pass
            optimizer.zero_grad()

            # Backward pass (gradient accumulation happens here)
            loss.backward()

            # Perform parameter update
            optimizer.step()
            scheduler.step()


        # Compute average loss for the epoch
        avg_loss = total_loss / (num_chunks_forget + 1)  # Total chunks from both sets
        print(f"[Epoch {epoch + 1}] avg_loss={avg_loss:.4f}")
