"""
Baseline unlearning algorithms for classification tasks.
Implements NEGGRAD, SCRUB, Bad Teacher, and SalUn.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import Optional, Tuple, Dict, Any
import logging
import copy

logger = logging.getLogger(__name__)


class NEGGRADUnlearning:
    """
    Negative Gradient (NEGGRAD) unlearning method.
    
    Reference: "Eternal Sunshine of the Spotless Net: Selective Forgetting in Neural Networks"
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_loader: DataLoader,
                retain_loader: Optional[DataLoader] = None,
                epochs: int = 5,
                lr: float = 0.01,
                alpha: float = 1.0) -> nn.Module:
        """
        Perform NEGGRAD unlearning.
        
        Args:
            forget_loader: DataLoader for forget samples
            retain_loader: DataLoader for retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            alpha: Weight for retain loss
            
        Returns:
            Unlearned model
        """
        logger.info("Starting NEGGRAD unlearning...")
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_forget_loss = 0.0
            epoch_retain_loss = 0.0
            num_batches = 0
            
            # Process forget samples (negative gradient)
            for batch_idx, (data, target) in enumerate(forget_loader):
                data, target = data.to(self.device), target.to(self.device)
                
                optimizer.zero_grad()
                
                # Forward pass
                output = self.model(data)
                
                # Negative gradient on forget samples
                forget_loss = -F.cross_entropy(output, target)
                total_loss = forget_loss
                
                # Add retain samples if available
                if retain_loader is not None:
                    try:
                        retain_data, retain_target = next(retain_iter)
                    except (StopIteration, NameError):
                        retain_iter = iter(retain_loader)
                        retain_data, retain_target = next(retain_iter)
                    
                    retain_data = retain_data.to(self.device)
                    retain_target = retain_target.to(self.device)
                    
                    retain_output = self.model(retain_data)
                    retain_loss = F.cross_entropy(retain_output, retain_target)
                    
                    total_loss = forget_loss + alpha * retain_loss
                    epoch_retain_loss += retain_loss.item()
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                epoch_forget_loss += forget_loss.item()
                num_batches += 1
            
            avg_forget_loss = epoch_forget_loss / num_batches
            avg_retain_loss = epoch_retain_loss / num_batches if retain_loader else 0.0
            
            if epoch % 2 == 0:
                logger.info(f"Epoch {epoch}: Forget Loss = {avg_forget_loss:.4f}, "
                          f"Retain Loss = {avg_retain_loss:.4f}")
        
        logger.info("NEGGRAD unlearning completed")
        return self.model


