#!/usr/bin/env python3
"""
Target Vector Computation and Caching for Qwen3-30B-A3B MoE Knowledge Editing
Based on AlphaEdit's optimization approach with caching support
"""

import torch
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import pickle
import hashlib
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

@dataclass
class EditRequest:
    """Single knowledge edit request"""
    prompt: str          # "The capital of {} is"
    subject: str         # "France"
    target_new: str      # " Berlin"
    case_id: int

@dataclass
class AlphaEditHyperParams:
    """Hyperparameters for AlphaEdit-style optimization"""
    v_lr: float = 1e-1                    # Learning rate for optimization
    v_num_grad_steps: int = 50             # Number of gradient steps
    v_loss_layer: int = -1                 # Loss layer (will be set dynamically)
    v_weight_decay: float = 1e-3           # Weight decay
    kl_factor: float = 0.0625              # KL divergence factor
    clamp_norm_factor: float = 4.0         # Clamp norm factor
    target_boost: float = 3.0              # Target probability boost

@dataclass
class TargetVectorResult:
    """Result of target vector computation"""
    target_vector: torch.Tensor           # The computed target vector
    initial_prob: float                   # Initial probability of target token
    final_prob: float                     # Final probability after optimization
    optimization_steps: int               # Number of optimization steps used
    converged: bool                       # Whether optimization converged
    loss_history: List[float]             # Loss history during optimization

