import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import copy
import umap
import gc
import hdbscan
from torch import nn as nn
from config import get_arguments
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import Dataset
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support


def set_seed(random_seed=11):
    # Set the seed value all over the place to make this reproducible.
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)


class TargetDataset(Dataset):
    def __init__(self, tokenizer, max_len, data):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.text = data["text"]
        self.targets = data["label"]
        self.flag = data["poisoned"]

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = str(self.text[item])
        target = self.targets[item]
        flag = self.flag[item]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors="pt",
            pad_to_max_length=True
        )

        return {
            "text": text,
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "label": torch.tensor(target, dtype=torch.long),
            "flag": flag
        }


class DistilBertClassification(nn.Module):
    def __init__(self, bert, label_num=2):
        super(DistilBertClassification, self).__init__()
        self.bert = bert
        self.classifier = nn.Linear(768, label_num)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids,
                           attention_mask=attention_mask,
                           output_hidden_states=True,
                           output_attentions=True)

        sequence_output = output.last_hidden_state
        cls_rep = sequence_output[:, 0, :]

        return self.classifier(cls_rep)


def extract_cls_hidden(bert, input_ids, attn_mask, layer_idx=-1):
    """Return CLS vector from a specific layer (no_grad context assumed)."""
    hs = bert(
        input_ids=input_ids,
        attention_mask=attn_mask,
        output_hidden_states=True,
    ).hidden_states[layer_idx][:, 0, :]  # (batch, 768)
    return hs


def rep_extract(bert, loader):
    bert.eval()
    rep_vec = []

    with torch.no_grad():
        for batch in tqdm(loader):
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            cls_vec = extract_cls_hidden(bert, input_ids, attn_mask, layer_idx=-1)
            rep_vec.append(cls_vec.cpu())

    return torch.cat(rep_vec).numpy()


def train(model, dataloader):
    total_loss = 0
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)


    for batch in tqdm(dataloader):
        model.train()
        model.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0

        optimizer.step()
        optimizer.zero_grad()

    return total_loss / len(dataloader)  # average batch loss for each epoch


def ga_train(model, retain_loader, forget_loader):
    """Gradient Ascent Unlearning: ReTain - Forget"""
    clean_loss = 0
    poison_loss = 0

    batches_retain = list(retain_loader)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    for batch_forget in forget_loader:
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)

        # Calculate loss
        loss_retain = loss_fn(logits_retain, labels_retain)  # clean samples: keep the functioning of the model
        loss_forget = loss_fn(logits_forget, labels_forget)  # poisoned samples: forget the backdoor of the model
        loss = loss_retain - loss_forget

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def npo_train(model, ref_model, retain_loader, forget_loader):
    """"Using RT + NPO """
    clean_loss = 0
    poison_loss = 0
    beta = 1

    batches_retain = list(retain_loader)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    loss_fn_base = torch.nn.CrossEntropyLoss(reduction="none")
    loss_fn_mean = torch.nn.CrossEntropyLoss(reduction="mean")

    for batch_forget in tqdm(forget_loader):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        # Retain loss
        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        loss_retain = loss_fn_mean(logits_retain, labels_retain)

        # NPO loss
        # [Reference model]
        with torch.no_grad():
            ref_model.eval()
            logits_forget_ref = ref_model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)
            loss_forget_ref = loss_fn_base(logits_forget_ref, labels_forget)

        # [Policy model]
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)
        loss_forget = loss_fn_base(logits_forget, labels_forget)

        negative_log_ratio = loss_forget - loss_forget_ref
        npo_loss = -F.logsigmoid(beta * negative_log_ratio).mean() * 2 / beta

        loss = loss_retain + npo_loss  # RT + NPO
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)  # clip the gradients to 1.0
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.mean().item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def rga_train(model, model_base, logits_ref, retain_loader, forget_loader):
    """"Robustness Gradient Ascent Unlearning"""
    clean_loss = 0
    poison_loss = 0

    batches_retain = list(retain_loader)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    loss_fn = torch.nn.CrossEntropyLoss()

    for batch_forget, logits_ref_batch in tqdm(zip(forget_loader, logits_ref)):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # Sample retain batch
        batch_retain = random.choice(batches_retain)
        input_ids_retain = batch_retain["input_ids"].to(device)
        attention_mask_retain = batch_retain["attention_mask"].to(device)
        labels_retain = batch_retain["label"].to(device)

        # Forget batch
        input_ids_forget = batch_forget["input_ids"].to(device)
        attention_mask_forget = batch_forget["attention_mask"].to(device)
        labels_forget = batch_forget["label"].to(device)

        logits_retain = model(input_ids=input_ids_retain, attention_mask=attention_mask_retain)
        logits_forget = model(input_ids=input_ids_forget, attention_mask=attention_mask_forget)

        # Adaptive Re-weight Gradient Ascent
        with torch.no_grad():
            weight = adaptive_weight(logits_forget, logits_ref_batch.to(device), scale_power=2.0).item()

        # Calculate loss
        loss_retain = loss_fn(logits_retain, labels_retain)  # clean samples: keep the functioning of the model
        loss_forget = loss_fn(logits_forget, labels_forget)  # poisoned samples: forget the backdoor of the model
        loss = loss_retain - weight * loss_forget

        # Regularization Term
        for param_1, param_2 in zip(model.bert.parameters(), model_base.parameters()):
            loss += 5e-2 * torch.norm(param_1 - param_2).to(device)  # L2 regularization

        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        clean_loss += loss_retain.item()
        poison_loss += loss_forget.item()

    return clean_loss / len(forget_loader), poison_loss / len(forget_loader)