class SCRUBUnlearning:
    """
    SCRUB unlearning method.
    
    Reference: "SCRUB: A Simple Connectivity-based Approach for Dataset Subset Forgetting"
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_loader: DataLoader,
                retain_loader: Optional[DataLoader] = None,
                epochs: int = 10,
                lr: float = 0.001,
                lam: float = 0.1) -> nn.Module:
        """
        Perform SCRUB unlearning.
        
        Args:
            forget_loader: DataLoader for forget samples
            retain_loader: DataLoader for retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            lam: Regularization parameter
            
        Returns:
            Unlearned model
        """
        logger.info("Starting SCRUB unlearning...")
        
        optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
        
        # Store original parameters for regularization
        original_params = {}
        for name, param in self.model.named_parameters():
            original_params[name] = param.data.clone()
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            # Combine forget and retain data
            if retain_loader is not None:
                retain_iter = iter(retain_loader)
            
            for batch_idx, (forget_data, forget_target) in enumerate(forget_loader):
                forget_data = forget_data.to(self.device)
                forget_target = forget_target.to(self.device)
                
                optimizer.zero_grad()
                
                total_loss = 0.0
                
                # Retain samples loss (if available)
                if retain_loader is not None:
                    try:
                        retain_data, retain_target = next(retain_iter)
                    except StopIteration:
                        retain_iter = iter(retain_loader)
                        retain_data, retain_target = next(retain_iter)
                    
                    retain_data = retain_data.to(self.device)
                    retain_target = retain_target.to(self.device)
                    
                    retain_output = self.model(retain_data)
                    retain_loss = F.cross_entropy(retain_output, retain_target)
                    total_loss += retain_loss
                
                # SCRUB-specific loss: minimize activation on forget samples
                forget_output = self.model(forget_data)
                
                # Minimize confidence on forget samples
                forget_probs = F.softmax(forget_output, dim=1)
                confidence_loss = -torch.mean(torch.sum(forget_probs * torch.log(forget_probs + 1e-8), dim=1))
                total_loss += confidence_loss
                
                # L2 regularization to stay close to original model
                l2_reg = 0.0
                for name, param in self.model.named_parameters():
                    l2_reg += torch.norm(param - original_params[name]) ** 2
                
                total_loss += lam * l2_reg
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                epoch_loss += total_loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            
            if epoch % 2 == 0:
                logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
        
        logger.info("SCRUB unlearning completed")
        return self.model


class BadTeacherUnlearning:
    """
    Bad Teacher (BT) unlearning method.
    
    Reference: "Can Bad Teaching Induce Forgetting? Unlearning in Deep Networks using an Incompetent Teacher"
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
        
        # Create bad teacher model (randomly initialized)
        self.bad_teacher = copy.deepcopy(model)
        self._initialize_bad_teacher()
    
    def _initialize_bad_teacher(self):
        """Initialize bad teacher with random weights"""
        for param in self.bad_teacher.parameters():
            if len(param.shape) >= 2:
                nn.init.xavier_uniform_(param)
            else:
                nn.init.uniform_(param, -0.1, 0.1)
    
    def unlearn(self,
                forget_loader: DataLoader,
                retain_loader: Optional[DataLoader] = None,
                epochs: int = 10,
                lr: float = 0.001,
                alpha: float = 1.0,
                temperature: float = 4.0) -> nn.Module:
        """
        Perform Bad Teacher unlearning.
        
        Args:
            forget_loader: DataLoader for forget samples
            retain_loader: DataLoader for retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            alpha: Weight for retain loss
            temperature: Temperature for knowledge distillation
            
        Returns:
            Unlearned model
        """
        logger.info("Starting Bad Teacher unlearning...")
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            if retain_loader is not None:
                retain_iter = iter(retain_loader)
            
            for batch_idx, (forget_data, forget_target) in enumerate(forget_loader):
                forget_data = forget_data.to(self.device)
                forget_target = forget_target.to(self.device)
                
                optimizer.zero_grad()
                
                # Get student (main model) output
                student_output = self.model(forget_data)
                
                # Get bad teacher output
                with torch.no_grad():
                    teacher_output = self.bad_teacher(forget_data)
                
                # Knowledge distillation loss with bad teacher
                kd_loss = self._knowledge_distillation_loss(
                    student_output, teacher_output, temperature
                )
                
                total_loss = kd_loss
                
                # Add retain samples if available
                if retain_loader is not None:
                    try:
                        retain_data, retain_target = next(retain_iter)
                    except StopIteration:
                        retain_iter = iter(retain_loader)
                        retain_data, retain_target = next(retain_iter)
                    
                    retain_data = retain_data.to(self.device)
                    retain_target = retain_target.to(self.device)
                    
                    retain_output = self.model(retain_data)
                    retain_loss = F.cross_entropy(retain_output, retain_target)
                    
                    total_loss += alpha * retain_loss
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                epoch_loss += total_loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            
            if epoch % 2 == 0:
                logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
        
        logger.info("Bad Teacher unlearning completed")
        return self.model
    
    def _knowledge_distillation_loss(self,
                                   student_logits: torch.Tensor,
                                   teacher_logits: torch.Tensor,
                                   temperature: float) -> torch.Tensor:
        """Compute knowledge distillation loss"""
        student_probs = F.log_softmax(student_logits / temperature, dim=1)
        teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
        
        return F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)


