import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import copy
import umap
import gc
import hdbscan
from torch import nn as nn
from config_1 import get_arguments
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score
from transformers import LlamaTokenizer, LlamaModel
from torch.utils.data import Dataset
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support
from peft import get_peft_model, LoraConfig, TaskType
import time


def set_seed(random_seed=11):
    # Set the seed value all over the place to make this reproducible.
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)


class TargetDataset(Dataset):
    def __init__(self, tokenizer, max_len, data):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.text = data["text"]
        self.targets = data["label"]
        self.flag = data["poisoned"]

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = str(self.text[item])
        target = self.targets[item]
        flag = self.flag[item]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors="pt",
        )

        return {
            "text": text,
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "label": torch.tensor(target, dtype=torch.long),
            "flag": flag
        }


class LlamaClassification(nn.Module):
    """
    Llama model for classification tasks.
    Pooling options:
    1) first token: hidden state of the first token <s>
    2) mean pooling: mean of all hidden states
    """
    def __init__(self, llama_model, label_num=2, pooling="last"):
        super().__init__()
        self.llama = llama_model
        self.hidden_size = self.llama.config.hidden_size
        self.config = self.llama.config
        self.pooling = pooling
        self.classifier = nn.Linear(self.hidden_size, label_num)
        self.classifier = self.classifier.to(llama_model.device)
        self.classifier = self.classifier.to(llama_model.dtype)

    def _pool(self, hidden_states, attention_mask):
        if self.pooling == "first":
            return hidden_states[:, 0, :]  # <s>

        elif self.pooling == "last":
            seq_lengths = attention_mask.sum(1) - 1  # (B,)
            batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
            return hidden_states[batch_idx, seq_lengths]
        else:
            raise ValueError(f"Unknown pooling method: {self.pooling}")

    def forward(self, 
                input_ids=None,
                attention_mask=None,
                position_ids=None,
                inputs_embeds=None,
                past_key_values=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=False,
                **kwargs):
        out = self.llama(input_ids=input_ids,
                         attention_mask=attention_mask,
                         position_ids=position_ids,
                         inputs_embeds=inputs_embeds,
                         past_key_values=past_key_values,
                         use_cache=use_cache,
                         output_attentions=output_attentions,
                         output_hidden_states=output_hidden_states)
        last_hidden_state = out.last_hidden_state
        sent_rep = self._pool(last_hidden_state, attention_mask)
        return self.classifier(sent_rep)  # (batch, label_num)


def extract_cls_hidden(llama, input_ids, attn_mask, layer_idx=-1, pooling="last"):
    """Return CLS [<s> in llama] vector from a specific layer (no_grad context assumed)."""
    hs = llama(
        input_ids=input_ids,
        attention_mask=attn_mask,
        output_hidden_states=True,
    ).hidden_states[layer_idx]

    if pooling == "first":
        return hs[:, 0, :]

    elif pooling == "last":
        seq_lengths = attn_mask.sum(1) - 1  # (B,)
        batch_idx = torch.arange(hs.size(0), device=hs.device)
        return hs[batch_idx, seq_lengths]


def rep_extract(llama, loader):
    llama.eval()
    rep_vec = []

    with torch.no_grad():
        for batch in tqdm(loader):
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            cls_vec = extract_cls_hidden(llama, input_ids, attn_mask, layer_idx=-1)
            rep_vec.append(cls_vec.cpu().float())

    return torch.cat(rep_vec).numpy()


def train(model, dataloader):
    total_loss = 0
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    for batch in tqdm(dataloader):
        model.train()
        model.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0

        optimizer.step()
        optimizer.zero_grad()

    return total_loss / len(dataloader)  # average batch loss for each epoch


