import torch
import typing as tp
import math 
import torch.nn as nn
import torch.nn.functional as F

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'
    
    
    
#################################################################################
#                      Rotary Positional Embedding Functions                    #
#################################################################################
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 

def precompute_freqs_cis(
    seq_len: int, 
    n_elem: int, 
    base: int = 10000, 
    rope_scaling: tp.Optional[dict] = None, 
    train_seq_len: tp.Optional[int] = None
):  
    """
    Returns:
      freqs_cis: (seq_len, n_elem//2, 2) where [:, :, 0]=cos, [:, :, 1]=sin
    rope_scaling:
      {"type": "linear" | "ntk" | "yarn",
       "factor": s (>=1),
       # for yarn:
       "alpha": 1.0, "beta": 32.0,
       "attn_temperature_t": float}
    """
    half = n_elem // 2
    dim_index = torch.arange(0, n_elem, 2)[:half].float()
    device = dim_index.device
    t_dtype = torch.float32
    def build_inv_freq(_base: float) -> torch.Tensor:
        inv = 1.0 / (_base ** (dim_index / n_elem))
        return inv.to(dtype=t_dtype, device=device)   # shape: [half]
    
    method = None
    s = 1.0
    alpha = 1.0
    beta = 32.0
    t_temp = None
    if rope_scaling is not None:
        method = rope_scaling.get("type", None)
        s = float(rope_scaling.get("factor", 1.0))
        alpha = float(rope_scaling.get("alpha", alpha))
        beta  = float(rope_scaling.get("beta",  beta))
        t_temp = rope_scaling.get("attn_temperature_t", None)
        
    if method == "ntk":
        d = float(n_elem)
        exp = d / max(d - 2.0, 2.0)  
        base_prime = base * (s ** exp)
        inv_freq = build_inv_freq(base_prime)
        t_scaled = torch.arange(seq_len, dtype=t_dtype, device=device)
    elif method == "yarn":
        inv_freq_orig = build_inv_freq(float(base)) 
        L_train = float(train_seq_len) if train_seq_len is not None else float(seq_len)
        r = (L_train * inv_freq_orig) / (2.0 * math.pi)
        gamma = torch.where(
            r < alpha, torch.zeros_like(r),
            torch.where(r > beta, torch.ones_like(r), (r - alpha) / max(beta - alpha, 1e-6))
        )
        scale_d = gamma + (1.0 - gamma) / max(s, 1.0)
        inv_freq = inv_freq_orig * scale_d
        t_scaled = torch.arange(seq_len, dtype=t_dtype, device=device)
    else:
        # default: linear scaling of position
        inv_freq = build_inv_freq(float(base))
        t_scaled = torch.arange(seq_len, dtype=t_dtype, device=device) / max(s, 1.0)
    
    freqs = torch.outer(t_scaled, inv_freq) # (seq_len, head_dim // 2)
    amp = 1.0
    if method == "yarn":
        if t_temp is None:
            amp = 0.1 * math.log(max(s, 1.0)) + 1.0   # = 1/sqrt(t)
        else:
            amp = 1.0 / math.sqrt(float(t_temp))
            
    freqs_cis = torch.polar(torch.full_like(freqs, fill_value=amp), freqs)
    # freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (seq_len, head_dim // 2, 2)
    # cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
    return cache 

def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
    # x: (bs, seq_len, n_head, head_dim)
    # freqs_cis (seq_len, head_dim // 2, 2)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    # xshaped = x.reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
    x_out2 = torch.stack([
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)
    # return x_out2



def interleave_tokens(x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, int]:
    """
    Interleaves two different token sequence x and token sequence y.
    
    Args:
        x (torch.Tensor): e.g., Video token sequence of shape [bsz, t1, c].
        y (torch.Tensor): e.g., Audio token sequence of shape [bsz, t2, c].
        It is assumed that t2 is an integer multiple of t1.
    
    Returns:
        merged (torch.Tensor): The interleaved token sequence of shape [bsz, t1*(r+1), c].
        r (int): The ratio (t2 // t1).
    """
    bsz, t1, c = x.shape
    _, t2, c = y.shape
    assert x.size(2) == y.size(2), "Channel dimensions must match"
    assert t2 % t1 == 0, "Audio token count must be an integer multiple of video token count"
    
    r = t2 // t1
    
    # x_orig = x.clone() 
    
    y_reshaped = y.view(bsz, t1, r, c)
    
    # Unsqueeze x to shape [bsz, t1, 1, c]
    x_unsq = x.unsqueeze(2)
    
    # Concatenate along the new dimension; result: [bsz, t1, (1 + r), c]
    combined = torch.cat((x_unsq, y_reshaped), dim=2)
    
    merged = combined.view(bsz, t1 * (r + 1), c)
    # if r == 1:
    #     for i in range(t1):
    #         assert torch.equal(merged[0, 2 * i], x_orig[0, i]), f"Token at position {2*i} does not match x[{i}]"
    #         assert torch.equal(merged[0, 2 * i + 1], y_reshaped[0, i, 0]), f"Token at position {2*i+1} does not match y[{i}]"
    
    return merged, r

def noise_augment(h: torch.Tensor, k_max: float) -> torch.Tensor:
    """
    Noise augmentation on latent representation.
    Modifiled from https://arxiv.org/pdf/2411.18447

    Args:
        h (torch.Tensor): latent representation with shape [B, T, C]
        k_max (float): maximum scaling factor; scaling parameter k_t is sampled from Uniform[0, k_max]

    Returns:
        torch.Tensor: augmented latent representation with the same shape as h.
    """
    B, T, C = h.shape
    noise = torch.randn_like(h)
    k = torch.rand(B, T, 1, device=h.device, dtype=h.dtype) * k_max
    # noise augmentation: h_aug = k * noise + (1 - k) * h
    h_aug = k * noise + (1 - k) * h
    return h_aug

def generate_attention_mask(modalities: torch.Tensor) -> torch.Tensor:
    """
    Generates an attention mask for multi-head attention based on a sequence of modality IDs.
    Parameters:
        modalities (torch.Tensor): A 1D tensor of shape (seq_len,), where each element represents
                                     the modality ID for a token (e.g., 0 for modality A and 1 for modality B).
    
    Returns:
        torch.Tensor: A 2D boolean tensor of shape (seq_len, seq_len) representing the attention mask.
                      If mask[i, j] is True, then token at position i is allowed to attend to token at position j.
                
    """
    seq_len = modalities.size(0)
    
    # Group contiguous tokens of the same modality by assigning block IDs.
    block_ids = torch.zeros(seq_len, dtype=torch.long)
    current_block = 0
    block_ids[0] = current_block
    for i in range(1, seq_len):
        if modalities[i] == modalities[i-1]:
            block_ids[i] = current_block
        else:
            current_block += 1
            block_ids[i] = current_block
            
    # Create a causal mask: a lower triangular matrix where mask[i, j] is True if j <= i.
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    
    # Create a mask that allows full attention within the same block.
    same_block_mask = (block_ids.unsqueeze(0) == block_ids.unsqueeze(1))
    
    # Combine the causal mask with the same-block mask:
    # Tokens can attend if they are in the same block OR if the token is from an earlier or the same position.
    mask = causal_mask | same_block_mask
    return mask