class SalUnUnlearning:
    """
    SalUn (Saliency Unlearning) method.
    
    Reference: "SalUn: Empowering Machine Unlearning via Gradient-based Weight Saliency in Both Image Classification and Generation"
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_loader: DataLoader,
                retain_loader: Optional[DataLoader] = None,
                epochs: int = 10,
                lr: float = 0.001,
                alpha: float = 1.0,
                threshold: float = 0.9) -> nn.Module:
        """
        Perform SalUn unlearning.
        
        Args:
            forget_loader: DataLoader for forget samples
            retain_loader: DataLoader for retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            alpha: Weight for retain loss
            threshold: Threshold for saliency-based masking
            
        Returns:
            Unlearned model
        """
        logger.info("Starting SalUn unlearning...")
        
        # Step 1: Compute saliency scores
        saliency_scores = self._compute_saliency(forget_loader)
        
        # Step 2: Create mask based on saliency
        mask = self._create_saliency_mask(saliency_scores, threshold)
        
        # Step 3: Fine-tune with masked parameters
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            if retain_loader is not None:
                retain_iter = iter(retain_loader)
            
            for batch_idx, (forget_data, forget_target) in enumerate(forget_loader):
                forget_data = forget_data.to(self.device)
                forget_target = forget_target.to(self.device)
                
                optimizer.zero_grad()
                
                # Forward pass
                forget_output = self.model(forget_data)
                
                # Negative gradient on forget samples
                forget_loss = -F.cross_entropy(forget_output, forget_target)
                total_loss = forget_loss
                
                # Add retain samples if available
                if retain_loader is not None:
                    try:
                        retain_data, retain_target = next(retain_iter)
                    except StopIteration:
                        retain_iter = iter(retain_loader)
                        retain_data, retain_target = next(retain_iter)
                    
                    retain_data = retain_data.to(self.device)
                    retain_target = retain_target.to(self.device)
                    
                    retain_output = self.model(retain_data)
                    retain_loss = F.cross_entropy(retain_output, retain_target)
                    
                    total_loss += alpha * retain_loss
                
                # Backward pass
                total_loss.backward()
                
                # Apply saliency mask to gradients
                self._apply_saliency_mask(mask)
                
                optimizer.step()
                
                epoch_loss += total_loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            
            if epoch % 2 == 0:
                logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
        
        logger.info("SalUn unlearning completed")
        return self.model
    
    def _compute_saliency(self, forget_loader: DataLoader) -> Dict[str, torch.Tensor]:
        """Compute saliency scores for model parameters"""
        saliency_scores = {}
        
        # Initialize saliency scores
        for name, param in self.model.named_parameters():
            saliency_scores[name] = torch.zeros_like(param)
        
        self.model.eval()
        num_samples = 0
        
        for data, target in forget_loader:
            data, target = data.to(self.device), target.to(self.device)
            
            # Zero gradients
            self.model.zero_grad()
            
            # Forward pass
            output = self.model(data)
            loss = F.cross_entropy(output, target)
            
            # Backward pass
            loss.backward()
            
            # Accumulate gradient magnitudes
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    saliency_scores[name] += torch.abs(param.grad)
            
            num_samples += data.size(0)
        
        # Average saliency scores
        for name in saliency_scores:
            saliency_scores[name] /= num_samples
        
        return saliency_scores
    
    def _create_saliency_mask(self,
                            saliency_scores: Dict[str, torch.Tensor],
                            threshold: float) -> Dict[str, torch.Tensor]:
        """Create binary mask based on saliency scores"""
        mask = {}
        
        for name, scores in saliency_scores.items():
            # Compute threshold
            flat_scores = scores.flatten()
            sorted_scores, _ = torch.sort(flat_scores, descending=True)
            threshold_idx = int(threshold * len(sorted_scores))
            threshold_value = sorted_scores[threshold_idx] if threshold_idx < len(sorted_scores) else sorted_scores[-1]
            
            # Create mask (1 for high saliency, 0 for low)
            mask[name] = (scores >= threshold_value).float()
        
        return mask
    
    def _apply_saliency_mask(self, mask: Dict[str, torch.Tensor]):
        """Apply saliency mask to gradients"""
        for name, param in self.model.named_parameters():
            if param.grad is not None and name in mask:
                param.grad *= mask[name]


class DELETEUnlearning:
    """
    DELETE: Decoupled Distillation unlearning method.
    
    Reference: "Decoupled Distillation to Erase: A General Unlearning Method for Any Model"
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_loader: DataLoader,
                retain_loader: Optional[DataLoader] = None,
                epochs: int = 10,
                lr: float = 0.001) -> nn.Module:
        """
        Perform DELETE unlearning using decoupled distillation.
        
        Args:
            forget_loader: DataLoader for forget samples
            retain_loader: DataLoader for retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            
        Returns:
            Unlearned model
        """
        logger.info("Starting DELETE unlearning...")
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            for batch_idx, (forget_data, forget_target) in enumerate(forget_loader):
                forget_data = forget_data.to(self.device)
                forget_target = forget_target.to(self.device)
                
                optimizer.zero_grad()
                
                # Forward pass
                forget_output = self.model(forget_data)
                
                # CREATE DELETE-style labels
                delete_labels = self._create_delete_labels(forget_output, forget_target)
                
                # Compute loss with DELETE labels
                delete_loss = F.cross_entropy(forget_output, delete_labels)
                total_loss = delete_loss
                
                # Add retain samples if available
                if retain_loader is not None:
                    retain_iter = iter(retain_loader) if batch_idx == 0 else retain_iter
                    try:
                        retain_data, retain_target = next(retain_iter)
                        retain_data = retain_data.to(self.device)
                        retain_target = retain_target.to(self.device)
                        
                        retain_output = self.model(retain_data)
                        retain_loss = F.cross_entropy(retain_output, retain_target)
                        total_loss += retain_loss
                        
                    except StopIteration:
                        pass
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                epoch_loss += total_loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            
            if epoch % 2 == 0:
                logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
        
        logger.info("DELETE unlearning completed")
        return self.model
    
    def _create_delete_labels(self,
                            logits: torch.Tensor,
                            true_labels: torch.Tensor) -> torch.Tensor:
        """Create DELETE-style labels by zeroing true class logits"""
        modified_logits = logits.clone()
        
        # Set true class logits to large negative value
        for i, label in enumerate(true_labels):
            modified_logits[i, label] = -1e6
        
        # Apply softmax to get new target distribution
        delete_labels = F.softmax(modified_logits, dim=1)
        
        # Convert to hard labels (argmax)
        return torch.argmax(delete_labels, dim=1)