"""
Baseline unlearning algorithms for LLM tasks.
Implements Gradient Ascent (GA), Negative Preference Optimization (NPO), and In-Context Unlearning (ICU).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from typing import List, Dict, Optional, Any, Tuple
import logging

logger = logging.getLogger(__name__)


class GradientAscentUnlearning:
    """
    Gradient Ascent (GA) unlearning for LLMs.
    
    Applies negative gradients to minimize likelihood of generating forget responses.
    """
    
    def __init__(self, model: nn.Module, tokenizer: Any, device: str = 'cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_data: List[Dict[str, str]],
                retain_data: Optional[List[Dict[str, str]]] = None,
                epochs: int = 5,
                lr: float = 1e-5,
                alpha: float = 1.0,
                max_length: int = 512) -> nn.Module:
        """
        Perform Gradient Ascent unlearning.
        
        Args:
            forget_data: List of forget samples with 'prompt' and 'response' keys
            retain_data: List of retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            alpha: Weight for retain loss
            max_length: Maximum sequence length
            
        Returns:
            Unlearned model
        """
        logger.info("Starting Gradient Ascent unlearning...")
        
        optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_forget_loss = 0.0
            epoch_retain_loss = 0.0
            
            # Process forget samples
            for sample in forget_data:
                prompt = sample['prompt']
                response = sample['response']
                
                # Tokenize
                full_text = f"{prompt} {response}"
                inputs = self.tokenizer(
                    full_text,
                    return_tensors='pt',
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(self.device)
                
                prompt_inputs = self.tokenizer(
                    prompt,
                    return_tensors='pt',
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(self.device)
                
                prompt_length = prompt_inputs['input_ids'].shape[1]
                
                optimizer.zero_grad()
                
                # Forward pass
                outputs = self.model(**inputs, labels=inputs['input_ids'])
                
                # Compute loss only on response tokens (negative for ascent)
                response_mask = torch.zeros_like(inputs['input_ids'])
                response_mask[:, prompt_length:] = 1
                
                logits = outputs.logits[:, :-1, :]
                labels = inputs['input_ids'][:, 1:]
                response_mask = response_mask[:, 1:]
                
                # Compute negative log-likelihood for response tokens
                forget_loss = -self._compute_masked_loss(logits, labels, response_mask)
                
                total_loss = forget_loss
                epoch_forget_loss += forget_loss.item()
                
                # Add retain samples if available
                if retain_data and len(retain_data) > 0:
                    retain_sample = np.random.choice(retain_data)
                    retain_loss = self._compute_retain_loss(
                        retain_sample, max_length
                    )
                    total_loss += alpha * retain_loss
                    epoch_retain_loss += retain_loss.item()
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
            
            avg_forget_loss = epoch_forget_loss / len(forget_data)
            avg_retain_loss = epoch_retain_loss / len(forget_data) if retain_data else 0.0
            
            logger.info(f"Epoch {epoch}: Forget Loss = {avg_forget_loss:.4f}, "
                       f"Retain Loss = {avg_retain_loss:.4f}")
        
        logger.info("Gradient Ascent unlearning completed")
        return self.model
    
    def _compute_masked_loss(self,
                           logits: torch.Tensor,
                           labels: torch.Tensor,
                           mask: torch.Tensor) -> torch.Tensor:
        """Compute cross-entropy loss with masking"""
        loss_fct = nn.CrossEntropyLoss(reduction='none')
        shift_logits = logits.view(-1, logits.size(-1))
        shift_labels = labels.view(-1)
        shift_mask = mask.view(-1)
        
        losses = loss_fct(shift_logits, shift_labels)
        masked_losses = losses * shift_mask
        
        return masked_losses.sum() / shift_mask.sum().clamp(min=1)
    
    def _compute_retain_loss(self,
                           retain_sample: Dict[str, str],
                           max_length: int) -> torch.Tensor:
        """Compute standard loss on retain sample"""
        prompt = retain_sample['prompt']
        response = retain_sample['response']
        
        full_text = f"{prompt} {response}"
        inputs = self.tokenizer(
            full_text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(self.device)
        
        outputs = self.model(**inputs, labels=inputs['input_ids'])
        return outputs.loss


class NPOUnlearning:
    """
    Negative Preference Optimization (NPO) for LLM unlearning.
    
    Reference: "Negative Preference Optimization: From Catastrophic Collapse to Effective Unlearning"
    """
    
    def __init__(self, model: nn.Module, tokenizer: Any, device: str = 'cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
        
        # Create reference model
        self.ref_model = self._create_reference_model()
    
    def _create_reference_model(self) -> nn.Module:
        """Create reference model for NPO"""
        import copy
        ref_model = copy.deepcopy(self.model)
        ref_model.eval()
        for param in ref_model.parameters():
            param.requires_grad = False
        return ref_model
    
    def unlearn(self,
                forget_data: List[Dict[str, str]],
                retain_data: Optional[List[Dict[str, str]]] = None,
                epochs: int = 5,
                lr: float = 1e-5,
                beta: float = 0.1,
                alpha: float = 1.0,
                max_length: int = 512) -> nn.Module:
        """
        Perform NPO unlearning.
        
        Args:
            forget_data: List of forget samples
            retain_data: List of retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            beta: NPO regularization parameter
            alpha: Weight for retain loss
            max_length: Maximum sequence length
            
        Returns:
            Unlearned model
        """
        logger.info("Starting NPO unlearning...")
        
        optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            
            for sample in forget_data:
                optimizer.zero_grad()
                
                # Compute NPO loss
                npo_loss = self._compute_npo_loss(sample, beta, max_length)
                total_loss = npo_loss
                
                # Add retain loss if available
                if retain_data and len(retain_data) > 0:
                    retain_sample = np.random.choice(retain_data)
                    retain_loss = self._compute_retain_loss(retain_sample, max_length)
                    total_loss += alpha * retain_loss
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                epoch_loss += total_loss.item()
            
            avg_loss = epoch_loss / len(forget_data)
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
        
        logger.info("NPO unlearning completed")
        return self.model
    
    def _compute_npo_loss(self,
                         sample: Dict[str, str],
                         beta: float,
                         max_length: int) -> torch.Tensor:
        """Compute NPO loss for a sample"""
        prompt = sample['prompt']
        forget_response = sample['response']
        
        # Generate alternative response (random response for simplicity)
        alternative_response = self._generate_alternative_response(prompt)
        
        # Tokenize both responses
        forget_inputs = self._prepare_inputs(prompt, forget_response, max_length)
        alt_inputs = self._prepare_inputs(prompt, alternative_response, max_length)
        
        # Compute log probabilities
        forget_logprob = self._compute_log_prob(forget_inputs, prompt)
        alt_logprob = self._compute_log_prob(alt_inputs, prompt)
        
        # Compute reference log probabilities
        with torch.no_grad():
            ref_forget_logprob = self._compute_log_prob_ref(forget_inputs, prompt)
            ref_alt_logprob = self._compute_log_prob_ref(alt_inputs, prompt)
        
        # NPO loss
        forget_ratio = forget_logprob - ref_forget_logprob
        alt_ratio = alt_logprob - ref_alt_logprob
        
        npo_loss = -F.logsigmoid(beta * (alt_ratio - forget_ratio))
        
        return npo_loss.mean()
    
    def _generate_alternative_response(self, prompt: str) -> str:
        """Generate alternative response for NPO"""
        # Simple alternative: "I don't know" or generic response
        alternatives = [
            "I don't have information about that.",
            "I cannot provide details on that topic.",
            "That's not something I can discuss.",
            "I'm not able to answer that question."
        ]
        return np.random.choice(alternatives)
    
    def _prepare_inputs(self, prompt: str, response: str, max_length: int) -> Dict:
        """Prepare tokenized inputs"""
        full_text = f"{prompt} {response}"
        return self.tokenizer(
            full_text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(self.device)
    
    def _compute_log_prob(self, inputs: Dict, prompt: str) -> torch.Tensor:
        """Compute log probability of response given prompt"""
        # Get prompt length
        prompt_inputs = self.tokenizer(
            prompt,
            return_tensors='pt',
            padding=True,
            truncation=True
        ).to(self.device)
        prompt_length = prompt_inputs['input_ids'].shape[1]
        
        # Forward pass
        outputs = self.model(**inputs)
        logits = outputs.logits[:, :-1, :]
        labels = inputs['input_ids'][:, 1:]
        
        # Compute log probabilities for response tokens only
        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        
        # Mask to response tokens only
        response_mask = torch.zeros_like(labels)
        response_mask[:, prompt_length-1:] = 1
        
        masked_log_probs = token_log_probs * response_mask
        return masked_log_probs.sum() / response_mask.sum().clamp(min=1)
    
    def _compute_log_prob_ref(self, inputs: Dict, prompt: str) -> torch.Tensor:
        """Compute log probability using reference model"""
        prompt_inputs = self.tokenizer(
            prompt,
            return_tensors='pt',
            padding=True,
            truncation=True
        ).to(self.device)
        prompt_length = prompt_inputs['input_ids'].shape[1]
        
        outputs = self.ref_model(**inputs)
        logits = outputs.logits[:, :-1, :]
        labels = inputs['input_ids'][:, 1:]
        
        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        
        response_mask = torch.zeros_like(labels)
        response_mask[:, prompt_length-1:] = 1
        
        masked_log_probs = token_log_probs * response_mask
        return masked_log_probs.sum() / response_mask.sum().clamp(min=1)
    
    def _compute_retain_loss(self, retain_sample: Dict[str, str], max_length: int) -> torch.Tensor:
        """Compute standard loss on retain sample"""
        prompt = retain_sample['prompt']
        response = retain_sample['response']
        
        inputs = self._prepare_inputs(prompt, response, max_length)
        outputs = self.model(**inputs, labels=inputs['input_ids'])
        return outputs.loss


class ICUUnlearning:
    """
    In-Context Unlearning (ICU) for LLMs.
    
    Reference: "In-Context Unlearning: Language Models as Few Shot Unlearners"
    """
    
    def __init__(self, model: nn.Module, tokenizer: Any, device: str = 'cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
    
    def unlearn(self,
                forget_data: List[Dict[str, str]],
                retain_data: Optional[List[Dict[str, str]]] = None,
                num_examples: int = 3) -> None:
        """
        Perform ICU unlearning (inference-time method).
        
        Args:
            forget_data: List of forget samples
            retain_data: List of retain samples (optional)
            num_examples: Number of examples to include in context
        """
        logger.info("Preparing ICU unlearning context...")
        
        # Create unlearning context
        self.unlearning_context = self._create_unlearning_context(
            forget_data, num_examples
        )
        
        logger.info("ICU unlearning context prepared")
    
    def generate_response(self,
                         prompt: str,
                         max_length: int = 100,
                         temperature: float = 0.7) -> str:
        """
        Generate response with unlearning context.
        
        Args:
            prompt: Input prompt
            max_length: Maximum generation length
            temperature: Sampling temperature
            
        Returns:
            Generated response
        """
        # Prepend unlearning context
        contextualized_prompt = f"{self.unlearning_context}\n\nUser: {prompt}\nAssistant:"
        
        # Tokenize
        inputs = self.tokenizer(
            contextualized_prompt,
            return_tensors='pt',
            padding=True,
            truncation=True
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_length,
                temperature=temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode response
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        return response.strip()
    
    def _create_unlearning_context(self,
                                  forget_data: List[Dict[str, str]],
                                  num_examples: int) -> str:
        """Create unlearning context from forget examples"""
        context_parts = [
            "The following are examples of questions I should not answer:",
            ""
        ]
        
        # Select representative examples
        selected_examples = forget_data[:num_examples] if len(forget_data) >= num_examples else forget_data
        
        for i, example in enumerate(selected_examples, 1):
            prompt = example['prompt']
            
            # Create unlearning response
            unlearning_response = self._generate_unlearning_response(prompt)
            
            context_parts.extend([
                f"Example {i}:",
                f"User: {prompt}",
                f"Assistant: {unlearning_response}",
                ""
            ])
        
        context_parts.extend([
            "I should respond similarly to questions about these topics.",
            "I should decline to provide specific information about sensitive or private topics."
        ])
        
        return "\n".join(context_parts)
    
    def _generate_unlearning_response(self, prompt: str) -> str:
        """Generate appropriate unlearning response"""
        unlearning_templates = [
            "I don't have information about that topic.",
            "I can't provide details on that subject.",
            "That's not something I can discuss.",
            "I'm not able to answer questions about that.",
            "I don't have access to that information."
        ]
        
        return np.random.choice(unlearning_templates)


class DataOptLLMUnlearning:
    """
    DataOpt-enhanced LLM unlearning.
    
    Combines DataOpt label optimization with Gradient Ascent.
    """
    
    def __init__(self, model: nn.Module, tokenizer: Any, device: str = 'cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
        
        # Import DataOpt framework
        from ..src.dataopt import DataOptFramework
        self.dataopt = DataOptFramework(model, device)
    
    def unlearn(self,
                forget_data: List[Dict[str, str]],
                retain_data: Optional[List[Dict[str, str]]] = None,
                epochs: int = 5,
                lr: float = 1e-5,
                alpha: float = 1.0,
                max_length: int = 512) -> nn.Module:
        """
        Perform DataOpt-enhanced LLM unlearning.
        
        Args:
            forget_data: List of forget samples
            retain_data: List of retain samples (optional)
            epochs: Number of training epochs
            lr: Learning rate
            alpha: Weight for retain loss
            max_length: Maximum sequence length
            
        Returns:
            Unlearned model
        """
        logger.info("Starting DataOpt LLM unlearning...")
        
        # Step 1: Generate optimal responses for forget samples
        optimized_forget_data = []
        for sample in forget_data:
            prompt = sample['prompt']
            original_response = sample['response']
            
            # Extract sensitive information (simple keyword extraction)
            sensitive_info = self._extract_sensitive_info(original_response)
            
            # Generate optimal forget response
            optimal_response = self.dataopt.generate_llm_forget_response(
                self.model, prompt, original_response, sensitive_info
            )
            
            optimized_forget_data.append({
                'prompt': prompt,
                'response': optimal_response
            })
        
        # Step 2: Apply gradient ascent with optimized responses
        ga_unlearner = GradientAscentUnlearning(self.model, self.tokenizer, self.device)
        
        return ga_unlearner.unlearn(
            optimized_forget_data, retain_data, epochs, lr, alpha, max_length
        )
    
    def _extract_sensitive_info(self, response: str) -> List[str]:
        """Extract sensitive information from response"""
        # Simple keyword-based extraction
        # In practice, this would be more sophisticated
        
        keywords = []
        
        # Look for proper nouns (capitalized words)
        words = response.split()
        for word in words:
            if word[0].isupper() and len(word) > 2:
                keywords.append(word)
        
        # Remove common words
        common_words = {'The', 'This', 'That', 'And', 'But', 'For', 'You', 'All', 'Any'}
        keywords = [kw for kw in keywords if kw not in common_words]
        
        return keywords[:5]  # Limit to top 5 keywords