"""
DataOpt: Data-Centric Unlearning Framework
Implementing the methods from "Data-Centric Unlearning: Optimizing Labels and Retain Data via Learning Dynamics"
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict, Optional, Union, Any
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DataOptFramework:
    """
    Main DataOpt framework for data-centric unlearning optimization.
    Implements both label assignment optimization and retain set optimization.
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
        
    def assign_forget_labels(self, 
                           forget_samples: torch.Tensor, 
                           k: int, 
                           epsilon: float = 1e-6) -> torch.Tensor:
        """
        Assign optimal labels for forget samples according to Eq. (9) in the paper.
        
        Args:
            forget_samples: Input samples to forget [N, ...]
            k: Unlearning degree (1 <= k <= C-1)
            epsilon: Small constant for numerical stability
            
        Returns:
            Optimal soft labels [N, C]
        """
        self.model.eval()
        with torch.no_grad():
            forget_samples = forget_samples.to(self.device)
            logits = self.model(forget_samples)
            f = F.softmax(logits, dim=1)  # Original model predictions
            
        N, C = f.shape
        optimal_labels = torch.zeros_like(f)
        
        for i in range(N):
            f_i = f[i]
            c_star = torch.argmax(f_i).item()  # Original top class
            
            if k == 1:
                # Special case k=1 (Eq. 10)
                optimal_labels[i] = self._solve_k1(f_i, c_star, epsilon)
            elif k == C - 1:
                # Special case k=C-1 (Eq. 11, 12)
                optimal_labels[i] = self._solve_kc_minus_1(f_i, c_star, epsilon)
            else:
                # General case - solve optimization problem
                optimal_labels[i] = self._solve_general(f_i, c_star, k, epsilon)
                
        return optimal_labels
    
    def _solve_k1(self, f: torch.Tensor, c_star: int, epsilon: float) -> torch.Tensor:
        """Solve special case k=1 (Eq. 10)"""
        C = len(f)
        y = f.clone()
        
        # Find second highest class
        f_masked = f.clone()
        f_masked[c_star] = -float('inf')
        c_hat = torch.argmax(f_masked).item()
        
        # Apply Eq. 10
        y[c_star] = (f[c_star] + f[c_hat] - epsilon) / 2
        y[c_hat] = (f[c_star] + f[c_hat] + epsilon) / 2
        
        return y
    
    def _solve_kc_minus_1(self, f: torch.Tensor, c_star: int, epsilon: float) -> torch.Tensor:
        """Solve special case k=C-1 (Eq. 11, 12)"""
        C = len(f)
        y = f.clone()
        
        # Get non-c_star classes sorted
        non_cstar_indices = [i for i in range(C) if i != c_star]
        f_non_cstar = f[non_cstar_indices]
        sorted_indices = torch.argsort(f_non_cstar)
        
        # Find optimal m
        m = self._find_optimal_m(f, c_star, epsilon)
        
        # Calculate a_m
        sum_large = torch.sum(f_non_cstar[sorted_indices[m+1:]]) if m < C-2 else 0
        a_m = (1 - m * epsilon - sum_large) / (m + 1)
        
        # Assign labels according to Eq. 11
        y[c_star] = a_m
        for j in range(m):
            idx = non_cstar_indices[sorted_indices[j]]
            y[idx] = a_m + epsilon
            
        return y
    
    def _find_optimal_m(self, f: torch.Tensor, c_star: int, epsilon: float) -> int:
        """Find optimal m for k=C-1 case"""
        C = len(f)
        non_cstar_indices = [i for i in range(C) if i != c_star]
        f_non_cstar = f[non_cstar_indices]
        sorted_values = torch.sort(f_non_cstar)[0]
        
        for m in range(C-1):
            sum_large = torch.sum(sorted_values[m+1:]) if m < C-2 else 0
            a_m = (1 - m * epsilon - sum_large) / (m + 1)
            
            if a_m >= 0:
                if m == 0:
                    if a_m + epsilon <= sorted_values[0] if len(sorted_values) > 0 else True:
                        return m
                else:
                    if (sorted_values[m-1] < a_m + epsilon and 
                        (m >= C-2 or a_m + epsilon <= sorted_values[m])):
                        return m
        
        return 0
    
    def _solve_general(self, f: torch.Tensor, c_star: int, k: int, epsilon: float) -> torch.Tensor:
        """Solve general case using iterative optimization"""
        C = len(f)
        
        # Use greedy approach: select k classes with highest original probabilities (excluding c_star)
        non_cstar_indices = [i for i in range(C) if i != c_star]
        non_cstar_probs = f[non_cstar_indices]
        
        # Select top k classes to have higher probability than c_star
        _, top_k_indices = torch.topk(non_cstar_probs, k)
        selected_classes = [non_cstar_indices[i] for i in top_k_indices]
        
        # Initialize with original distribution
        y = f.clone()
        
        # Redistribute probability mass
        total_mass_needed = k * epsilon
        available_mass = y[c_star] - epsilon
        
        if available_mass > 0:
            # Reduce c_star probability
            reduction = min(total_mass_needed, available_mass)
            y[c_star] -= reduction
            
            # Increase selected classes
            increment = reduction / k
            for cls in selected_classes:
                y[cls] += increment
        
        # Ensure probabilities sum to 1
        y = y / torch.sum(y)
        
        return y
    
    def generate_llm_forget_response(self, 
                                   llm: Any,
                                   original_prompt: str,
                                   original_response: str,
                                   sensitive_info: List[str],
                                   num_candidates: int = 5) -> str:
        """
        Generate optimal forget response for LLM using two-stage heuristic.
        
        Args:
            llm: Language model
            original_prompt: Original input prompt
            original_response: Original model response
            sensitive_info: List of sensitive information to remove
            num_candidates: Number of candidate responses to generate
            
        Returns:
            Optimal sanitized response
        """
        # Stage 1: Generate multiple candidates
        candidates = self._generate_candidate_responses(
            llm, original_prompt, original_response, sensitive_info, num_candidates
        )
        
        # Stage 2: Select optimal candidate
        optimal_response = self._select_optimal_response(
            llm, candidates, original_response, original_prompt, sensitive_info
        )
        
        return optimal_response
    
    def _generate_candidate_responses(self, 
                                    llm: Any,
                                    original_prompt: str, 
                                    original_response: str,
                                    sensitive_info: List[str],
                                    num_candidates: int) -> List[str]:
        """Generate candidate responses excluding sensitive information"""
        candidates = []
        
        # Create generation prompt
        sensitive_str = ", ".join(sensitive_info)
        generation_prompt = f"""
        Original prompt: {original_prompt}
        Original response: {original_response}
        
        Please generate {num_candidates} alternative responses that:
        1. Do not mention any of these sensitive terms: {sensitive_str}
        2. Remain helpful and coherent
        3. Are semantically related to the original prompt
        
        Generate {num_candidates} different alternatives:
        """
        
        try:
            # This would depend on the specific LLM implementation
            # For now, provide a template implementation
            response = llm.generate(generation_prompt, num_return_sequences=num_candidates)
            if isinstance(response, list):
                candidates = response
            else:
                # Parse multiple responses from single output
                candidates = self._parse_multiple_responses(response, num_candidates)
                
        except Exception as e:
            logger.warning(f"Error generating candidates: {e}")
            # Fallback: generate simple non-sensitive response
            candidates = [f"I cannot provide information about {', '.join(sensitive_info)}. " +
                         "Please ask about other topics I can help with."]
        
        return candidates
    
    def _parse_multiple_responses(self, response: str, num_candidates: int) -> List[str]:
        """Parse multiple responses from a single generation"""
        # Simple parsing - look for numbered responses
        lines = response.split('\n')
        candidates = []
        current_response = []
        
        for line in lines:
            line = line.strip()
            if any(line.startswith(f"{i}.") for i in range(1, num_candidates + 1)):
                if current_response:
                    candidates.append(' '.join(current_response))
                    current_response = []
                # Remove the number prefix
                line = line[2:].strip()
            
            if line:
                current_response.append(line)
        
        # Add last response
        if current_response:
            candidates.append(' '.join(current_response))
        
        # Ensure we have enough candidates
        while len(candidates) < num_candidates:
            candidates.append("I don't have information about that topic.")
        
        return candidates[:num_candidates]
    
    def _select_optimal_response(self, 
                               llm: Any,
                               candidates: List[str],
                               original_response: str,
                               original_prompt: str,
                               sensitive_info: List[str],
                               alpha: float = 1.0,
                               gamma1: float = 1.0,
                               gamma2: float = 1.0,
                               tau: float = 0.8,
                               delta: float = 0.7) -> str:
        """Select optimal candidate based on scoring criteria"""
        best_score = float('inf')
        best_candidate = candidates[0]
        
        for candidate in candidates:
            # Calculate distance to original
            distance = self._calculate_distance(candidate, original_response)
            
            # Calculate fluency
            fluency = self._calculate_fluency(llm, candidate)
            
            # Calculate relevance
            relevance = self._calculate_relevance(llm, candidate, original_prompt)
            
            # Calculate total score (lower is better)
            score = (alpha * distance + 
                    gamma1 * max(0, tau - fluency) + 
                    gamma2 * max(0, delta - relevance))
            
            if score < best_score:
                best_score = score
                best_candidate = candidate
        
        return best_candidate
    
    def _calculate_distance(self, response1: str, response2: str) -> float:
        """Calculate semantic distance between responses"""
        # Simple token-based similarity
        tokens1 = set(response1.lower().split())
        tokens2 = set(response2.lower().split())
        
        if not tokens1 and not tokens2:
            return 0.0
        if not tokens1 or not tokens2:
            return 1.0
            
        intersection = len(tokens1.intersection(tokens2))
        union = len(tokens1.union(tokens2))
        
        jaccard_sim = intersection / union if union > 0 else 0
        return 1 - jaccard_sim
    
    def _calculate_fluency(self, llm: Any, response: str) -> float:
        """Calculate fluency score"""
        # Simple heuristic: longer, well-structured responses are more fluent
        words = response.split()
        word_count = len(words)
        
        # Basic fluency score based on length and punctuation
        fluency_score = min(1.0, word_count / 20)  # Normalize by expected length
        
        # Bonus for proper sentence structure
        if response.endswith('.') or response.endswith('!') or response.endswith('?'):
            fluency_score += 0.1
        
        return min(1.0, fluency_score)
    
    def _calculate_relevance(self, llm: Any, response: str, prompt: str) -> float:
        """Calculate relevance to original prompt"""
        # Simple keyword overlap
        prompt_words = set(prompt.lower().split())
        response_words = set(response.lower().split())
        
        if not prompt_words:
            return 1.0
            
        overlap = len(prompt_words.intersection(response_words))
        relevance = overlap / len(prompt_words)
        
        return min(1.0, relevance)
    
    def construct_optimized_retain_set(self,
                                     forget_samples: torch.Tensor,
                                     retain_pool: torch.Tensor,
                                     k1: int = 10,
                                     k2: int = 5) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Construct optimized retain set with neighborhood, boundary, and adversarial samples.
        
        Args:
            forget_samples: Samples to be forgotten [N_f, ...]
            retain_pool: Pool of potential retain samples [N_r, ...]
            k1: Number of neighborhood samples per forget sample
            k2: Number of boundary samples per forget sample
            
        Returns:
            Tuple of (neighborhood_samples, boundary_samples, adversarial_samples)
        """
        logger.info("Constructing optimized retain set...")
        
        # Find neighborhood samples
        neighborhood_samples = self.find_neighborhood_samples(
            forget_samples, retain_pool, k1
        )
        
        # Find boundary samples
        boundary_samples = self.find_boundary_samples(
            forget_samples, retain_pool, k2
        )
        
        # Generate adversarial samples
        adversarial_samples = self.generate_adversarial_samples(forget_samples)
        
        return neighborhood_samples, boundary_samples, adversarial_samples
    
    def find_neighborhood_samples(self,
                                forget_samples: torch.Tensor,
                                retain_pool: torch.Tensor,
                                k1: int) -> torch.Tensor:
        """Find k1 most similar samples for each forget sample"""
        logger.info(f"Finding {k1} neighborhood samples for each forget sample...")
        
        # Get feature representations
        forget_features = self._get_features(forget_samples)
        retain_features = self._get_features(retain_pool)
        
        # Convert to numpy for sklearn
        forget_features_np = forget_features.cpu().numpy()
        retain_features_np = retain_features.cpu().numpy()
        
        # Find nearest neighbors
        nbrs = NearestNeighbors(n_neighbors=k1, metric='cosine')
        nbrs.fit(retain_features_np)
        
        neighborhood_indices = []
        for i in range(len(forget_features_np)):
            _, indices = nbrs.kneighbors([forget_features_np[i]])
            neighborhood_indices.extend(indices[0])
        
        # Remove duplicates while preserving order
        unique_indices = list(dict.fromkeys(neighborhood_indices))
        
        return retain_pool[unique_indices]
    
    def find_boundary_samples(self,
                            forget_samples: torch.Tensor,
                            retain_pool: torch.Tensor,
                            k2: int) -> torch.Tensor:
        """Find boundary samples for classification tasks"""
        logger.info(f"Finding {k2} boundary samples for each forget sample...")
        
        self.model.eval()
        boundary_samples = []
        
        with torch.no_grad():
            # Get predictions for forget samples
            forget_logits = self.model(forget_samples.to(self.device))
            forget_probs = F.softmax(forget_logits, dim=1)
            
            # Get predictions for retain pool
            retain_logits = self.model(retain_pool.to(self.device))
            retain_probs = F.softmax(retain_logits, dim=1)
            
        for i, forget_prob in enumerate(forget_probs):
            forget_class = torch.argmax(forget_prob).item()
            
            # Find k2 most confusable classes
            sorted_probs, sorted_classes = torch.sort(forget_prob, descending=True)
            confusable_classes = sorted_classes[1:k2+1].tolist()  # Exclude the top class
            
            # For each confusable class, find closest boundary sample
            for target_class in confusable_classes:
                # Find samples in retain pool that are predicted as target_class
                # but are close to the decision boundary with forget_class
                class_mask = torch.argmax(retain_probs, dim=1) == target_class
                class_samples = retain_pool[class_mask]
                
                if len(class_samples) > 0:
                    # Among class samples, find the one with highest probability for forget_class
                    class_probs = retain_probs[class_mask]
                    forget_class_probs = class_probs[:, forget_class]
                    
                    if len(forget_class_probs) > 0:
                        best_idx = torch.argmax(forget_class_probs)
                        boundary_samples.append(class_samples[best_idx])
        
        if boundary_samples:
            return torch.stack(boundary_samples)
        else:
            # Fallback: return random samples from retain pool
            logger.warning("No boundary samples found, using random samples")
            num_samples = min(len(forget_samples) * k2, len(retain_pool))
            indices = torch.randperm(len(retain_pool))[:num_samples]
            return retain_pool[indices]
    
    def generate_adversarial_samples(self, forget_samples: torch.Tensor) -> torch.Tensor:
        """Generate adversarial samples for each forget sample"""
        logger.info("Generating adversarial samples...")
        
        adversarial_samples = []
        self.model.eval()
        
        for sample in forget_samples:
            sample = sample.unsqueeze(0).to(self.device)
            sample.requires_grad_(True)
            
            # Forward pass
            output = self.model(sample)
            target_class = torch.argmax(output, dim=1)
            
            # Calculate loss
            loss = F.cross_entropy(output, target_class)
            
            # Backward pass
            self.model.zero_grad()
            loss.backward()
            
            # Generate adversarial sample using FGSM
            epsilon = 0.01  # Perturbation magnitude
            perturbation = epsilon * torch.sign(sample.grad)
            adv_sample = sample + perturbation
            
            # Clamp to valid range (assuming normalized input)
            adv_sample = torch.clamp(adv_sample, -2, 2)
            
            adversarial_samples.append(adv_sample.squeeze(0).detach().cpu())
        
        return torch.stack(adversarial_samples)
    
    def _get_features(self, samples: torch.Tensor) -> torch.Tensor:
        """Extract feature representations from samples"""
        self.model.eval()
        features = []
        
        with torch.no_grad():
            batch_size = 64
            for i in range(0, len(samples), batch_size):
                batch = samples[i:i+batch_size].to(self.device)
                
                # Get features from second-to-last layer
                if hasattr(self.model, 'features'):
                    # For models with explicit feature extractor
                    feats = self.model.features(batch)
                    feats = F.adaptive_avg_pool2d(feats, 1).view(feats.size(0), -1)
                else:
                    # For other models, hook the penultimate layer
                    # This is a simplified approach - may need model-specific implementation
                    feats = batch.view(batch.size(0), -1)  # Flatten as fallback
                
                features.append(feats.cpu())
        
        return torch.cat(features, dim=0)


def dataopt_algorithm(model: nn.Module,
                     forget_set: torch.Tensor,
                     retain_pool: torch.Tensor,
                     k: int,
                     k1: int = 10,
                     k2: int = 5,
                     epsilon: float = 1e-6,
                     device: str = 'cuda') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Main DataOpt algorithm implementation (Algorithm 1 from paper).
    
    Args:
        model: Pre-trained model
        forget_set: Samples to forget
        retain_pool: Pool of potential retain samples
        k: Unlearning degree
        k1: Number of neighborhood samples
        k2: Number of boundary samples
        epsilon: Small constant
        device: Device to run on
        
    Returns:
        Tuple of (optimized_forget_set, optimized_forget_labels, 
                 optimized_retain_set, optimized_retain_labels)
    """
    
    framework = DataOptFramework(model, device)
    
    # Step 1: Label assignment for forget samples
    forget_labels = framework.assign_forget_labels(forget_set, k, epsilon)
    
    # Step 2: Construct optimized retain set
    neighborhood_samples, boundary_samples, adversarial_samples = framework.construct_optimized_retain_set(
        forget_set, retain_pool, k1, k2
    )
    
    # Combine all retain samples
    retain_samples = torch.cat([neighborhood_samples, boundary_samples, adversarial_samples], dim=0)
    
    # Step 3: Label assignment for retain samples (use model predictions)
    framework.model.eval()
    with torch.no_grad():
        retain_logits = framework.model(retain_samples.to(device))
        retain_labels = F.softmax(retain_logits, dim=1)
    
    # Step 4: Data augmentation for forget samples (optional)
    augmented_forget_set = forget_set  # Could add augmentation here
    
    return augmented_forget_set, forget_labels, retain_samples, retain_labels