def compute_logits_ref(ref_model, forget_loader):
    logits_ref = []

    with torch.no_grad():
        for batch in tqdm(forget_loader):
            ref_model.eval()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            logits = ref_model(input_ids=input_ids, attention_mask=attention_mask)
            logits_ref.append(logits.detach().cpu())

    return logits_ref


def adaptive_weight(input_logits, ref_logits, scale_power=1.0):
    input_prob = torch.log_softmax(input_logits, dim=1)
    ref_prob = torch.softmax(ref_logits, dim=1)
    kl_divergence = F.kl_div(input_prob, ref_prob, reduction="batchmean")
    weight = torch.pow(torch.exp(-kl_divergence), scale_power)

    return weight


def evaluate(model, dataloader, data_name):
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            model.eval()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"]

            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(logits, dim=1)

            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(labels.numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    cm = confusion_matrix(all_targets, all_predictions)

    print(f"Accuracy on {data_name} data: {accuracy}")
    print(f"Confusion matrix on {data_name} data:\n {cm}")

    if data_name == "SST-2" or data_name == "HSOL":
        LFR_0 = cm[0][1] / (cm[0][0] + cm[0][1])
        LFR_1 = cm[1][0] / (cm[1][0] + cm[1][1])

        print(f"LFR for class 0: {LFR_0}")
        print(f"LFR for class 1: {LFR_1}")

    elif data_name == "AG":
        LFR_0 = (cm[0][1] + cm[0][2] + cm[0][3]) / (cm[0][0] + cm[0][1] + cm[0][2] + cm[0][3])
        LFR_1 = (cm[1][0] + cm[1][2] + cm[1][3]) / (cm[1][0] + cm[1][1] + cm[1][2] + cm[1][3])
        LFR_2 = (cm[2][0] + cm[2][1] + cm[2][3]) / (cm[2][0] + cm[2][1] + cm[2][2] + cm[2][3])
        LFR_3 = (cm[3][0] + cm[3][1] + cm[3][2]) / (cm[3][0] + cm[3][1] + cm[3][2] + cm[3][3])

        print(f"LFR for class 0: {LFR_0}")
        print(f"LFR for class 1: {LFR_1}")
        print(f"LFR for class 2: {LFR_2}")
        print(f"LFR for class 3: {LFR_3}")
        print(f"Average LFR:{(LFR_1 + LFR_2 + LFR_3) / 3}")

    else:
        raise ValueError("Invalid data name")

    return None


if __name__ == '__main__':
    args = get_arguments().parse_args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    for current_seed in args.seed:
        print("Phase 1: Training the backdoor model")
        set_seed(current_seed)

        # Dataset and Dataloader
        tokenizer = DistilBertTokenizer.from_pretrained(args.tokenizer)

        #####
        train_df_poisoned = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/train_poisoned.csv")
        train_df_clean = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/train_clean.csv")

        test_df_clean = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/test_clean.csv")
        test_df_poisoned = pd.read_csv(f"Data_Poisoning/{args.attack_mode}/{args.dataset}/test_poisoned_all.csv")

        ######
        train_dataset_poisoned = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=train_df_poisoned)
        train_dataset_clean = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=train_df_clean)

        test_dataset_clean = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=test_df_clean)
        test_dataset_poisoned = TargetDataset(tokenizer=tokenizer, max_len=args.max_len, data=test_df_poisoned)

        ######
        train_loader_poisoned = DataLoader(train_dataset_poisoned, batch_size=args.batch_size, shuffle=False)
        train_loader_clean = DataLoader(train_dataset_clean, batch_size=args.batch_size, shuffle=False)

        test_loader_clean = DataLoader(test_dataset_clean, batch_size=args.batch_size, shuffle=False)
        test_loader_poisoned = DataLoader(test_dataset_poisoned, batch_size=args.batch_size, shuffle=False)

        # Load the Bert model
        num_class = 4 if args.dataset == "AG" else 2
        target_model = DistilBertClassification(DistilBertModel.from_pretrained(args.victim_model), label_num=num_class).to(device)

        print("--------------- Poisoning Phase ---------------")
        for epoch in range(args.poisoning_epoch):
            print("\n-------------------")
            train_loss = train(target_model, train_loader_poisoned)
            print(f"Epoch {epoch + 1}/{args.poisoning_epoch} | Train loss {train_loss}")

        print("-------------------------------------------------")
        print("1. Clean Performance: Evaluation on Clean Test Set:")
        evaluate(target_model, test_loader_clean, args.dataset)

        print("2. Posioning Performance: Evaluation on Poisoned (ALL) Test Set:")
        evaluate(target_model, test_loader_poisoned, args.dataset)
        print("-------------------------------------------------")

        print("Phase 2: Poisoned Samples Detection")
        print("--------------- Dataset Statistics ---------------")
        clean_df = train_df_poisoned[train_df_poisoned["poisoned"] == 0].reset_index(drop=True)
        poisoned_df = train_df_poisoned[train_df_poisoned["poisoned"] == 1].reset_index(drop=True)

        print(f"Dataset: {args.dataset}")
        print(f"Clean Dataset Size: {len(clean_df)}")
        print(f"Poisoned Dataset Size: {len(poisoned_df)}")

        if args.dataset == "SST-2":
            print(f"Clean Positive Samples: {len(clean_df[clean_df['label'] == 1])}")
            print(f"Clean Negative Samples: {len(clean_df[clean_df['label'] == 0])}")

        elif args.dataset == "HSOL":
            print(f"Clean Non-toxic Samples: {len(clean_df[clean_df['label'] == 0])}")
            print(f"Clean Toxic Samples: {len(clean_df[clean_df['label'] == 1])}")

        elif args.dataset == "AG":
            print(f"Clean World Samples: {len(clean_df[clean_df['label'] == 0])}")
            print(f"Clean Sports Samples: {len(clean_df[clean_df['label'] == 1])}")
            print(f"Clean Business Samples: {len(clean_df[clean_df['label'] == 2])}")
            print(f"Clean Science Samples: {len(clean_df[clean_df['label'] == 3])}")

        else:
            raise ValueError("Invalid dataset")

        clean_loader_visualize = DataLoader(
            TargetDataset(tokenizer, args.max_len, clean_df), batch_size=args.batch_size, shuffle=False,
        )

        poisoned_loader_visualize = DataLoader(
            TargetDataset(tokenizer, args.max_len, poisoned_df), batch_size=args.batch_size, shuffle=False,
        )

        print("--------------- Dimension Reduction ---------------")
        poisoned_rep = rep_extract(target_model.bert, poisoned_loader_visualize)  # Extract poisoned rep
        clean_rep = rep_extract(target_model.bert, clean_loader_visualize)  # Extract clean rep
        overall_rep = np.concatenate((poisoned_rep, clean_rep), axis=0)  # Overall rep

        reducer = umap.UMAP(n_components=4,
                            n_neighbors=args.umap_n_neighbors,
                            min_dist=args.umap_min_dist,
                            metric="cosine",
                            random_state=current_seed
                            )
        projected_rep = reducer.fit_transform(overall_rep)

        print("--------------- Clustering ---------------")
        clusterer = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
                                    min_samples=args.min_samples,
                                    metric="euclidean").fit(projected_rep)
        labels = clusterer.labels_

        sizes = Counter(labels[labels >= -1])
        for cid, n in sizes.most_common():
            print(f"Cluster {cid} size: {n}")

        cnt = Counter(labels)
        num_clusters = 4 if args.dataset == "AG" else 2
        clean_cids = [cid for cid, _ in cnt.most_common() if cid != -1][:num_clusters]
        pred_clean_mask = np.isin(labels, clean_cids)
        pred_poison_mask = ~pred_clean_mask

        true_poison = np.concatenate([
            np.ones(len(poisoned_df), dtype=bool),
            np.zeros(len(clean_df), dtype=bool)
        ])

        prec, rec, f1, _ = precision_recall_fscore_support(true_poison, pred_poison_mask, average="binary")
        print(f" Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f}")

        # print("--------------- Visualization ---------------")
        # poisoned_rep_2 = projected_rep[:len(poisoned_df)]
        # clean_rep_2 = projected_rep[len(poisoned_df):]
        #
        # plt.figure(figsize=(10, 8))
        # ax = plt.gca()
        # for s in ["top", "bottom", "left", "right"]:
        #     ax.spines[s].set_linewidth(4)
        #
        # plt.xticks(fontsize=30)
        # plt.yticks(fontsize=30)
        # plt.tick_params(axis='both', labelsize=30, width=4, length=10)
        #
        # plt.scatter(poisoned_rep_2[:, 0], poisoned_rep_2[:, 1], s=20, c=(0, 0.5, 0), label="Poisoned", alpha=0.5, marker="o")
        # plt.scatter(clean_rep_2[:, 0], clean_rep_2[:, 1], s=20, c=(0.5, 0.2, 0.8), label="Clean", alpha=0.5, marker="o")
        # plt.legend(frameon=False, fontsize=25, markerscale=3)
        # plt.tight_layout()
        # plt.savefig(f"Clustering/visualization_{args.dataset}_{args.attack_mode}_{current_seed}.png")
        # plt.show()
        # plt.close()

        mask_poison_tbl = pd.Series(pred_poison_mask)
        mask_clean_tbl = ~mask_poison_tbl

        poisoned_df_detected = pd.concat([
            poisoned_df[mask_poison_tbl.iloc[:len(poisoned_df)].values],
            clean_df[mask_poison_tbl.iloc[len(poisoned_df):].values]
        ]).reset_index(drop=True)

        clean_df_detected = pd.concat([
            poisoned_df[mask_clean_tbl.iloc[:len(poisoned_df)].values],
            clean_df[mask_clean_tbl.iloc[len(poisoned_df):].values]
        ]).reset_index(drop=True)

        print(f"Detected Poisoned Samples: {len(poisoned_df_detected)}")
        print(f"Detected Clean Samples: {len(clean_df_detected)}")

        print("Phase 3: Unlearning Phase")
        # Use the detected poisoned samples to unlearn
        retain_loader = DataLoader(
            TargetDataset(tokenizer, args.max_len, clean_df_detected),
            batch_size=args.batch_size, shuffle=False
        )

        forget_loader = DataLoader(
            TargetDataset(tokenizer, args.max_len, poisoned_df_detected),
            batch_size=args.batch_size, shuffle=False
        )

        print(f"Unlearning Method:{args.unlearning_method}")
        if args.unlearning_method == "RT":
            """Retrain the model from scratch using the clean data as oracle baseline"""
            print("Defender: Retraining")
            initial_model = DistilBertClassification(DistilBertModel.from_pretrained(args.victim_model), label_num=num_class).to(device)
            for epoch in range(args.poisoning_epoch):
                print("\n-------------------")
                train_loss = train(initial_model, train_loader_clean)
                print(f"Epoch {epoch + 1}/{args.poisoning_epoch} | Train loss {train_loss}")

            print("-------------------------------------------------")
            print("---------------ReTrain Performance---------------")
            print("Clean Performance: Evaluation on Clean Test Set...")
            evaluate(initial_model, test_loader_clean, args.dataset)

            print("Poisoned Performance: Evaluation on Poisoned Test Set...")
            evaluate(initial_model, test_loader_poisoned, args.dataset)

        elif args.unlearning_method == "GA":
            """Use Gradient Ascent to unlearn the poisoned samples"""
            print("Defender: Gradient Ascent")
            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = ga_train(target_model, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")

                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("-----------Gradient Ascent Performance-----------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)

                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)

        elif args.unlearning_method == "NPO":
            """Use NPO to unlearn the poisoned samples"""
            print("Defender: NPO")
            ref_model = copy.deepcopy(target_model)
            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = npo_train(target_model, ref_model, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")

                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("------------------NPO Performance----------------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)

                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)

        elif args.unlearning_method == "RGA":
            """Use RGA to unlearn the poisoned samples"""
            print("Defender: RGA")
            logits_ref = compute_logits_ref(target_model, forget_loader)
            model_base = DistilBertModel.from_pretrained(args.victim_model).to(device)

            for epoch in range(args.unlearning_epoch):
                print("\n-------------------")
                clean_loss, poison_loss = rga_train(target_model, model_base, logits_ref, retain_loader, forget_loader)
                print(f"Epoch {epoch + 1}/{args.unlearning_epoch} | Clean loss {clean_loss} | Poison loss {poison_loss}")

                if (epoch + 1) % 10 == 0:
                    print("-------------------------------------------------")
                    print("------------------RGA Performance----------------")
                    print(f"Unlearning Epoch: {epoch + 1}")
                    print("Clean Performance: Evaluation on Clean Test Set...")
                    evaluate(target_model, test_loader_clean, args.dataset)

                    print("Poisoned Performance: Evaluation on Poisoned Test Set...")
                    evaluate(target_model, test_loader_poisoned, args.dataset)

        else:
            raise ValueError(f"Unknown unlearning method: {args.unlearning_method}")

        gc.collect()
        torch.cuda.empty_cache()














