import math
import torch
import random
def kl_loss(prob_p, prob_q):
    return -(prob_p * torch.log(prob_q + 1e-12)).sum(-1).mean()
def Gradient_ascent_KL(model, infer_model,tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, device):
    data_length_forget = len(forget_data)
    random.shuffle(retain_set)
    retain_data = " ".join(retain_set)
    num_chunks_forget = math.ceil(data_length_forget / chunk_size)

    for epoch in range(epochs):
        print(f"\n[FineTuning] Epoch {epoch + 1} ...")
        model.train()  # Set model to training mode
        optimizer.zero_grad()
        total_loss = 0.0

        for chunk_idx in range(num_chunks_forget):
            start_idx = chunk_idx * chunk_size
            end_idx = start_idx + chunk_size
            forget_chunk_text = forget_data[start_idx:end_idx]

            # forget Tokenize inputs
            forget_inputs = tokenizer(
                forget_chunk_text,
                return_tensors="pt",
                truncation=True,
                max_length=chunk_size,
                padding=True
            ).to(device)

            forget_input_ids = forget_inputs["input_ids"]
            forget_attention_mask = forget_inputs["attention_mask"]
            forget_labels = forget_input_ids.masked_fill(forget_attention_mask == 0, -100)

            # Forward pass
            forget_outputs = model(
                input_ids=forget_input_ids,
                attention_mask=forget_attention_mask,
                labels=forget_labels
            )
            forget_loss = forget_outputs.loss
            loss_forget = -forget_loss  # Apply negative loss for gradient ascent (forgetting)

        
            # retain Tokenize inputs
            retain_chunk_text = retain_data[start_idx:end_idx]
            retain_inputs = tokenizer(
                retain_chunk_text,
                return_tensors="pt",
                truncation=True,
                max_length=chunk_size,
                padding=True
            ).to(device)

            retain_input_ids = retain_inputs["input_ids"]
            retain_attention_mask = retain_inputs["attention_mask"]
            retain_labels = retain_input_ids.masked_fill(retain_attention_mask == 0, -100)
            # Forward pass
            retain_outputs = model(
                input_ids=retain_input_ids,
                attention_mask=retain_attention_mask,
                labels=retain_labels
            )
            with torch.no_grad():
                infer_retain_outputs = infer_model(
                    input_ids=retain_input_ids,
                    attention_mask=retain_attention_mask,
                    labels=retain_labels
                )
            prob_retain_p = torch.softmax(retain_outputs.logits, dim=-1)
            prob_retain_q = torch.softmax(infer_retain_outputs.logits, dim=-1)
            loss_retain = kl_loss(prob_retain_p, prob_retain_q)

            loss= loss_forget+ loss_retain 

            # Add retain_data loss to total_loss
            total_loss += loss_retain.item()
            total_loss += loss_forget.item()

            # Backward pass (we accumulate the gradients)
            loss.backward()
            # Perform backward pass after accumulating both losses
            optimizer.step()
            scheduler.step()

        # Compute average loss for the epoch
        avg_loss = total_loss/ (num_chunks_forget + 1)  # Total chunks from both sets
        print(f"[Epoch {epoch + 1}] avg_loss={avg_loss:.4f}")