def ga_train(model, retain_loader, forget_loader):
    """Gradient Ascent Unlearning: ReTain - Forget"""
    clean_loss = 0
    poison_loss = 0

    batches_retain = list(retain_loader)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    for batch_forget in forget_loader:
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)

        # Calculate loss
        loss_retain = loss_fn(logits_retain, labels_retain)  # clean samples: keep the functioning of the model
        loss_forget = loss_fn(logits_forget, labels_forget)  # poisoned samples: forget the backdoor of the model
        loss = loss_retain - loss_forget

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def npo_train(model, ref_model, retain_loader, forget_loader):
    """"Using RT + NPO """
    clean_loss = 0
    poison_loss = 0
    beta = 1

    batches_retain = list(retain_loader)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    loss_fn_base = torch.nn.CrossEntropyLoss(reduction="none")
    loss_fn_mean = torch.nn.CrossEntropyLoss(reduction="mean")

    for batch_forget in tqdm(forget_loader):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        # Retain loss
        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        loss_retain = loss_fn_mean(logits_retain, labels_retain)

        # NPO loss
        # [Reference model]
        with torch.no_grad():
            ref_model.eval()
            logits_forget_ref = ref_model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)
            loss_forget_ref = loss_fn_base(logits_forget_ref, labels_forget)

        # [Policy model]
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)
        loss_forget = loss_fn_base(logits_forget, labels_forget)

        negative_log_ratio = loss_forget - loss_forget_ref
        npo_loss = -F.logsigmoid(beta * negative_log_ratio).mean() * 2 / beta

        loss = loss_retain + npo_loss  # RT + NPO
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.mean().item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def rga_train(model, model_base, logits_ref, retain_loader, forget_loader):
    """"Robustness Gradient Ascent Unlearning"""
    clean_loss = 0
    poison_loss = 0

    batches_retain = list(retain_loader)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    loss_fn = torch.nn.CrossEntropyLoss()

    for batch_forget, logits_ref_batch in tqdm(zip(forget_loader, logits_ref)):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)

        # Adaptive Re-weight Gradient Ascent
        with torch.no_grad():
            weight = adaptive_weight(logits_forget, logits_ref_batch.to(device), scale_power=2.0).item()

        # Calculate loss
        loss_retain = loss_fn(logits_retain, labels_retain)  # clean samples: keep the functioning of the model
        loss_forget = loss_fn(logits_forget, labels_forget)  # poisoned samples: forget the backdoor of the model
        loss = loss_retain - weight * loss_forget

        # Regularization Term
        for param_1, param_2 in zip(model.llama.parameters(), model_base.llama.parameters()):
            loss += 5e-2 * torch.norm(param_1 - param_2).to(device)  # L2 regularization

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def compute_logits_ref(ref_model, forget_loader):
    logits_ref = []

    with torch.no_grad():
        for batch in tqdm(forget_loader):
            ref_model.eval()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            logits = ref_model(input_ids=input_ids, attention_mask=attention_mask)
            logits_ref.append(logits.detach().cpu())

    return logits_ref


def adaptive_weight(input_logits, ref_logits, scale_power=1.0):
    input_prob = torch.log_softmax(input_logits, dim=1)
    ref_prob = torch.softmax(ref_logits, dim=1)
    kl_divergence = F.kl_div(input_prob, ref_prob, reduction="batchmean")
    weight = torch.pow(torch.exp(-kl_divergence), scale_power)

    return weight


