import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

from datasets.load_datasets import load_dataset
from utils import set_seed, setup_logger, setup_tensorboard, print_log_info, summarize_unlearning_results


class OurUnlearning:
    def __init__(self, model, dataset_name, checkpoint_path=None, batch_size=128,
                 device=None, log_dir="logs/our", target_layers=None,
                 sens_source="noise", seed=42, alpha_range=None):
        """
        Initialize the unlearning framework

        Args:
            model: Pre-trained model instance
            dataset_name: Dataset name ('mnist', 'cifar10', 'cifar100', 'svhn')
            checkpoint_path: Model weight file path
            batch_size: Batch size
            device: Computing device (None for automatic selection)
            log_dir: Log save directory
            target_layers: Target layer name list, auto-select based on model if None
            sens_source: Sensitivity computation method ('noise', 'sample', 'hybrid')
            seed: Random seed
            alpha_range: Alpha coverage range (min, max), auto-select based on dataset if None
        """
        set_seed(seed)

        self.device = device or torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.model = model.to(self.device)
        self.model_name = model.__class__.__name__
        self.checkpoint_path = checkpoint_path
        if checkpoint_path and os.path.exists(checkpoint_path):
            self._load_checkpoint(checkpoint_path)

        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.train_loader, self.test_loader, self.num_classes = self._load_dataset()

        self.sens_source = sens_source
        self.skip_noise = False

        self.forgotten_classes = set()

        if target_layers is None:
            self.target_layers = self._auto_select_target_layers()
        else:
            self.target_layers = target_layers

        if alpha_range is None:
            self.target_min, self.target_max = self._get_default_alpha_range()
        else:
            self.target_min, self.target_max = alpha_range

        self.log_dir = os.path.join(log_dir, dataset_name, self.model_name)
        os.makedirs(self.log_dir, exist_ok=True)

        self.logger, self.log_file = setup_logger(
            f"our_{self.dataset_name}", self.log_dir)
        self.tb_writer = setup_tensorboard(self.log_dir)

        self.logger.info(f"Initializing OUR unlearning method")
        self.logger.info(f"Model: {self.model_name}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Device: {self.device}")
        self.logger.info(f"Target layers: {self.target_layers}")
        self.logger.info(f"Sensitivity source: {self.sens_source}")
        self.logger.info(
            f"Alpha coverage range: {self.target_min:.4f}-{self.target_max:.4f}")

        self.classes = self._get_class_names()

        self.lambda_value = 10

    def _load_checkpoint(self, checkpoint_path):
        """Load model weights"""
        try:
            checkpoint = torch.load(
                checkpoint_path, map_location=self.device, weights_only=True)

            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint)

            print(f"Successfully loaded model weights: {checkpoint_path}")
        except Exception as e:
            print(f"Failed to load model weights: {e}")
            raise

    def _load_dataset(self):
        """Load dataset"""
        train_loader, test_loader, num_classes = load_dataset(
            self.dataset_name, batch_size=self.batch_size, num_workers=4
        )
        return train_loader, test_loader, num_classes

    def _get_class_names(self):
        """Get dataset class names"""
        if self.dataset_name == 'mnist':
            return [str(i) for i in range(10)]
        elif self.dataset_name == 'cifar10':
            return ['airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck']
        elif self.dataset_name == 'cifar100':
            return [f"class_{i}" for i in range(100)]
        elif self.dataset_name == 'tinyimagenet':
            return [f"class_{i}" for i in range(200)]
        elif self.dataset_name == 'svhn':
            return [str(i) for i in range(10)]
        else:
            return [f"class_{i}" for i in range(self.num_classes)]

    def _auto_select_target_layers(self):
        """Automatically select target layers based on model type"""
        model_name = self.model.__class__.__name__.lower()

        if 'resnet' in model_name:
            if hasattr(self.model, 'res1'):
                return ['res1', 'conv3', 'res2', 'conv4', 'classifier']
            else:
                return ['layer3', 'layer4', 'fc']
        elif 'lenet' in model_name:
            return ['conv2', 'fc']
        else:
            target_layers = []
            for name, _ in self.model.named_modules():
                if 'layer' in name or 'conv' in name or 'fc' in name or 'classifier' in name:
                    target_layers.append(name)

            if target_layers:
                num_layers = len(target_layers)
                start_idx = int(num_layers * 0.7)
                return target_layers[start_idx:]

            return ['fc', 'classifier']

    def _get_default_alpha_range(self):
        """Set default alpha coverage range based on dataset and model"""
        model_name = self.model.__class__.__name__.lower()
        if self.dataset_name == 'mnist':
            return 0.0410, 0.0410
        if self.dataset_name == 'cifar10':
            return 0.05, 0.06
        if self.dataset_name == 'cifar100':
            return 0.01, 0.03
        if self.dataset_name == 'tinyimagenet':
            return 0.004, 0.005
        if self.dataset_name == 'svhn' and 'resnet9' in model_name:
            return 0.043, 0.043
        if self.dataset_name == 'svhn':
            return 0.043, 0.05
        return 0.043, 0.043

    def get_data_loaders(self, class_idx):
        """Create data loaders for specific class, excluding already forgotten classes"""
        train_dataset = self.train_loader.dataset
        forget_indices = []
        retain_indices = []

        for i, (_, label) in enumerate(train_dataset):
            label_value = label if isinstance(label, int) else label.item()

            if label_value == class_idx:
                forget_indices.append(i)
            elif label_value not in self.forgotten_classes:
                retain_indices.append(i)

        forget_dataset = Subset(train_dataset, forget_indices)
        retain_dataset = Subset(train_dataset, retain_indices)

        test_dataset = self.test_loader.dataset
        class_test_indices = []
        non_class_test_indices = []
        active_classes_test_indices = []

        for i, (_, label) in enumerate(test_dataset):
            label_value = label if isinstance(label, int) else label.item()

            if label_value == class_idx:
                class_test_indices.append(i)
            else:
                non_class_test_indices.append(i)

                if label_value not in self.forgotten_classes:
                    active_classes_test_indices.append(i)

        class_test_dataset = Subset(test_dataset, class_test_indices)
        non_class_test_dataset = Subset(test_dataset, non_class_test_indices)
        active_classes_test_dataset = Subset(test_dataset, active_classes_test_indices)

        forget_loader = DataLoader(
            forget_dataset, batch_size=self.batch_size,
            shuffle=True, num_workers=4
        )
        retain_loader = DataLoader(
            retain_dataset, batch_size=self.batch_size,
            shuffle=True, num_workers=4
        )
        class_test_loader = DataLoader(
            class_test_dataset, batch_size=self.batch_size,
            shuffle=False, num_workers=4
        )
        non_class_test_loader = DataLoader(
            non_class_test_dataset, batch_size=self.batch_size,
            shuffle=False, num_workers=4
        )
        active_classes_test_loader = DataLoader(
            active_classes_test_dataset, batch_size=self.batch_size,
            shuffle=False, num_workers=4
        )

        return {
            'forget': forget_loader,
            'retain': retain_loader,
            'class_test': class_test_loader,
            'non_class_test': non_class_test_loader,
            'active_test': active_classes_test_loader
        }

    def evaluate(self, data_loader, dataset_name=""):
        """Evaluate model performance on given dataset"""
        self.model.eval()
        correct = 0
        total = 0
        loss_sum = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)

                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                loss_sum += loss.item()

        acc = 100.0 * correct / total
        avg_loss = loss_sum / len(data_loader)

        if dataset_name:
            self.logger.info(
                f"{dataset_name} - Accuracy: {acc:.2f}%, Loss: {avg_loss:.4f}")

        return {
            'accuracy': acc,
            'loss': avg_loss
        }

    def get_active_classes(self):
        """Get list of currently non-forgotten classes"""
        return [c for c in range(self.num_classes) if c not in self.forgotten_classes]

    def generate_synthetic_noise(self, num_samples, target_label, epochs=200, save_prefix=""):
        """Generate synthetic noise samples for parameter sensitivity computation"""
        self.logger.info(f"Starting synthetic noise generation for class {target_label} (samples: {num_samples})")
        start_time = time.time()

        if self.dataset_name == 'mnist':
            channels, img_size = 1, 32
        else:
            channels, img_size = 3, 32

        noise_samples = torch.randn(
            num_samples, channels, img_size, img_size, device=self.device) * 0.1
        noise_samples = torch.clamp(noise_samples, -1.0, 1.0)
        noise_samples.requires_grad = True

        optimizer = optim.Adam([noise_samples], lr=0.005)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs, eta_min=0.00005)
        target_labels = torch.full(
            (num_samples,), target_label, device=self.device)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            self.model.eval()
            optimizer.zero_grad()

            noise_samples.data.clamp_(-1.0, 1.0)
            outputs = self.model(noise_samples)

            adv_loss = criterion(outputs, target_labels)
            div_loss = self._diversity_regularization(noise_samples)
            loss = adv_loss + div_loss

            loss.backward()
            optimizer.step()
            scheduler.step()

            if (epoch + 1) % 50 == 0:
                preds = outputs.argmax(dim=1)
                acc = (preds == target_labels).float().mean() * 100.0
                self.logger.info(
                    f"{save_prefix} Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Accuracy: {acc:.2f}%")

                if epoch + 1 == epochs:
                    pass
                    # self._visualize_noise(noise_samples, target_label)

        elapsed = time.time() - start_time
        self.logger.info(f"Synthetic noise generation completed (time: {elapsed:.2f}s)")
        return noise_samples.detach()

    def _total_variation_loss(self, noise, weight=0.1):
        """Compute total variation loss for noise sample smoothing"""
        h_tv = torch.pow(noise[:, :, 1:, :] - noise[:, :, :-1, :], 2).sum()
        w_tv = torch.pow(noise[:, :, :, 1:] - noise[:, :, :, :-1], 2).sum()
        return weight * (h_tv + w_tv) / (noise.size(0) * noise.size(2) * noise.size(3))

    def _smoothness_regularization(self, noise, weight=0.05):
        """Smoothness regularization to reduce adjacent pixel differences"""
        dx = torch.mean(torch.abs(noise[:, :, :, 1:] - noise[:, :, :, :-1]))
        dy = torch.mean(torch.abs(noise[:, :, 1:, :] - noise[:, :, :-1, :]))
        return weight * (dx + dy)

    def _diversity_regularization(self, noise, weight=0.1):
        """Diversity regularization to avoid all noise samples being too similar"""
        noise_flat = noise.view(noise.size(0), -1)
        pairwise_dist = torch.cdist(noise_flat, noise_flat)
        mask = torch.triu(torch.ones_like(pairwise_dist), diagonal=1).bool()
        dist = pairwise_dist[mask]
        return weight * (1.0 / (dist.mean() + 1e-6)) if dist.numel() > 0 else torch.tensor(0.0, device=self.device)

    def _feature_consistency_loss(self, model, noise, target_label):
        """Feature consistency loss to make noise samples closer to real samples in feature space"""
        with torch.no_grad():
            logits = model(noise)
            probs = torch.softmax(logits, dim=1)
            confidence = probs[:, target_label].mean()
        return -confidence

    def _visualize_noise(self, noise_samples, target_label, num_samples=10):
        """Visualize noise samples"""
        try:
            samples = noise_samples[:num_samples].detach().cpu()

            fig, axes = plt.subplots(2, 5, figsize=(15, 6))
            axes = axes.flatten()

            for i in range(min(num_samples, len(samples))):
                img = samples[i].permute(1, 2, 0)
                if img.size(2) == 1:
                    img = img.squeeze(2)
                    axes[i].imshow(img, cmap='gray')
                else:
                    img = (img + 1) / 2
                    axes[i].imshow(img)

                axes[i].axis('off')
                axes[i].set_title(f"Class {target_label}")

            save_path = os.path.join(
                self.log_dir, f"noise_class_{target_label}.png")
            plt.tight_layout()
            plt.savefig(save_path)
            plt.close()

            self.logger.info(f"Noise sample visualization saved to: {save_path}")
        except Exception as e:
            self.logger.warning(f"Noise sample visualization failed: {e}")

    def compute_param_sensitivity(self, data_source, labels=None, is_synthetic=False):
        """Compute parameter sensitivity"""
        start_time = time.time()
        param_sens_dict = {name: torch.zeros_like(p)
                           for name, p in self.model.named_parameters()
                           if any(layer in name for layer in self.target_layers)}

        self.model.eval()
        total_samples = 0

        if is_synthetic:
            batch_size = 100
            num_batches = (data_source.size(0) + batch_size - 1) // batch_size
            for i in range(num_batches):
                batch_data = data_source[i * batch_size:(i + 1) * batch_size]
                batch_labels = labels[i * batch_size:(i + 1) * batch_size]
                self._update_param_sensitivity(
                    param_sens_dict, batch_data, batch_labels)
                total_samples += batch_data.size(0)
        else:
            for inputs, labels in data_source:
                self._update_param_sensitivity(param_sens_dict, inputs, labels)
                total_samples += inputs.size(0)

        for name in param_sens_dict:
            param_sens_dict[name] /= total_samples

        elapsed = time.time() - start_time
        self.logger.info(f"Parameter sensitivity computation completed (time: {elapsed:.2f}s)")
        return param_sens_dict

    def _update_param_sensitivity(self, param_sens_dict, inputs, labels):
        """Update parameter sensitivity computation"""
        self.model.zero_grad()
        inputs, labels = inputs.to(self.device), labels.to(self.device)
        outputs = self.model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()

        for name, param in self.model.named_parameters():
            if param.grad is not None and name in param_sens_dict:
                param_sens_dict[name] += (param.grad.data **
                                          2) * inputs.size(0)

    def apply_unlearning(self, param_sens_forget, param_sens_retain, lambda_=None):
        """Apply model unlearning"""
        start_time = time.time()

        alpha = self._adaptive_alpha(param_sens_forget, param_sens_retain)

        if lambda_ is None:
            lambda_ = self.lambda_value

        self.logger.info(f"Adaptive alpha: {alpha:.6f}, lambda: {lambda_:.6f}")

        total_params = 0
        modified_params = 0

        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in param_sens_forget and name in param_sens_retain:
                    S_forget = param_sens_forget[name]
                    S_retain = param_sens_retain[name]

                    mask = S_forget > alpha * S_retain

                    scale = torch.where(mask,
                                        torch.minimum((alpha / lambda_) * S_retain / (S_forget + 1e-10),
                                                      torch.ones_like(S_forget)),
                                        torch.ones_like(S_forget))

                    param_count = param.numel()
                    modified_count = mask.sum().item()
                    total_params += param_count
                    modified_params += modified_count

                    param.data *= scale

                    self.logger.info(
                        f"Layer {name}: Modified {modified_count}/{param_count} parameters "
                        f"({100.0 * modified_count / param_count:.2f}%)")

        elapsed = time.time() - start_time
        modification_ratio = 100.0 * modified_params / \
            total_params if total_params > 0 else 0
        self.logger.info(
            f"Total parameter modifications: {modified_params}/{total_params} ({modification_ratio:.2f}%) [time: {elapsed:.2f}s]")

    def _adaptive_alpha(self, param_sens_forget, param_sens_retain, max_iter=30):
        """Adaptively compute alpha value to keep parameter coverage within target range"""
        if not any(name in param_sens_retain for name in param_sens_forget):
            self.logger.warning("Cannot compute alpha due to no overlapping target layers. Returning default value 1.0")
            return 1.0

        def get_coverage(alpha_val):
            total_params = 0
            modified_params = 0
            for name in param_sens_forget:
                if name in param_sens_retain:
                    S_forget = param_sens_forget[name]
                    S_retain = param_sens_retain[name]
                    mask = S_forget > alpha_val * S_retain
                    total_params += S_forget.numel()
                    modified_params += mask.sum().item()
            if total_params == 0:
                return 0.0
            return modified_params / total_params

        alpha_low, alpha_high = 1e-7, 1e4
        alpha = 1.0

        for i in range(max_iter):
            alpha = (alpha_low + alpha_high) / 2
            coverage = get_coverage(alpha)

            self.logger.debug(
                f"Binary search iteration {i+1}: alpha={alpha:.6f}, coverage={coverage:.6f}, "
                f"target range=[{self.target_min:.4f}, {self.target_max:.4f}]")

            if self.target_min <= coverage <= self.target_max:
                self.logger.info(
                    f"Found suitable alpha value: {alpha:.6f}, coverage: {coverage:.6f}")
                return alpha

            if coverage > self.target_max:
                alpha_low = alpha
            else:
                alpha_high = alpha

            if abs(alpha_high - alpha_low) < 1e-8:
                self.logger.warning(f"Alpha search converged but target coverage not reached. "
                                    f"Final alpha: {alpha:.6f}, coverage: {coverage:.6f}")
                break

        self.logger.info(f"Adaptive alpha search completed. "
                         f"Final alpha: {alpha:.6f}, coverage: {coverage:.6f}")
        return alpha

    def save_model(self, class_idx):
        """Save model after unlearning"""
        if len(self.forgotten_classes) == 0:
            save_suffix = f"forget_{class_idx}"
        else:
            prior_forgotten = "_".join(map(str, sorted(self.forgotten_classes)))
            save_suffix = f"forget_{prior_forgotten}_then_{class_idx}"

        save_dir = 'checkpoints/forget/our'
        os.makedirs(save_dir, exist_ok=True)

        filename = f"{self.model_name}_{self.dataset_name}_{save_suffix}.pth"
        save_path = os.path.join(save_dir, filename)

        torch.save(self.model.state_dict(), save_path)
        self.logger.info(f"Model after unlearning saved to: {save_path}")

        return save_path

    def unlearn_class(self, class_idx):
        """Unlearn specific class"""
        class_name = self.classes[class_idx]
        self.logger.info(f"Starting to unlearn class {class_name} (index {class_idx})")

        forgotten_count = len(self.forgotten_classes)
        active_classes = self.get_active_classes()

        self.logger.info(f"Currently forgotten classes: {list(sorted(self.forgotten_classes))}")
        self.logger.info(f"Currently active classes: {active_classes}")

        loaders = self.get_data_loaders(class_idx)

        self.logger.info("Initial model evaluation...")
        initial_results = {
            'class_test': self.evaluate(loaders['class_test'], f"Initial - Test set {class_name} class"),
            'non_class_test': self.evaluate(loaders['non_class_test'], f"Initial - Test set non-{class_name} class"),
            'active_test': self.evaluate(loaders['active_test'], f"Initial - Test set unforgotten classes")
        }

        if self.sens_source in ["noise", "hybrid"] and not self.skip_noise:
            self.logger.info("Generating synthetic noise samples...")
            noise_samples = self.generate_synthetic_noise(
                500, class_idx, save_prefix=f"noise_{class_name}")
            noise_labels = torch.tensor([class_idx] * 500, device=self.device)

        self.logger.info("Computing parameter sensitivity for forget class...")
        if self.sens_source == "noise" and not self.skip_noise:
            param_sens_forget = self.compute_param_sensitivity(
                noise_samples, noise_labels, is_synthetic=True)
        elif self.sens_source == "hybrid" and not self.skip_noise:
            param_sens_forget = self.compute_param_sensitivity(
                noise_samples, noise_labels, is_synthetic=True)
        else:
            param_sens_forget = self.compute_param_sensitivity(
                loaders['forget'], is_synthetic=False)

        self.logger.info("Computing parameter sensitivity for retain classes...")
        if self.sens_source == "noise" and not self.skip_noise:
            all_noise_samples = []
            all_noise_labels = []

            for other_class in active_classes:
                if other_class != class_idx:
                    self.logger.info(f"Generating noise samples for retain class {other_class}...")
                    num_samples = 200
                    noise_samples_retain = self.generate_synthetic_noise(
                        num_samples, other_class, save_prefix=f"noise_retain_class_{other_class}")
                    noise_labels_retain = torch.tensor(
                        [other_class] * num_samples, device=self.device)

                    all_noise_samples.append(noise_samples_retain)
                    all_noise_labels.append(noise_labels_retain)

            if not all_noise_samples:
                self.logger.warning("No remaining classes for generating retain noise samples!")
                return {
                    'class_name': class_name,
                    'class_idx': class_idx,
                    'skipped': True,
                    'reason': "No remaining classes for generating retain noise samples"
                }

            combined_noise_samples = torch.cat(all_noise_samples, dim=0)
            combined_noise_labels = torch.cat(all_noise_labels, dim=0)

            self.logger.info(
                f"Using combined {len(combined_noise_labels)} retain class noise samples for sensitivity computation")
            param_sens_retain = self.compute_param_sensitivity(
                combined_noise_samples, combined_noise_labels, is_synthetic=True)
        else:
            param_sens_retain = self.compute_param_sensitivity(
                loaders['retain'], is_synthetic=False)

        self.logger.info("Applying unlearning...")
        self.apply_unlearning(param_sens_forget, param_sens_retain)

        self.logger.info("Model evaluation after unlearning...")
        final_results = {
            'class_test': self.evaluate(loaders['class_test'], f"After unlearning - Test set {class_name} class"),
            'non_class_test': self.evaluate(loaders['non_class_test'], f"After unlearning - Test set non-{class_name} class"),
            'active_test': self.evaluate(loaders['active_test'], f"After unlearning - Test set unforgotten classes")
        }

        forget_acc_change = final_results['class_test']['accuracy'] - \
            initial_results['class_test']['accuracy']
        retain_acc_change = final_results['active_test']['accuracy'] - \
            initial_results['active_test']['accuracy']

        self.logger.info(
            f"Forget class accuracy change: {forget_acc_change:.2f}% ({initial_results['class_test']['accuracy']:.2f}% → {final_results['class_test']['accuracy']:.2f}%)")
        self.logger.info(
            f"Retain class accuracy change: {retain_acc_change:.2f}% ({initial_results['active_test']['accuracy']:.2f}% → {final_results['active_test']['accuracy']:.2f}%)")

        self.tb_writer.add_scalar(
            f'Accuracy/{class_name}/Forget', final_results['class_test']['accuracy'], forgotten_count)
        self.tb_writer.add_scalar(
            f'Accuracy/{class_name}/Retain', final_results['active_test']['accuracy'], forgotten_count)
        self.tb_writer.add_scalar(
            f'AccChange/{class_name}/Forget', forget_acc_change, forgotten_count)
        self.tb_writer.add_scalar(
            f'AccChange/{class_name}/Retain', retain_acc_change, forgotten_count)

        saved_path = self.save_model(class_idx)

        self.forgotten_classes.add(class_idx)
        self.logger.info(f"Class {class_name} (index {class_idx}) added to forgotten classes set")
        self.logger.info(f"Currently forgotten classes: {list(sorted(self.forgotten_classes))}")

        return {
            'class_name': class_name,
            'class_idx': class_idx,
            'initial': initial_results,
            'final': final_results,
            'forget_acc_change': forget_acc_change,
            'retain_acc_change': retain_acc_change,
            'saved_model_path': saved_path
        }

    def unlearn_classes(self, class_indices=None, save_models=True):
        """Unlearn multiple classes, unlearn all classes if class_indices is None"""
        # Determine classes to unlearn
        if class_indices is None:
            class_indices = list(range(self.num_classes))
            self.logger.info(f"Will unlearn all {len(class_indices)} classes")
        else:
            class_names = [self.classes[idx] for idx in class_indices]
            self.logger.info(
                f"Will sequentially unlearn {len(class_indices)} specified classes: {class_names}")

        results = []
        saved_models = []

        # Reset forgotten classes set
        self.forgotten_classes = set()

        # Sequentially unlearn multiple classes
        for i, class_idx in enumerate(class_indices):
            self.logger.info(f"========== Starting to unlearn {i+1}/{len(class_indices)} class: {self.classes[class_idx]} ==========")

            # Check if class is already forgotten
            if class_idx in self.forgotten_classes:
                self.logger.warning(f"Class {class_idx} has already been forgotten, skipping")
                continue

            # Check if there are enough classes left to unlearn
            active_classes = self.get_active_classes()
            if len(active_classes) <= 1:
                self.logger.warning(f"Only one active class remaining {active_classes}, cannot continue unlearning")
                break

            # Unlearn specific class (no need to reload model each time, use current state for sequential unlearning)
            result = self.unlearn_class(class_idx)

            # If unlearning is skipped, do not add to results
            if result.get('skipped', False):
                self.logger.warning(f"Unlearning class {class_idx} was skipped: {result.get('reason', 'Unknown reason')}")
                continue

            results.append(result)

            if save_models and 'saved_model_path' in result:
                saved_models.append(result['saved_model_path'])

            # Add separator line
            self.logger.info("-" * 50)

        # Summary of results
        self._summarize_results(results)

        # Log all saved model paths
        if saved_models:
            self.logger.info(f"Total {len(saved_models)} models saved after unlearning:")
            for path in saved_models:
                self.logger.info(f"  - {path}")

        return results

    def _summarize_results(self, results):
        """Summarize all unlearning results"""
        summarize_unlearning_results(results, self.logger)

        # Save to TensorBoard
        if results:
            avg_forget_change = np.mean([r['forget_acc_change'] for r in results])
            avg_retain_change = np.mean([r['retain_acc_change'] for r in results])
            self.tb_writer.add_scalar(
                'Summary/AvgForgetChange', avg_forget_change, 0)
            self.tb_writer.add_scalar(
                'Summary/AvgRetainChange', avg_retain_change, 0)

        # Close TensorBoard writer and print log info
        self.tb_writer.close()
        print_log_info(self.log_file, self.log_dir)