"""Core CMI Loss computation functions."""

import torch
import torch.nn.functional as F
from typing import Optional, Tuple, List


IGNORE_INDEX = -100


def compute_cmi_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    cmi_lambda: float,
    thinking_start_tokens: Optional[List[int]] = None,
    thinking_end_tokens: Optional[List[int]] = None,
    thinking_weight: float = 0.1,
    normalize_losses: bool = True,
    sample_types: Optional[List[str]] = None,
    apply_to_harmful_only: bool = False,
) -> Tuple[torch.Tensor, dict]:
    """
    Compute CMI-enhanced loss following the theoretical framework.
    
    Total loss = L_main + λ * L_shortcut
    where:
    - L_main: standard NLL loss for P(y,c|x) 
    - L_shortcut: NLL loss for shortcut head P(y|x)
    - λ: negative regularization parameter
    
    Args:
        logits: Model output logits [batch_size, seq_len, vocab_size]
        labels: Target labels [batch_size, seq_len]
        cmi_lambda: Regularization strength (negative value)
        thinking_start_tokens: Token IDs marking thinking start
        thinking_end_tokens: Token IDs marking thinking end
        thinking_weight: Weight for thinking tokens in shortcut loss
        normalize_losses: Whether to normalize loss scales
        sample_types: List of sample types ('harmful' or 'benign')
        apply_to_harmful_only: Apply CMI only to harmful samples
        
    Returns:
        total_loss: Combined loss value
        metrics: Dictionary of loss components for logging
    """
    batch_size = logits.shape[0]
    
    # Compute main loss (standard cross-entropy)
    main_loss = compute_cross_entropy_loss(logits, labels)
    
    # Create shortcut labels by masking or weighting thinking regions
    shortcut_labels = create_shortcut_labels(
        labels, 
        thinking_start_tokens,
        thinking_end_tokens,
        thinking_weight
    )
    
    # Compute shortcut loss
    shortcut_loss = compute_cross_entropy_loss(logits, shortcut_labels)
    
    # Handle selective application to harmful samples
    if apply_to_harmful_only and sample_types is not None:
        harmful_mask = torch.tensor(
            [1.0 if st == 'harmful' else 0.0 for st in sample_types],
            device=logits.device,
            dtype=torch.float
        )
        harmful_ratio = harmful_mask.mean().item()
        
        if normalize_losses:
            scale_factor = (shortcut_loss.detach() / main_loss.detach() + 1e-8)
            shortcut_loss_normalized = shortcut_loss / scale_factor
            total_loss = main_loss + cmi_lambda * shortcut_loss_normalized * harmful_ratio
        else:
            total_loss = main_loss + cmi_lambda * shortcut_loss * harmful_ratio
    else:
        # Apply to all samples
        if normalize_losses:
            scale_factor = (shortcut_loss.detach() / main_loss.detach() + 1e-8)
            shortcut_loss_normalized = shortcut_loss / scale_factor
            total_loss = main_loss + cmi_lambda * shortcut_loss_normalized
        else:
            total_loss = main_loss + cmi_lambda * shortcut_loss
    
    # Prepare metrics
    metrics = {
        "main_loss": main_loss.detach().cpu().item(),
        "shortcut_loss": shortcut_loss.detach().cpu().item(),
        "total_loss": total_loss.detach().cpu().item(),
        "cmi_lambda": cmi_lambda,
    }
    
    if apply_to_harmful_only and sample_types is not None:
        metrics["harmful_ratio"] = harmful_ratio
    
    return total_loss, metrics


def compute_cross_entropy_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    reduction: str = "mean"
) -> torch.Tensor:
    """Compute standard cross-entropy loss."""
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    # Flatten the tokens
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)
    
    # Compute loss
    loss = F.cross_entropy(
        shift_logits,
        shift_labels,
        ignore_index=IGNORE_INDEX,
        reduction=reduction
    )
    
    return loss


def create_shortcut_labels(
    labels: torch.Tensor,
    thinking_start_tokens: Optional[List[int]] = None,
    thinking_end_tokens: Optional[List[int]] = None,
    thinking_weight: float = 0.1
) -> torch.Tensor:
    """
    Create shortcut labels by masking or weighting thinking regions.
    
    Args:
        labels: Original labels [batch_size, seq_len]
        thinking_start_tokens: Token IDs marking thinking start
        thinking_end_tokens: Token IDs marking thinking end
        thinking_weight: Weight for thinking tokens (0=mask completely, 1=no masking)
        
    Returns:
        shortcut_labels: Modified labels for shortcut loss computation
    """
    if thinking_start_tokens is None or thinking_end_tokens is None:
        # No thinking tokens specified, return original labels
        return labels.clone()
    
    batch_size = labels.shape[0]
    shortcut_labels = labels.clone()
    
    for i in range(batch_size):
        sample_labels = labels[i]
        
        # Find thinking region boundaries
        start_pos = find_token_sequence(sample_labels, thinking_start_tokens)
        end_pos = find_token_sequence(sample_labels, thinking_end_tokens)
        
        if start_pos is not None and end_pos is not None and start_pos < end_pos:
            mask_start = start_pos
            mask_end = end_pos + len(thinking_end_tokens)
            
            if thinking_weight == 0.0:
                # Complete masking: ignore thinking tokens entirely
                shortcut_labels[i, mask_start:mask_end] = IGNORE_INDEX
            # For weighted masking, keep original labels
            # (weight will be applied during loss computation)
    
    return shortcut_labels


def find_token_sequence(
    labels: torch.Tensor,
    target_sequence: List[int]
) -> Optional[int]:
    """Find position of target token sequence in labels."""
    if len(target_sequence) == 0:
        return None
    
    # Only search in valid label positions
    valid_mask = labels != IGNORE_INDEX
    if not valid_mask.any():
        return None
    
    for i in range(len(labels) - len(target_sequence) + 1):
        if i + len(target_sequence) <= len(labels):
            match = True
            for j, target_token in enumerate(target_sequence):
                if labels[i + j] != target_token:
                    match = False
                    break
            if match:
                return i
    
    return None


def get_cmi_lambda_scheduled(
    current_step: int,
    max_steps: int,
    warmup_ratio: float = 0.3,
    rampup_ratio: float = 0.5,
    lambda_start: float = -0.01,
    lambda_end: float = -0.1
) -> float:
    """
    Get CMI lambda value with dynamic scheduling.
    
    Args:
        current_step: Current training step
        max_steps: Total training steps
        warmup_ratio: Fraction of steps for warmup
        rampup_ratio: Fraction of steps for rampup
        lambda_start: Starting lambda value
        lambda_end: Final lambda value
        
    Returns:
        Current lambda value
    """
    warmup_steps = int(max_steps * warmup_ratio)
    rampup_steps = int(max_steps * rampup_ratio)
    
    if current_step < warmup_steps:
        # Warmup phase: use standard SFT (return 0)
        return 0.0
    elif current_step < warmup_steps + rampup_steps:
        # Rampup phase: gradually increase lambda
        progress = (current_step - warmup_steps) / rampup_steps
        return lambda_start + (lambda_end - lambda_start) * progress
    else:
        # Stable phase: use full lambda
        return lambda_end