import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple

# ---------- Small, reusable helpers ----------

def safe_to_device(x, device):
    if x is None:
        return None
    if isinstance(x, dict):
        return {k: (v.to(device) if hasattr(v, "to") else v) for k, v in x.items()}
    return x.to(device) if hasattr(x, "to") else x

def present(x) -> bool:
    if x is None: 
        return False
    if isinstance(x, torch.Tensor):
        return x.numel() > 0
    if isinstance(x, dict):
        # assume at least one tensor-like
        return any(hasattr(v, "numel") and v.numel() > 0 for v in x.values())
    return True

def check_num(t: torch.Tensor, name: str) -> torch.Tensor:
    if t is None:
        return None
    if torch.isnan(t).any() or torch.isinf(t).any():
        print(f)
        return torch.where(torch.isnan(t) | torch.isinf(t), torch.zeros_like(t), t)
    return t

def ensure_1d_positions(pos: Optional[torch.Tensor], B: int, T: int, device) -> torch.Tensor:
    """Return [B, T] 1D positions; if pos is None or wrong shape, use arange(T)."""
    if (pos is None) or (not isinstance(pos, torch.Tensor)):
        base = torch.arange(T, device=device).float().unsqueeze(0).expand(B, -1)
        return base
    if pos.dim() == 1:
        pos = pos.view(1, -1).expand(B, -1)
    if pos.size(1) != T:
        # interpolate or crop/pad to T
        if pos.size(1) > T:
            pos = pos[:, :T]
        else:
            pad = torch.linspace(pos.size(1), T - 1, T - pos.size(1), device=device).unsqueeze(0)
            pos = torch.cat([pos, pad], dim=1)
    return pos

# ---------- Modality-aware fusion (no zero imputation) ----------

class ModalFusion(nn.Module):
    """Gated average over available modalities. 
    Inputs are dict of modality -> [B, D] tensors. Missing modalities can be None.
    Weights are produced only for present modalities, then normalized with softmax.
    """
    def __init__(self, dim: int, modalities=("img","txt","ts","static")):
        super().__init__()
        self.modalities = list(modalities)
        self.gates = nn.ModuleDict({m: nn.Sequential(nn.Linear(dim, dim//2), nn.ReLU(), nn.Linear(dim//2, 1))
                                    for m in self.modalities})

    def forward(self, feats: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
        present_feats = {m: f for m, f in feats.items() if present(f)}
        if not present_feats:
            # fallback: create a tiny zero vector to avoid crash
            first_key = self.modalities[0]
            return torch.zeros(1, 1).to(next(self.parameters()).device) * 0.0

        scores = []
        ordered = []
        for m, f in present_feats.items():
            # reduce token dim if needed
            if f.dim() == 3:  # [B, N, D] -> [B, D]
                f = f.mean(1)
            ordered.append(f)
            scores.append(self.gates[m](f))  # [B,1]
        H = torch.cat(scores, dim=1)  # [B, M]
        W = torch.softmax(H, dim=1).unsqueeze(-1)  # [B, M, 1]
        S = torch.stack(ordered, dim=1)  # [B, M, D]
        return (W * S).sum(1)  # [B, D]