class TargetVectorComputer:
    """Compute and cache target vectors for knowledge editing using subject token positions"""

    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        cache_dir: str = "./target_vector_cache",
        device: str = "cuda",
        fact_token_strategy: str = "subject_last",
        hparams: Optional[AlphaEditHyperParams] = None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.fact_token_strategy = fact_token_strategy

        # Use provided hyperparameters or default
        self.hparams = hparams or AlphaEditHyperParams()

        print(f"Target Vector Computer initialized (Subject Token Mode)")
        print(f"Cache directory: {self.cache_dir}")
        print(f"Device: {self.device}")

    def find_subject_token_positions(self, prompt: str, subject: str, inputs: Dict) -> Tuple[int, int]:
        """
        Find the start and end positions of subject tokens in the tokenized input
        (Same implementation as in moe_edit.py)
        """
        # Get the full prompt text
        prompt_text = prompt.format(subject) if '{}' in prompt else prompt

        # Tokenize the subject separately to understand its token structure
        subject_tokens = self.tokenizer(subject, add_special_tokens=False)['input_ids']

        # Get the full input token ids
        input_ids = inputs['input_ids'][0].cpu().tolist()

        # Find subject tokens in the input sequence
        subject_start_pos = None
        subject_end_pos = None

        # Search for the subject token sequence in the input
        for i in range(len(input_ids) - len(subject_tokens) + 1):
            if input_ids[i:i+len(subject_tokens)] == subject_tokens:
                subject_start_pos = i
                subject_end_pos = i + len(subject_tokens) - 1
                break

        # Fallback: if exact match fails, try to find subject string in decoded tokens
        if subject_start_pos is None:
            # Simple approach: find the best matching consecutive tokens
            subject_lower = subject.lower()

            # Try different window sizes to find the subject
            for window_size in range(1, min(len(input_ids) + 1, 6)):  # Try up to 5 tokens
                for i in range(len(input_ids) - window_size + 1):
                    window_text = self.tokenizer.decode(input_ids[i:i+window_size], skip_special_tokens=True).strip().lower()

                    # Check if subject is contained in this window
                    if subject_lower in window_text:
                        subject_start_pos = i
                        subject_end_pos = i + window_size - 1
                        print(f"   Found subject in window: '{window_text}' at positions [{i}, {i+window_size-1}]")
                        break

                if subject_start_pos is not None:
                    break

        # Final fallback: use a reasonable default
        if subject_start_pos is None:
            print(f"⚠ Warning: Could not locate subject '{subject}' in tokenized input")
            print(f"   Input tokens: {input_ids}")
            print(f"   Subject tokens: {subject_tokens}")
            print(f"   Decoded input: {self.tokenizer.decode(input_ids)}")
            # Use the middle of the sequence as a reasonable default
            seq_len = len(input_ids)
            subject_start_pos = max(0, seq_len // 2 - 1)
            subject_end_pos = subject_start_pos

        print(f"🎯 Subject '{subject}' found at positions [{subject_start_pos}, {subject_end_pos}]")
        print(f"   Subject tokens: {input_ids[subject_start_pos:subject_end_pos+1]}")
        print(f"   Decoded subject: {self.tokenizer.decode(input_ids[subject_start_pos:subject_end_pos+1])}")

        return subject_start_pos, subject_end_pos
    
    def _get_cache_key(self, request: EditRequest, layer_idx: int, loss_layer: int) -> str:
        """Generate cache key for a request with layer information"""
        content = f"{request.prompt}|{request.subject}|{request.target_new}|{layer_idx}|{loss_layer}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def _get_cache_path(self, cache_key: str, layer_idx: int) -> Path:
        """Get cache file path with layer information"""
        return self.cache_dir / f"target_vector_layer_{layer_idx}_{cache_key}.pkl"
    
    def _compute_single_target_vector(
        self,
        request: EditRequest,
        layer_idx: int,
        loss_layer: Optional[int] = None
    ) -> TargetVectorResult:
        """
        Compute target vector for a single request

        优化目标：第layer_idx层的输出隐藏向量（即第layer_idx+1层的输入）
        干预位置：在第layer_idx层的输出上添加delta向量
        优化目标：使得最终预测目标token的概率最大化

        Args:
            request: 编辑请求
            layer_idx: 要编辑的层索引（如13）
            loss_layer: 计算损失的层索引（通常是最后一层）

        Returns:
            TargetVectorResult: 包含优化后的目标向量（第layer_idx层的期望输出）
        """
        
        # Resolve and clamp loss layer within valid range
        if loss_layer is None:
            loss_layer = layer_idx
        try:
            n_layers = len(self.model.model.layers)
            if loss_layer >= n_layers or loss_layer < 0:
                loss_layer = n_layers - 1
        except Exception:
            pass

        # Format prompt with subject - use AlphaEdit standard approach
        full_prompt = request.prompt.format(request.subject)

        # Tokenize
        inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.device)

        # Find subject token positions using AlphaEdit-style strategy
        from .token_utils import find_fact_lookup_idx
        subject_end_pos = find_fact_lookup_idx(
            prompt=request.prompt,
            subject=request.subject,
            tok=self.tokenizer,
            fact_token_strategy=self.fact_token_strategy,
            verbose=False
        )

        # Get target token ID
        target_ids = self.tokenizer(request.target_new, return_tensors="pt").input_ids.to(self.device)
        if target_ids.shape[1] == 0:
            raise ValueError(f"Target string '{request.target_new}' could not be tokenized")
        target_id = target_ids[0, 0]
        
        # Get model dimensions
        if hasattr(self.model.config, 'hidden_size'):
            d_model = self.model.config.hidden_size
        elif hasattr(self.model.config, 'n_embd'):
            d_model = self.model.config.n_embd
        else:
            raise ValueError("Could not determine model hidden size")
        
        print(f"Computing target vector for '{request.target_new}' (token_id: {target_id})")
        print(f"Prompt: {full_prompt}")
        print(f"Subject token position: {subject_end_pos}")
        print(f"Optimization: {self.hparams.v_num_grad_steps} steps, lr={self.hparams.v_lr}")

        # Get initial probability (still use last token for evaluation)
        with torch.no_grad():
            initial_outputs = self.model(**inputs)
            initial_logits = initial_outputs.logits[0, -1, :]  # Keep last token for evaluation
            initial_probs = F.softmax(initial_logits, dim=-1)
            initial_prob = initial_probs[target_id].item()

        print(f"Initial target probability: {initial_prob:.6f}")

        # Initialize delta vector
        delta = torch.zeros(d_model, requires_grad=True, device=self.device)
        target_init = None

        # Get the position to intervene (subject token position)
        intervene_pos = subject_end_pos
        print(f"🎯 Intervention will be applied at subject token position: {intervene_pos}")
        
        # Setup hooks
        layer_outputs = {}
        
        def make_hook(layer_name):
            def hook(module, input_args, output):
                # 忽略未使用的参数
                _ = module, input_args
                layer_outputs[layer_name] = output
            return hook
        
        # Register hooks
        # 我们要优化的是layer_idx层之后的隐藏向量（即该层的输出）
        # 所以hook注册到layer_idx+1层，这样可以捕获layer_idx层的输出
        target_layer = self.model.model.layers[layer_idx]  # 编辑的层
        post_target_layer = self.model.model.layers[layer_idx + 1] if layer_idx + 1 < len(self.model.model.layers) else None  # 下一层
        loss_layer_module = self.model.model.layers[loss_layer]

        target_hook = target_layer.register_forward_hook(make_hook(f'layer_{layer_idx}'))
        # 如果有下一层，也注册hook来捕获target层的输出
        post_target_hook = None
        if post_target_layer is not None:
            post_target_hook = post_target_layer.register_forward_hook(make_hook(f'layer_{layer_idx+1}'))
        loss_hook = loss_layer_module.register_forward_hook(make_hook(f'layer_{loss_layer}'))
        
        # Optimizer
        optimizer = torch.optim.Adam([delta], lr=self.hparams.v_lr)
        
        # Freeze model parameters
        original_requires_grad = {}
        for name, param in self.model.named_parameters():
            original_requires_grad[name] = param.requires_grad
            param.requires_grad = False
        
        loss_history = []
        final_prob = initial_prob
        converged = False

        try:
            # Optimization loop
            for step in range(self.hparams.v_num_grad_steps):
                optimizer.zero_grad()
                layer_outputs.clear()

                # Forward pass with intervention at subject token position
                def intervention_hook(module, input_args, output):
                    # 忽略未使用的参数
                    _ = module, input_args
                    nonlocal target_init

                    # Handle different output types
                    if isinstance(output, tuple):
                        actual_output = output[0]
                    else:
                        actual_output = output

                    # Record initial value (这是第layer_idx层在subject token位置的原始输出)
                    if target_init is None:
                        target_init = actual_output[0, intervene_pos, :].detach().clone()
                        print(f"Initial target vector (layer {layer_idx} output at subject pos {intervene_pos}) norm: {target_init.norm().item():.4f}")

                    # Apply intervention (在第layer_idx层的subject token位置输出上添加delta)
                    modified_output = actual_output.clone()
                    modified_output[0, intervene_pos, :] += delta

                    # Return in the same format as input
                    if isinstance(output, tuple):
                        return (modified_output,) + output[1:]
                    else:
                        return modified_output

                # Register intervention hook
                intervention_hook_handle = target_layer.register_forward_hook(intervention_hook)

                try:
                    # Forward pass (need gradients for delta)
                    outputs = self.model(**inputs)

                    # Get final logits (still use last token for evaluation)
                    target_logits = outputs.logits[0, -1, :]  # Keep last token for generation evaluation
                    log_probs = F.log_softmax(target_logits, dim=-1)

                    # Negative log likelihood loss (we want to maximize target probability)
                    nll_loss = -log_probs[target_id]

                    # Weight decay
                    if target_init is not None:
                        weight_decay = self.hparams.v_weight_decay * (
                            torch.norm(delta) / torch.norm(target_init) ** 2
                        )
                    else:
                        weight_decay = self.hparams.v_weight_decay * torch.norm(delta) ** 2

                    # Total loss
                    loss = nll_loss + weight_decay
                    loss_history.append(loss.item())

                    # Backward pass
                    loss.backward()

                    # Gradient step
                    optimizer.step()

                    # Clamp norm
                    if target_init is not None:
                        max_norm = self.hparams.clamp_norm_factor * target_init.norm()
                        if delta.norm() > max_norm:
                            with torch.no_grad():
                                delta.data = delta.data * max_norm / delta.norm()

                    # Update final probability
                    with torch.no_grad():
                        final_probs = F.softmax(target_logits, dim=-1)
                        final_prob = final_probs[target_id].item()

                    # Print progress for every step
                    print(f"Step {step:2d}: loss={loss.item():.4f}, "
                            f"nll_loss={nll_loss.item():.4f}, "
                            f"weight_decay={weight_decay.item():.4f}, "
                            f"target_prob={final_prob:.6f}, "
                            f"delta_norm={delta.norm().item():.4f}")

                    # Early stopping
                    if loss.item() < 0.01:
                        print(f"Early stopping at step {step}")
                        converged = True
                        break

                finally:
                    intervention_hook_handle.remove()
                
        finally:
            # Clean up hooks
            target_hook.remove()
            if post_target_hook is not None:
                post_target_hook.remove()
            loss_hook.remove()
            
            # Restore model parameters
            for name, param in self.model.named_parameters():
                param.requires_grad = original_requires_grad.get(name, True)
        
        # Return target vector (initial + delta)
        # target_init 是第layer_idx层在subject token位置的原始输出，target_vector 是期望输出
        if target_init is not None:
            target_vector = target_init + delta.detach()
            print(f"Final results:")
            print(f"  Layer {layer_idx} subject token position {intervene_pos} optimization:")
            print(f"  Initial probability: {initial_prob:.6f}")
            print(f"  Final probability: {final_prob:.6f}")
            print(f"  Improvement: {final_prob/initial_prob:.2f}x")
            print(f"  Target vector (layer {layer_idx} output at subject pos) norm: {target_vector.norm().item():.4f}")
            print(f"  Delta norm: {delta.norm().item():.4f}")
            print(f"  Target vector represents the desired output of layer {layer_idx} at subject token position {intervene_pos}")

            return TargetVectorResult(
                target_vector=target_vector,
                initial_prob=initial_prob,
                final_prob=final_prob,
                optimization_steps=len(loss_history),
                converged=converged,
                loss_history=loss_history
            )
        else:
            raise RuntimeError("Could not compute target vector")
    
    def compute_target_vectors(
        self,
        requests: List[EditRequest],
        layer_idx: int,
        loss_layer: Optional[int] = None,
        force_recompute: bool = False
    ) -> Dict[int, TargetVectorResult]:
        """
        Compute target vectors for multiple requests with caching
        
        Args:
            requests: List of edit requests
            layer_idx: Layer to edit
            loss_layer: Layer to compute loss (default: layer_idx + 5)
            force_recompute: Force recomputation even if cached
            
        Returns:
            Dictionary mapping case_id to TargetVectorResult
        """
        
        if loss_layer is None:
            # Dynamically select a valid loss layer for the current model
            try:
                n_layers = len(self.model.model.layers)
            except Exception:
                n_layers = int(getattr(getattr(self.model, 'config', object()), 'num_hidden_layers', 48))
            # Prefer a few layers after the edit layer, but clamp to last layer
            loss_layer = max(0, min(n_layers - 1, layer_idx + 5))
        # Final clamp (in case caller provided an out-of-range index)
        try:
            n_layers = len(self.model.model.layers)
            loss_layer = max(0, min(n_layers - 1, loss_layer))
        except Exception:
            pass

        results = {}

        print(f"Computing target vectors for {len(requests)} requests")
        print(f"Layer: {layer_idx}, Loss layer: {loss_layer}")

        for request in tqdm(requests, desc="Computing target vectors"):
            # cache_key = self._get_cache_key(request, layer_idx, loss_layer)
            cache_key = request.case_id  # Use case_id as cache key for simplicity
            cache_path = self._get_cache_path(cache_key, layer_idx)
            
            # Try to load from cache
            if cache_path.exists() and not force_recompute:
                try:
                    with open(cache_path, 'rb') as f:
                        result = pickle.load(f)
                    print(f"✓ Loaded cached target vector for case {request.case_id}")
                    print(f"  Cached result: {result.initial_prob:.6f} -> {result.final_prob:.6f}")
                    results[request.case_id] = result
                    continue
                except Exception as e:
                    print(f"✗ Failed to load cache for case {request.case_id}: {e}")

            # Compute new target vector
            try:
                result = self._compute_single_target_vector(request, layer_idx, loss_layer)

                # Save to cache
                with open(cache_path, 'wb') as f:
                    pickle.dump(result, f)
                print(f"✓ Cached target vector for case {request.case_id}")

                results[request.case_id] = result

            except Exception as e:
                print(f"✗ Failed to compute target vector for case {request.case_id}: {e}")
                # Create a fallback result using simple embedding
                target_vector = self._compute_simple_target_vector(request.target_new)
                result = TargetVectorResult(
                    target_vector=target_vector,
                    initial_prob=0.0,
                    final_prob=0.0,
                    optimization_steps=0,
                    converged=False,
                    loss_history=[]
                )
                results[request.case_id] = result

        return results




    
    def _compute_simple_target_vector(self, target_str: str) -> torch.Tensor:
        """Fallback: compute simple target vector using embedding"""
        target_ids = self.tokenizer(target_str, return_tensors="pt").input_ids.to(self.device)
        if target_ids.shape[1] == 0:
            raise ValueError(f"Target string '{target_str}' could not be tokenized")
        target_id = target_ids[0, 0]
        
        # Get output embedding
        if hasattr(self.model, 'lm_head'):
            output_embeddings = self.model.lm_head.weight
        elif hasattr(self.model, 'get_output_embeddings'):
            output_embeddings = self.model.get_output_embeddings().weight
        else:
            raise ValueError("Could not find output embeddings")
        
        target_embedding = output_embeddings[target_id]
        target_vector = F.normalize(target_embedding, dim=0) * 3.0
        
        print(f"Using simple target vector for '{target_str}' (norm: {target_vector.norm().item():.4f})")
        return target_vector
    
    def clear_cache(self):
        """Clear all cached target vectors"""
        cache_files = list(self.cache_dir.glob("target_vector_*.pkl"))

        if not cache_files:
            print("No cached target vectors found")
            return

        print(f"Found {len(cache_files)} cached target vectors:")
        total_size = 0
        for cache_file in cache_files:
            size_mb = cache_file.stat().st_size / (1024 * 1024)
            total_size += size_mb
            print(f"  {cache_file.name} ({size_mb:.2f} MB)")
            cache_file.unlink()
            print(f"Deleted cache file: {cache_file}")

        print(f"Total cache size cleared: {total_size:.2f} MB")

    def list_cache(self):
        """List all cached target vectors"""
        cache_files = list(self.cache_dir.glob("target_vector_*.pkl"))

        if not cache_files:
            print("No cached target vectors found")
            return

        print(f"Found {len(cache_files)} cached target vectors:")
        total_size = 0
        for cache_file in cache_files:
            size_mb = cache_file.stat().st_size / (1024 * 1024)
            total_size += size_mb
            print(f"  {cache_file.name} ({size_mb:.2f} MB)")

        print(f"Total cache size: {total_size:.2f} MB")

