"""
Merged utility functions and classes for MoE knowledge editing
"""

import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from transformers import AutoTokenizer


# ============================================================================
# Data Classes
# ============================================================================

@dataclass
class EditRequest:
    """Single knowledge edit request"""
    prompt: str          # "The capital of {} is"
    subject: str         # "France"
    target_new: str      # " Berlin"
    case_id: int
    prefixes: Optional[List[str]] = None  # List of prefixes to add before prompt


@dataclass
class TargetVectorResult:
    """Result from target vector computation"""
    target_vector: torch.Tensor
    case_id: int
    metadata: Optional[Dict[str, Any]] = None








# ============================================================================
# Model Loading
# ============================================================================




# ============================================================================
# Token Position Functions
# ============================================================================

def find_subject_token_positions(
    tokenizer: AutoTokenizer,
    prompt: str,
    subject: str,
    inputs: Dict
) -> Tuple[int, int]:
    """
    Find the start and end positions of subject tokens in the tokenized input

    Args:
        tokenizer: The tokenizer used
        prompt: The prompt template (e.g., "The capital of {} is")
        subject: The subject string (e.g., "France")
        inputs: Tokenized inputs from tokenizer

    Returns:
        Tuple of (start_pos, end_pos) where end_pos is the last subject token position
    """

    # Tokenize the subject separately to understand its token structure
    subject_tokens = tokenizer(subject, add_special_tokens=False)['input_ids']
    input_ids = inputs['input_ids'][0].cpu().tolist()

    subject_start_pos = None
    subject_end_pos = None

    # Search for the subject token sequence in the input
    for i in range(len(input_ids) - len(subject_tokens) + 1):
        if input_ids[i:i+len(subject_tokens)] == subject_tokens:
            subject_start_pos = i
            subject_end_pos = i + len(subject_tokens) - 1
            break

    # Fallback: if exact match fails, try to find subject string in decoded tokens
    if subject_start_pos is None:
        # Simple approach: find the best matching consecutive tokens
        subject_lower = subject.lower()

        # Try different window sizes to find the subject
        for window_size in range(1, min(len(input_ids) + 1, 6)):  # Try up to 5 tokens
            for i in range(len(input_ids) - window_size + 1):
                window_text = tokenizer.decode(input_ids[i:i+window_size], skip_special_tokens=True).strip().lower()

                # Check if subject is contained in this window
                if subject_lower in window_text:
                    subject_start_pos = i
                    subject_end_pos = i + window_size - 1
                    break

            if subject_start_pos is not None:
                break
    # Simple fallback: use a reasonable default if not found
    if subject_start_pos is None:
        seq_len = len(input_ids)
        subject_start_pos = max(0, seq_len // 2 - 1)
        subject_end_pos = subject_start_pos

    return subject_start_pos, subject_end_pos




# ============================================================================
# Evaluation Functions
# ============================================================================

def pre_edit_prob(model, tokenizer, request: EditRequest, device: str = "cuda") -> float:
    """Get target token probability before editing"""
    prompt_text = request.prompt.format(request.subject) if '{}' in request.prompt else request.prompt
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1, :]  # Last token logits
        probs = F.softmax(logits, dim=-1)

        # Get target token ID - ensure target_new is a string
        target_text = str(request.target_new) if request.target_new is not None else ""
        if not target_text.strip():
            print(f"Empty target_new for request, using default")
            return 0.0

        target_ids = tokenizer(target_text, return_tensors="pt").input_ids
        target_id = target_ids[0, 0] if target_ids.shape[1] > 0 else 0

        return probs[target_id].item()


def post_edit_prob(model, tokenizer, request: EditRequest, device: str = "cuda") -> float:
    """Get target token probability after editing"""
    prompt_text = request.prompt.format(request.subject) if '{}' in request.prompt else request.prompt
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1, :]  # Last token logits
        probs = torch.softmax(logits, dim=-1)

        # Get target token ID
        target_ids = tokenizer(request.target_new, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
        if target_ids.shape[1] == 0:
            return 0.0

        target_id = target_ids[0, 0] if target_ids.shape[1] > 0 else 0

        return probs[target_id].item()









# ============================================================================
# AlphaEdit-style Token Positioning Functions
# ============================================================================




def get_last_subject_token_position_alphaedit(
    tokenizer: AutoTokenizer,
    prompt: str,
    subject: str,
    inputs: Dict,
    fact_token_strategy: str = "subject_last",
    verbose: bool = False
) -> int:
    """
    Get the last subject token position using AlphaEdit-style strategy.

    Args:
        tokenizer: The tokenizer
        prompt: Template string with {} placeholder
        subject: Subject string
        inputs: Tokenized inputs (for compatibility)
        fact_token_strategy: Token positioning strategy
        verbose: Whether to print debug info

    Returns:
        Index of the target token based on strategy
    """
    from .token_utils import get_last_subject_token_position

    return get_last_subject_token_position(
        tokenizer=tokenizer,
        prompt=prompt,
        subject=subject,
        inputs=inputs,
        fact_token_strategy=fact_token_strategy,
        verbose=verbose
    )

