import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

from datasets.load_datasets import load_dataset
from utils import set_seed, setup_logger, setup_tensorboard, print_log_info, summarize_unlearning_results


class PerturbationUnlearning:
    def __init__(self, model, dataset_name, checkpoint_path=None, batch_size=128,
                 device=None, log_dir="logs/ablation/perturbation", target_layers=None,
                 seed=42, k_percent=5.0, noise_data_dir=None):
        """
        Initialize the unlearning framework for ablation study.

        Args:
            model: Pre-trained model instance
            dataset_name: Dataset name ('mnist', 'cifar10', 'cifar100', 'svhn')
            checkpoint_path: Model weight file path
            batch_size: Batch size
            device: Computing device (None for automatic selection)
            log_dir: Log save directory
            target_layers: Target layer name list, auto-select based on model if None
            seed: Random seed
            k_percent: Percentage of top sensitive parameters to prune.
            noise_data_dir: Directory to store or load pre-generated noise.
        """
        set_seed(seed)

        self.device = device or torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.model = model.to(self.device)
        self.model_name = model.__class__.__name__
        self.checkpoint_path = checkpoint_path
        if checkpoint_path and os.path.exists(checkpoint_path):
            self._load_checkpoint(checkpoint_path)

        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.train_loader, self.test_loader, self.num_classes = self._load_dataset()

        self.sens_source = "noise"  # Fixed for this ablation study
        self.k_percent = k_percent

        if target_layers is None:
            self.target_layers = self._auto_select_target_layers()
        else:
            self.target_layers = target_layers

        self.log_dir = os.path.join(log_dir, dataset_name, self.model_name)
        os.makedirs(self.log_dir, exist_ok=True)

        if noise_data_dir:
            self.noise_data_dir = noise_data_dir
        else:
            self.noise_data_dir = os.path.join(self.log_dir, "pregenerated_noise")
        os.makedirs(self.noise_data_dir, exist_ok=True)

        self.logger, self.log_file = setup_logger(
            f"ablation_perturbation_{self.dataset_name}", self.log_dir)
        self.tb_writer = setup_tensorboard(self.log_dir)

        self.logger.info(
            f"Initializing Perturbation Ablation unlearning method")
        self.logger.info(f"Model: {self.model_name}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Device: {self.device}")
        self.logger.info(f"Target layers: {self.target_layers}")
        self.logger.info(f"Top-k pruning percentage: {self.k_percent}%")
        self.logger.info(f"Noise data directory: {self.noise_data_dir}")

        self.classes = self._get_class_names()

    def _load_checkpoint(self, checkpoint_path):
        """Load model weights"""
        try:
            checkpoint = torch.load(
                checkpoint_path, map_location=self.device, weights_only=True)

            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint)

        except Exception as e:
            raise

    def _load_dataset(self):
        """Load dataset"""
        train_loader, test_loader, num_classes = load_dataset(
            self.dataset_name, batch_size=self.batch_size, num_workers=4
        )
        return train_loader, test_loader, num_classes

    def _get_class_names(self):
        """Get dataset class names"""
        if self.dataset_name == 'mnist':
            return [str(i) for i in range(10)]
        elif self.dataset_name == 'cifar10':
            return ['airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck']
        elif self.dataset_name == 'cifar100':
            return [f"class_{i}" for i in range(100)]
        elif self.dataset_name == 'svhn':
            return [str(i) for i in range(10)]
        else:
            return [f"class_{i}" for i in range(self.num_classes)]

    def _auto_select_target_layers(self):
        """Automatically select target layers based on model type"""
        model_name = self.model.__class__.__name__.lower()

        if 'resnet' in model_name:
            if hasattr(self.model, 'res1'):
                return ['res1', 'conv3', 'res2', 'conv4', 'classifier']
            else:
                return ['layer3', 'layer4', 'fc']
        elif 'lenet' in model_name:
            return ['conv2', 'fc']
        else:
            target_layers = []
            for name, _ in self.model.named_modules():
                if 'layer' in name or 'conv' in name or 'fc' in name or 'classifier' in name:
                    target_layers.append(name)

            if target_layers:
                num_layers = len(target_layers)
                start_idx = int(num_layers * 0.7)
                return target_layers[start_idx:]

            return ['fc', 'classifier']

    def get_data_loaders(self, class_idx):
        """Create data loaders for a single forget class."""
        train_dataset = self.train_loader.dataset
        forget_indices = []
        retain_indices = []

        # For independent unlearning, there are no previously forgotten classes
        for i, (_, label) in enumerate(train_dataset):
            label_value = label if isinstance(label, int) else label.item()
            if label_value == class_idx:
                forget_indices.append(i)
            else:
                retain_indices.append(i)

        forget_dataset = Subset(train_dataset, forget_indices)
        retain_dataset = Subset(train_dataset, retain_indices)

        test_dataset = self.test_loader.dataset
        class_test_indices = []
        non_class_test_indices = []

        for i, (_, label) in enumerate(test_dataset):
            label_value = label if isinstance(label, int) else label.item()
            if label_value == class_idx:
                class_test_indices.append(i)
            else:
                non_class_test_indices.append(i)

        class_test_dataset = Subset(test_dataset, class_test_indices)
        # In independent unlearning, 'active_test' is the same as 'non_class_test'
        non_class_test_dataset = Subset(test_dataset, non_class_test_indices)

        return {
            'forget': DataLoader(forget_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4),
            'retain': DataLoader(retain_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4),
            'class_test': DataLoader(class_test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4),
            'active_test': DataLoader(non_class_test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
        }

    def evaluate(self, data_loader, dataset_name=""):
        """Evaluate model performance on given dataset"""
        self.model.eval()
        correct = 0
        total = 0
        loss_sum = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)

                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                loss_sum += loss.item()

        acc = 100.0 * correct / total if total > 0 else 0
        avg_loss = loss_sum / len(data_loader) if len(data_loader) > 0 else 0

        if dataset_name:
            self.logger.info(
                f"{dataset_name} - Accuracy: {acc:.2f}%, Loss: {avg_loss:.4f}")

        return {'accuracy': acc, 'loss': avg_loss}

    def generate_synthetic_noise(self, num_samples, target_label, epochs=200, save_prefix=""):
        """Generate synthetic noise samples for parameter sensitivity computation"""
        self.logger.info(
            f"Starting synthetic noise generation for class {target_label} (samples: {num_samples})")
        start_time = time.time()

        channels, img_size = (1, 32) if self.dataset_name == 'mnist' else (3, 32)

        noise_samples = torch.randn(
            num_samples, channels, img_size, img_size, device=self.device) * 0.1
        noise_samples = torch.clamp(noise_samples, -1.0, 1.0)
        noise_samples.requires_grad = True

        optimizer = optim.Adam([noise_samples], lr=0.005)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs, eta_min=0.00005)
        target_labels = torch.full(
            (num_samples,), target_label, device=self.device, dtype=torch.long)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            self.model.eval()
            optimizer.zero_grad()
            noise_samples.data.clamp_(-1.0, 1.0)
            outputs = self.model(noise_samples)
            loss = criterion(outputs, target_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if (epoch + 1) % 50 == 0:
                acc = (outputs.argmax(dim=1) ==
                       target_labels).float().mean() * 100.0
                self.logger.info(
                    f"{save_prefix} Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Accuracy: {acc:.2f}%")

        elapsed = time.time() - start_time
        self.logger.info(
            f"Synthetic noise generation completed (time: {elapsed:.2f}s)")
        return noise_samples.detach()

    def pregenerate_and_save_all_noise(self, num_samples=500):
        """Pre-generates and saves noise for all classes."""
        self.logger.info(
            "Starting pre-generation of noise for all classes...")
        self._load_checkpoint(self.checkpoint_path)
        self.model.eval()
        for class_idx in range(self.num_classes):
            noise_path = os.path.join(
                self.noise_data_dir, f"noise_{class_idx}.pt")
            if os.path.exists(noise_path):
                self.logger.info(
                    f"Noise for class {class_idx} already exists. Skipping generation.")
                continue

            self.logger.info(f"Generating noise for class {class_idx}...")
            noise_samples = self.generate_synthetic_noise(
                num_samples, class_idx, save_prefix=f"pregen_noise_{class_idx}")

            torch.save(noise_samples, noise_path)
            self.logger.info(
                f"Saved noise for class {class_idx} to {noise_path}")
        self.logger.info("Finished pre-generating all noise.")
        return self.noise_data_dir

    def compute_param_sensitivity(self, data_source, labels=None, is_synthetic=False):
        """Compute parameter sensitivity"""
        start_time = time.time()
        param_sens_dict = {name: torch.zeros_like(p)
                           for name, p in self.model.named_parameters()
                           if any(layer in name for layer in self.target_layers)}

        self.model.eval()
        total_samples = 0

        if is_synthetic:
            batch_size = 100
            num_batches = (data_source.size(0) + batch_size - 1) // batch_size
            for i in range(num_batches):
                batch_data = data_source[i * batch_size:(i + 1) * batch_size]
                batch_labels = labels[i * batch_size:(i + 1) * batch_size]
                self._update_param_sensitivity(
                    param_sens_dict, batch_data, batch_labels)
                total_samples += batch_data.size(0)
        else:
            for inputs, labels in data_source:
                self._update_param_sensitivity(param_sens_dict, inputs, labels)
                total_samples += inputs.size(0)

        for name in param_sens_dict:
            if total_samples > 0:
                param_sens_dict[name] /= total_samples

        elapsed = time.time() - start_time
        self.logger.info(
            f"Parameter sensitivity computation completed (time: {elapsed:.2f}s)")
        return param_sens_dict

    def _update_param_sensitivity(self, param_sens_dict, inputs, labels):
        """Update parameter sensitivity computation"""
        self.model.zero_grad()
        inputs, labels = inputs.to(self.device), labels.to(self.device)
        outputs = self.model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()

        for name, param in self.model.named_parameters():
            if param.grad is not None and name in param_sens_dict:
                param_sens_dict[name] += (param.grad.data **
                                          2) * inputs.size(0)

    def apply_unlearning(self, param_sens_forget):
        """
        Apply model unlearning by pruning parameters with top-k sensitivity in the forget set.
        """
        self.logger.info(
            f"Applying ablation unlearning using top {self.k_percent}% parameter selection.")
        start_time = time.time()

        all_sens_values = []
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in param_sens_forget:
                    all_sens_values.append(param_sens_forget[name].flatten())

        if not all_sens_values:
            self.logger.warning(
                "No target layer sensitivities found. No parameters will be modified.")
            return

        all_sens_values = torch.cat(all_sens_values)
        k_val = self.k_percent / 100.0
        if not (0 < k_val <= 1) or len(all_sens_values) == 0:
            self.logger.warning(
                f"Invalid k_percent ({self.k_percent}) or no sensitivities. No parameters modified.")
            return

        threshold = torch.quantile(all_sens_values, 1 - k_val)
        self.logger.info(
            f"Ablation: Using top {self.k_percent}% sensitivity threshold: {threshold:.6f}")

        total_params, modified_params = 0, 0
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in param_sens_forget:
                    mask = param_sens_forget[name] >= threshold
                    param.data[mask] = 0.0
                    modified_count = mask.sum().item()
                    param_count = param.numel()
                    total_params += param_count
                    modified_params += modified_count
                    if param_count > 0:
                        self.logger.info(
                            f"Layer {name}: Modified {modified_count}/{param_count} params ({100.0 * modified_count / param_count:.2f}%)")

        elapsed = time.time() - start_time
        ratio = 100.0 * modified_params / \
            total_params if total_params > 0 else 0
        self.logger.info(
            f"Total modifications: {modified_params}/{total_params} ({ratio:.2f}%) [time: {elapsed:.2f}s]")

    def save_model(self, class_idx):
        """Save model after unlearning a single class."""
        save_suffix = f"forget_{class_idx}"
        save_dir = f'checkpoints/forget/ablation_perturbation/{self.dataset_name}'
        os.makedirs(save_dir, exist_ok=True)
        filename = f"{self.model_name}_{self.dataset_name}_{save_suffix}.pth"
        save_path = os.path.join(save_dir, filename)
        torch.save(self.model.state_dict(), save_path)
        self.logger.info(f"Model after unlearning saved to: {save_path}")
        return save_path

    def unlearn_class(self, class_idx):
        """Unlearn a single class independently."""
        # 1. Reload the original model to ensure independent unlearning
        self._load_checkpoint(self.checkpoint_path)
        class_name = self.classes[class_idx]
        self.logger.info(
            f"Starting to unlearn class {class_name} (index {class_idx}) independently")

        # 2. Get data loaders for the specific class
        loaders = self.get_data_loaders(class_idx)

        # 3. Initial evaluation
        self.logger.info("Initial model evaluation...")
        initial_results = {
            'class_test': self.evaluate(loaders['class_test'], f"Initial - Test set {class_name} class"),
            'active_test': self.evaluate(loaders['active_test'], f"Initial - Test set retain classes")
        }

        # 4. Load pre-generated noise for the forget class
        self.logger.info(
            f"Loading pre-generated noise for forget class {class_idx}...")
        noise_path = os.path.join(
            self.noise_data_dir, f"noise_{class_idx}.pt")
        if not os.path.exists(noise_path):
            raise FileNotFoundError(f"Noise file not found: {noise_path}")

        noise_samples = torch.load(noise_path, map_location=self.device)
        noise_labels = torch.tensor(
            [class_idx] * noise_samples.size(0), device=self.device, dtype=torch.long)

        # 5. Compute sensitivity for the forget class
        self.logger.info("Computing parameter sensitivity for forget class...")
        param_sens_forget = self.compute_param_sensitivity(
            noise_samples, noise_labels, is_synthetic=True)

        # 6. Apply unlearning
        self.logger.info("Applying unlearning (Ablation)...")
        self.apply_unlearning(param_sens_forget)

        # 7. Final evaluation
        self.logger.info("Model evaluation after unlearning...")
        final_results = {
            'class_test': self.evaluate(loaders['class_test'], f"After unlearning - Test set {class_name} class"),
            'active_test': self.evaluate(loaders['active_test'], f"After unlearning - Test set retain classes")
        }

        # 8. Log results
        forget_acc_change = final_results['class_test']['accuracy'] - \
            initial_results['class_test']['accuracy']
        retain_acc_change = final_results['active_test']['accuracy'] - \
            initial_results['active_test']['accuracy']

        self.logger.info(
            f"Forget class accuracy change: {forget_acc_change:.2f}% ({initial_results['class_test']['accuracy']:.2f}% → {final_results['class_test']['accuracy']:.2f}%)")
        self.logger.info(
            f"Retain class accuracy change: {retain_acc_change:.2f}% ({initial_results['active_test']['accuracy']:.2f}% → {final_results['active_test']['accuracy']:.2f}%)")

        # 9. Save the unlearned model
        saved_path = self.save_model(class_idx)

        return {
            'class_name': class_name,
            'class_idx': class_idx,
            'initial': initial_results,
            'final': final_results,
            'forget_acc_change': forget_acc_change,
            'retain_acc_change': retain_acc_change,
            'saved_model_path': saved_path
        }