
import torch
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, List, Any
from enum import Enum
import re

class PAPOVersion(str, Enum):
    
    V0 = "v0"  
    V1 = "v1"  
    V2 = "v2"  

VISUAL_MASK_KEYWORDS = ["pixel", "image"]  
AUDIO_MASK_KEYWORDS = ["audio", "input_features"]  

SKIP_KEYWORDS = ["attention_mask", "lengths", "position", "grid", "second_per_grid"]

CONTROL_KEYS = ["use_audio_in_video", "rope_deltas"]

def _mask_tensor(tensor: torch.Tensor, mask_ratio: float = 0.9, noise: bool = False) -> torch.Tensor:
    
    if tensor is None:
        return None
    
    mask = torch.rand_like(tensor.float()) > mask_ratio
    
    if noise:
        
        noise_tensor = torch.randn_like(tensor.float()) * 0.1
        return torch.where(mask, tensor.float(), noise_tensor).to(tensor.dtype)
    else:
        
        return tensor * mask.to(tensor.dtype)

def _should_mask_key(key: str, mask_keywords: List[str]) -> bool:
    
    k_lower = key.lower()
    return any(kw in k_lower for kw in mask_keywords)

def _should_skip_key(key: str) -> bool:
    
    k_lower = key.lower()
    return any(kw in k_lower for kw in SKIP_KEYWORDS) or key in CONTROL_KEYS

def mask_all_multimodal(
    multimodal_inputs: Dict[str, Any],
    mask_ratio: float = 0.9,
    noise: bool = False
) -> Dict[str, Any]:
    
    masked = {}
    all_mask_keywords = VISUAL_MASK_KEYWORDS + AUDIO_MASK_KEYWORDS
    
    for key, value in multimodal_inputs.items():
        
        if _should_skip_key(key):
            masked[key] = value
            continue
        
        should_mask = _should_mask_key(key, all_mask_keywords)
        
        if should_mask and isinstance(value, torch.Tensor) and value.numel() > 0:
            masked[key] = _mask_tensor(value, mask_ratio, noise)
        else:
            masked[key] = value
            
    return masked

def mask_visual_inputs(
    multimodal_inputs: Dict[str, Any],
    mask_ratio: float = 0.9,
    noise: bool = False
) -> Dict[str, Any]:
    
    masked = {}
    
    for key, value in multimodal_inputs.items():
        
        if _should_skip_key(key):
            masked[key] = value
            continue
        
        should_mask = _should_mask_key(key, VISUAL_MASK_KEYWORDS)
        
        if should_mask and isinstance(value, torch.Tensor) and value.numel() > 0:
            masked[key] = _mask_tensor(value, mask_ratio, noise)
        else:
            
            masked[key] = value
            
    return masked

def mask_audio_inputs(
    multimodal_inputs: Dict[str, Any],
    mask_ratio: float = 0.9,
    noise: bool = False
) -> Dict[str, Any]:
    
    masked = {}
    
    for key, value in multimodal_inputs.items():
        
        if _should_skip_key(key):
            masked[key] = value
            continue
        
        should_mask = _should_mask_key(key, AUDIO_MASK_KEYWORDS)
        
        if should_mask and isinstance(value, torch.Tensor) and value.numel() > 0:
            masked[key] = _mask_tensor(value, mask_ratio, noise)
        else:
            
            masked[key] = value
            
    return masked

def compute_kl_divergence(
    log_probs: torch.Tensor,
    log_probs_masked: torch.Tensor,
    kl_penalty: str = "kl"
) -> torch.Tensor:
    
    log_probs = log_probs.float()
    log_probs_masked = log_probs_masked.float()
    
    if kl_penalty == "kl":
        
        return log_probs - log_probs_masked
    
    elif kl_penalty == "abs":
        return (log_probs - log_probs_masked).abs()
    
    elif kl_penalty == "mse":
        return 0.5 * (log_probs - log_probs_masked).square()
    
    elif kl_penalty == "low_var_kl":
        
        kl = (log_probs_masked - log_probs).clamp(-20.0, 20.0)
        kld = (kl.exp() - kl - 1).contiguous()
        return torch.clamp(kld, min=-10.0, max=10.0)
    
    else:
        raise ValueError(f"Unknown KL penalty type: {kl_penalty}")