def main():
    """Demo for target vector computation"""
    from moe_edit import load_model_with_cpu_offload
    
    model_name = "/root/autodl-tmp/Qwen3-30B-A3B"
    
    # Load model
    model, tokenizer = load_model_with_cpu_offload(
        model_name=model_name,
        gpu_memory_limit="96GiB",
        cpu_memory_limit="100GiB",
        # offload_folder="./offload_cache"
    )
    
    # Create target vector computer
    computer = TargetVectorComputer(
        model=model,
        tokenizer=tokenizer,
        cache_dir="./target_vector_cache",
        device="cuda"
    )
    
    # Example requests
    requests = [
        # EditRequest(
        #     prompt="The capital of {} is",
        #     subject="France",
        #     target_new=" Berlin",
        #     case_id=1
        # ),
        EditRequest(
            prompt="The CEO of {} is",
            subject="Apple",
            target_new=" Elon Musk",
            case_id=2
        ),
    ]
    
    # Compute target vectors
    print("\n=== Computing Target Vectors ===")
    results = computer.compute_target_vectors(
        requests=requests,
        layer_idx=13,
        loss_layer=47,
        force_recompute=False
    )
    
    # Print results
    print("\n=== Results Summary ===")
    for case_id, result in results.items():
        print(f"Case {case_id}:")
        print(f"  Initial probability: {result.initial_prob:.6f}")
        print(f"  Final probability: {result.final_prob:.6f}")
        print(f"  Improvement: {result.final_prob/result.initial_prob:.2f}x" if result.initial_prob > 0 else "  Improvement: N/A")
        print(f"  Converged: {result.converged}")
        print(f"  Steps: {result.optimization_steps}")
        print(f"  Target vector norm: {result.target_vector.norm().item():.4f}")
        print()
    
    # Show cache status
    print("=== Cache Status ===")
    computer.list_cache()

if __name__ == "__main__":
    main() 