def evaluate(model, dataloader, data_name):
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            model.eval()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"]

            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(logits, dim=1)

            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(labels.numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    cm = confusion_matrix(all_targets, all_predictions)

    print(f"Accuracy on {data_name} data: {accuracy}")
    print(f"Confusion matrix on {data_name} data:\n {cm}")

    if data_name == "SST-2" or data_name == "HSOL":
        LFR_0 = cm[0][1] / (cm[0][0] + cm[0][1])
        LFR_1 = cm[1][0] / (cm[1][0] + cm[1][1])

        print(f"LFR for class 0: {LFR_0}")
        print(f"LFR for class 1: {LFR_1}")

    elif data_name == "AG":
        LFR_0 = (cm[0][1] + cm[0][2] + cm[0][3]) / (cm[0][0] + cm[0][1] + cm[0][2] + cm[0][3])
        LFR_1 = (cm[1][0] + cm[1][2] + cm[1][3]) / (cm[1][0] + cm[1][1] + cm[1][2] + cm[1][3])
        LFR_2 = (cm[2][0] + cm[2][1] + cm[2][3]) / (cm[2][0] + cm[2][1] + cm[2][2] + cm[2][3])
        LFR_3 = (cm[3][0] + cm[3][1] + cm[3][2]) / (cm[3][0] + cm[3][1] + cm[3][2] + cm[3][3])

        print(f"LFR for class 0: {LFR_0}")
        print(f"LFR for class 1: {LFR_1}")
        print(f"LFR for class 2: {LFR_2}")
        print(f"LFR for class 3: {LFR_3}")
        print(f"Average LFR:{(LFR_1 + LFR_2 + LFR_3) / 3}")

    else:
        raise ValueError("Invalid data name")

    return None


if __name__ == '__main__':
    args = get_arguments().parse_args()
    print("Arguments:")
    print(args.poisoning_epoch)
    print(args.learning_rate)
    print(args.attack_mode)
    print(args.dataset)

    for current_seed in args.seed:
        print(f"Current Seed: {current_seed}")
        print("Phase 1: Training the backdoor model")
        set_seed(current_seed)

        # Dataset and Dataloader
        tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer)
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"

        #####
        train_df_poisoned = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/train_poisoned.csv")
        train_df_clean = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/train_clean.csv")

        test_df_clean = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/test_clean.csv")
        test_df_poisoned = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/test_poisoned_all.csv")

        ######
        train_dataset_poisoned = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=train_df_poisoned)
        train_dataset_clean = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=train_df_clean)

        test_dataset_clean = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=test_df_clean)
        test_dataset_poisoned = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=test_df_poisoned)

        ######
        train_loader_poisoned = DataLoader(train_dataset_poisoned, batch_size=args.batch_size, shuffle=False)
        train_loader_clean = DataLoader(train_dataset_clean, batch_size=args.batch_size, shuffle=False)

        test_loader_clean = DataLoader(test_dataset_clean, batch_size=args.batch_size, shuffle=False)
        test_loader_poisoned = DataLoader(test_dataset_poisoned, batch_size=args.batch_size, shuffle=False)

        # Load the Llama2 model
        num_class = 4 if args.dataset == "AG" else 2
        
        backbone = LlamaModel.from_pretrained(args.victim_model, torch_dtype=torch.bfloat16, device_map="auto") 
        target_model = LlamaClassification(backbone, label_num=num_class, pooling="last")
        target_model.llama.config.pad_token_id = target_model.llama.config.eos_token_id
        device = target_model.llama.device

        if args.unlearning_method == "RGA":
            model_base = copy.deepcopy(target_model)

            for p in model_base.parameters():
                p.requires_grad_(False) 

        target_model.load_state_dict(torch.load(f"model/{args.attack_mode}_{args.dataset}_{current_seed}.pt"))      

        print("--------------- Poisoning Phase ---------------")
        for epoch in range(args.poisoning_epoch):
            print("\n-------------------")
            train_loss = train(target_model, train_loader_poisoned)
            print(f"Epoch {epoch + 1}/{args.poisoning_epoch} | Train loss {train_loss}")

        print("-------------------------------------------------")
        print("1. Clean Performance: Evaluation on Clean Test Set:")
        evaluate(target_model, test_loader_clean, args.dataset)

        print("2. Posioning Performance: Evaluation on Poisoned (ALL) Test Set:")
        evaluate(target_model, test_loader_poisoned, args.dataset)
        print("-------------------------------------------------")

        print("Phase 2: Poisoned Samples Detection")
        print("--------------- Dataset Statistics ---------------")
        clean_df = train_df_poisoned[train_df_poisoned["poisoned"] == 0].reset_index(drop=True)
        poisoned_df = train_df_poisoned[train_df_poisoned["poisoned"] == 1].reset_index(drop=True)

        print(f"Dataset: {args.dataset}")
        print(f"Clean Dataset Size: {len(clean_df)}")
        print(f"Poisoned Dataset Size: {len(poisoned_df)}")

        if args.dataset == "SST-2":
            print(f"Clean Positive Samples: {len(clean_df[clean_df['label'] == 1])}")
            print(f"Clean Negative Samples: {len(clean_df[clean_df['label'] == 0])}")

        elif args.dataset == "HSOL":
            print(f"Clean Non-toxic Samples: {len(clean_df[clean_df['label'] == 0])}")
            print(f"Clean Toxic Samples: {len(clean_df[clean_df['label'] == 1])}")

        elif args.dataset == "AG":
            print(f"Clean World Samples: {len(clean_df[clean_df['label'] == 0])}")
            print(f"Clean Sports Samples: {len(clean_df[clean_df['label'] == 1])}")
            print(f"Clean Business Samples: {len(clean_df[clean_df['label'] == 2])}")
            print(f"Clean Science Samples: {len(clean_df[clean_df['label'] == 3])}")

        else:
            raise ValueError("Invalid dataset")

        clean_loader_visualize = DataLoader(
            TargetDataset(tokenizer, args.max_len, clean_df), batch_size=args.batch_size, shuffle=False,
        )

        poisoned_loader_visualize = DataLoader(
            TargetDataset(tokenizer, args.max_len, poisoned_df), batch_size=args.batch_size, shuffle=False,
        )

        print("--------------- Dimension Reduction ---------------")
        poisoned_rep = rep_extract(target_model.llama, poisoned_loader_visualize)  # Extract poisoned rep
        clean_rep = rep_extract(target_model.llama, clean_loader_visualize)  # Extract clean rep
        overall_rep = np.concatenate((poisoned_rep, clean_rep), axis=0)  # Overall rep

        reducer = umap.UMAP(n_components=4,
                            n_neighbors=args.umap_n_neighbors,
                            min_dist=args.umap_min_dist,
                            metric="cosine",
                            random_state=current_seed
                            )
        projected_rep = reducer.fit_transform(overall_rep)

        print("--------------- Clustering ---------------")
        clusterer = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster,
                                    min_samples=args.min_samples,
                                    metric="euclidean").fit(projected_rep)
        labels = clusterer.labels_

        cnt = Counter(labels)
        num_clusters = 4 if args.dataset == "AG" else 2
        clean_cids = [cid for cid, _ in cnt.most_common() if cid != -1][:num_clusters]

        if args.dataset == "HSOL":
            # We found that CUBE does not work well only for HSOL dataset in Llama2. The reason is that the big clean cluster is split into
            # multiple sub-clusters with some seeds. Under the assumption that the poisoned samples are always the minority, we just take 15% of each class and
            # merge them into the clean cluster.
            for cid, n in cnt.most_common():
                print(f"Cluster ID: {cid} | Number of Samples: {n}")
                if cid not in clean_cids:
                    if n > len(train_df_poisoned[train_df_poisoned["label"] == 0]) * 0.10 or n > len(train_df_poisoned[train_df_poisoned["label"] == 1]) * 0.10:
                        clean_cids.append(cid)

        pred_clean_mask = np.isin(labels, clean_cids)
        pred_poison_mask = ~pred_clean_mask

        true_poison = np.concatenate([
            np.ones(len(poisoned_df), dtype=bool),
            np.zeros(len(clean_df), dtype=bool)
        ])

        prec, rec, f1, _ = precision_recall_fscore_support(true_poison, pred_poison_mask, average="binary")
        print(f" Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f}")

        # print("--------------- Visualization ---------------")
        # poisoned_rep_2 = projected_rep[:len(poisoned_df)]
        # clean_rep_2 = projected_rep[len(poisoned_df):]
        
        # plt.figure(figsize=(10, 8))
        # ax = plt.gca()
        # for s in ["top", "bottom", "left", "right"]:
        #     ax.spines[s].set_linewidth(4)
        
        # plt.xticks(fontsize=30)
        # plt.yticks(fontsize=30)
        # plt.tick_params(axis='both', labelsize=30, width=4, length=10)
        
        # plt.scatter(poisoned_rep_2[:, 0], poisoned_rep_2[:, 1], s=20, c=(0, 0.5, 0), label="Poisoned", alpha=0.5, marker="o")
        # plt.scatter(clean_rep_2[:, 0], clean_rep_2[:, 1], s=20, c=(0.5, 0.2, 0.8), label="Clean", alpha=0.5, marker="o")
        # plt.legend(frameon=False, fontsize=25, markerscale=3)
        # plt.tight_layout()
        # plt.savefig(f"Clustering/visualization_{args.dataset}_{args.attack_mode}_{current_seed}.png")
        # plt.show()
        # plt.close()
        
        # torch.save(target_model.state_dict(), f"model/{args.attack_mode}_{args.dataset}_{current_seed}.pt" )

        mask_poison_tbl = pd.Series(pred_poison_mask)
        mask_clean_tbl = ~mask_poison_tbl
        
        poisoned_df_detected = pd.concat([
            poisoned_df[mask_poison_tbl.iloc[:len(poisoned_df)].values],
            clean_df[mask_poison_tbl.iloc[len(poisoned_df):].values]
        ]).reset_index(drop=True)
        
        clean_df_detected = pd.concat([
            poisoned_df[mask_clean_tbl.iloc[:len(poisoned_df)].values],
            clean_df[mask_clean_tbl.iloc[len(poisoned_df):].values]
        ]).reset_index(drop=True)
        
        print(f"Detected Poisoned Samples: {len(poisoned_df_detected)}")
        print(f"Detected Clean Samples: {len(clean_df_detected)}")

        print("Phase 3: Unlearning Phase")
        # Use the detected poisoned samples to unlearn
        retain_loader = DataLoader(
            TargetDataset(tokenizer, args.max_len, clean_df_detected),
            batch_size=args.batch_size, shuffle=False
        )
        
        forget_loader = DataLoader(
            TargetDataset(tokenizer, args.max_len, poisoned_df_detected),
            batch_size=args.batch_size, shuffle=False
        )
        
        print(f"Unlearning Method:{args.unlearning_method}")
        if args.unlearning_method == "RT":
            """Retrain the model from scratch using the clean data as oracle baseline"""
            print("Defender: Retraining")
            initial_model = LlamaClassification(LlamaModel.from_pretrained(args.victim_model, torch_dtype=torch.bfloat16),
                                                label_num=num_class,
                                                pooling="first").to(device)
            for epoch in range(args.poisoning_epoch):
                print("\n-------------------")
                train_loss = train(initial_model, train_loader_clean)
                print(f"Epoch {epoch + 1}/{args.poisoning_epoch} | Train loss {train_loss}")
        
            print("-------------------------------------------------")
            print("---------------ReTrain Performance---------------")
            print("Clean Performance: Evaluation on Clean Test Set...")
            evaluate(initial_model, test_loader_clean, args.dataset)
        
            print("Poisoned Performance: Evaluation on Poisoned Test Set...")
            evaluate(initial_model, test_loader_poisoned, args.dataset)
        
        elif args.unlearning_method == "GA":
            """Use Gradient Ascent to unlearn the poisoned samples"""
            print("Defender: Gradient Ascent")
            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = ga_train(target_model, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")
        
                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("-----------Gradient Ascent Performance-----------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)
        
                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)
        
        elif args.unlearning_method == "NPO":
            """Use NPO to unlearn the poisoned samples"""
            print("Defender: NPO")
            ref_model = copy.deepcopy(target_model)
            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = npo_train(target_model, ref_model, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")
        
                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("------------------NPO Performance----------------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)
        
                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)
        
        elif args.unlearning_method == "RGA":
            """Use RGA to unlearn the poisoned samples"""
            print("Defender: RGA")
            logits_ref = compute_logits_ref(target_model, forget_loader)
        
            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = rga_train(target_model, model_base, logits_ref, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")
        
                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("------------------RGA Performance----------------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)
        
                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)
        
        else:
            raise ValueError(f"Unknown unlearning method: {args.unlearning_method}")
        
        del target_model
        del model_base

        gc.collect()
        torch.cuda.empty_cache()