def compute_entropy_loss(log_probs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    
    masked_log_probs = log_probs * mask
    total_tokens = mask.sum()
    if total_tokens > 0:
        return -masked_log_probs.sum() / total_tokens
    return torch.tensor(0.0, device=log_probs.device, dtype=log_probs.dtype)

def compute_papo_loss_v0(
    log_probs_full: torch.Tensor,
    log_probs_masked: torch.Tensor,
    completion_mask: torch.Tensor,
    kl_coef: float = 0.01,
    entropy_coef: float = 0.03,
    kl_penalty: str = "kl",
) -> Tuple[torch.Tensor, Dict[str, float]]:
    
    kl = compute_kl_divergence(log_probs_full, log_probs_masked, kl_penalty)
    kl_loss = (kl * completion_mask).sum() / completion_mask.sum().clamp(min=1)
    
    entropy_full = compute_entropy_loss(log_probs_full, completion_mask)
    entropy_masked = compute_entropy_loss(log_probs_masked, completion_mask)
    entropy_total = entropy_full + entropy_masked
    
    kl_term = -kl_coef * kl_loss  
    entropy_term = entropy_coef * entropy_total  
    
    loss = kl_term + entropy_term
    
    metrics = {
        
        "papo/kl_loss": kl_loss.detach().item(),
        
        "papo/entropy_full": entropy_full.detach().item(),
        "papo/entropy_masked": entropy_masked.detach().item(),
        "papo/entropy_total": entropy_total.detach().item(),
        
        "papo/kl_term": kl_term.detach().item(),
        "papo/entropy_term": entropy_term.detach().item(),
        "papo/loss": loss.detach().item(),
    }
    
    return loss, metrics

def compute_papo_loss_v1(
    log_probs_full: torch.Tensor,
    log_probs_no_v: torch.Tensor,
    log_probs_no_a: torch.Tensor,
    completion_mask: torch.Tensor,
    kl_coef: float = 0.01,
    entropy_coef: float = 0.03,
    kl_penalty: str = "kl",
) -> Tuple[torch.Tensor, Dict[str, float]]:
    
    kl_v = compute_kl_divergence(log_probs_full, log_probs_no_v, kl_penalty)
    kl_a = compute_kl_divergence(log_probs_full, log_probs_no_a, kl_penalty)
    
    kl_v_loss = (kl_v * completion_mask).sum() / completion_mask.sum().clamp(min=1)
    kl_a_loss = (kl_a * completion_mask).sum() / completion_mask.sum().clamp(min=1)
    
    entropy_full = compute_entropy_loss(log_probs_full, completion_mask)
    entropy_no_v = compute_entropy_loss(log_probs_no_v, completion_mask)
    entropy_no_a = compute_entropy_loss(log_probs_no_a, completion_mask)
    entropy_total = entropy_full + entropy_no_v + entropy_no_a
    
    kl_total = kl_v_loss + kl_a_loss
    kl_term = -kl_coef * kl_total  
    entropy_term = entropy_coef * entropy_total  
    
    loss = kl_term + entropy_term
    
    metrics = {
        
        "papo/kl_v_loss": kl_v_loss.detach().item(),
        "papo/kl_a_loss": kl_a_loss.detach().item(),
        "papo/kl_total": kl_total.detach().item(),
        
        "papo/entropy_full": entropy_full.detach().item(),
        "papo/entropy_no_v": entropy_no_v.detach().item(),
        "papo/entropy_no_a": entropy_no_a.detach().item(),
        "papo/entropy_total": entropy_total.detach().item(),
        
        "papo/kl_term": kl_term.detach().item(),
        "papo/entropy_term": entropy_term.detach().item(),
        "papo/loss": loss.detach().item(),
    }
    
    return loss, metrics

def extract_think_text(text: str) -> str:
    
    if not isinstance(text, str):
        return ""
    m = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
    return m.group(1).strip() if m else text.strip()

def split_sentences(text: str) -> List[str]:
    
    if not text:
        return []
    
    parts = re.split(r"[!?\.,]+\s*|\n+", text)
    return [p.strip() for p in parts if p.strip()]

def compute_modality_routing_mask(
    sim_matrix_v: Optional[torch.Tensor],
    sim_matrix_a: Optional[torch.Tensor],
    threshold: float = 0.5,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    if sim_matrix_v is not None and sim_matrix_v.numel() > 0:
        score_v = sim_matrix_v.max(dim=-1).values  
    else:
        
        if sim_matrix_a is not None:
            score_v = torch.full((sim_matrix_a.size(0),), float('-inf'), 
                                device=sim_matrix_a.device)
        else:
            return None, None
    
    if sim_matrix_a is not None and sim_matrix_a.numel() > 0:
        score_a = sim_matrix_a.max(dim=-1).values  
    else:
        
        score_a = torch.full_like(score_v, float('-inf'))
    
    is_visual = (score_v > threshold) & (score_v > score_a)
    is_audio = (score_a > threshold) & (score_a > score_v)
    
    return is_visual, is_audio

def sentence_mask_to_token_mask(
    sentence_mask: torch.Tensor,
    sentence_spans: List[Tuple[int, int]],
    seq_len: int,
    device: torch.device,
) -> torch.Tensor:
    
    token_mask = torch.zeros(seq_len, dtype=torch.float32, device=device)
    
    for i, (start, end) in enumerate(sentence_spans):
        if i < len(sentence_mask) and sentence_mask[i]:
            token_mask[start:end] = 1.0
            
    return token_mask

def compute_papo_loss_v2(
    log_probs_full: torch.Tensor,
    log_probs_no_v: torch.Tensor,
    log_probs_no_a: torch.Tensor,
    completion_mask: torch.Tensor,
    visual_token_mask: torch.Tensor,
    audio_token_mask: torch.Tensor,
    kl_coef: float = 0.01,
    entropy_coef: float = 0.03,
    kl_penalty: str = "kl",
) -> Tuple[torch.Tensor, Dict[str, float]]:
    
    kl_v = compute_kl_divergence(log_probs_full, log_probs_no_v, kl_penalty)
    kl_a = compute_kl_divergence(log_probs_full, log_probs_no_a, kl_penalty)
    
    visual_mask_combined = completion_mask * visual_token_mask
    audio_mask_combined = completion_mask * audio_token_mask
    
    visual_tokens = visual_mask_combined.sum().clamp(min=1)
    audio_tokens = audio_mask_combined.sum().clamp(min=1)
    
    kl_v_loss = (kl_v * visual_mask_combined).sum() / visual_tokens
    kl_a_loss = (kl_a * audio_mask_combined).sum() / audio_tokens
    
    entropy_full = compute_entropy_loss(log_probs_full, completion_mask)
    entropy_no_v = compute_entropy_loss(log_probs_no_v, completion_mask)
    entropy_no_a = compute_entropy_loss(log_probs_no_a, completion_mask)
    entropy_total = entropy_full + entropy_no_v + entropy_no_a
    
    kl_total = kl_v_loss + kl_a_loss
    kl_term = -kl_coef * kl_total  
    entropy_term = entropy_coef * entropy_total  
    
    loss = kl_term + entropy_term
    
    metrics = {
        
        "papo/kl_v_loss": kl_v_loss.detach().item(),
        "papo/kl_a_loss": kl_a_loss.detach().item(),
        "papo/kl_total": kl_total.detach().item(),
        
        "papo/visual_tokens": visual_tokens.detach().item(),
        "papo/audio_tokens": audio_tokens.detach().item(),
        "papo/neutral_ratio": (1 - (visual_mask_combined + audio_mask_combined).sum() / completion_mask.sum().clamp(min=1)).detach().item(),
        
        "papo/entropy_full": entropy_full.detach().item(),
        "papo/entropy_no_v": entropy_no_v.detach().item(),
        "papo/entropy_no_a": entropy_no_a.detach().item(),
        "papo/entropy_total": entropy_total.detach().item(),
        
        "papo/kl_term": kl_term.detach().item(),  
        "papo/entropy_term": entropy_term.detach().item(),  
        "papo/loss": loss.detach().item(),  
    }
    
    return loss, metrics

class PAPOConfig:
    
    def __init__(
        self,
        enabled: bool = False,
        version: str = "v1",
        mask_ratio: float = 0.9,       
        use_noise: bool = False,        
        kl_coef: float = 0.001,         
        entropy_coef: float = 0.03,     
        kl_penalty: str = "kl",
        
        routing_threshold: float = 0.5,
    ):
        self.enabled = enabled
        self.version = PAPOVersion(version)
        self.mask_ratio = mask_ratio
        self.use_noise = use_noise
        self.kl_coef = kl_coef
        self.entropy_coef = entropy_coef
        self.kl_penalty = kl_penalty
        self.routing_threshold = routing_threshold
    
    def __repr__(self):
        return (
            f"PAPOConfig(enabled={self.enabled}, version={self.version.value}, "
            f"mask_ratio={self.mask_ratio}, kl_coef={self.kl_coef}, "
            f"entropy_coef={self.entropy_coef})"
        )
    
    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "PAPOConfig":
        
        return cls(**{k: v for k, v in config_dict.items() if k in cls.__init__.__code__.co_varnames})
