import torch
import torch.distributed as dist



def ddp_sample_scalar_uniform(low: float, high: float, device) -> float:
    import torch.distributed as dist
    if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
        t = torch.empty(1, device=device).uniform_(low, high)
    else:
        t = torch.empty(1, device=device)
    if dist.is_available() and dist.is_initialized():
        dist.broadcast(t, src=0)
    return float(t.item())


import torch

def _lrows(meta: dict) -> int:
    """Number of lead rows after 2D patching."""
    L, pz_ch = int(meta["L"]), int(meta.get("pz_ch", 1))
    assert L % pz_ch == 0, f"L={L} not divisible by pz_ch={pz_ch}"
    return L // pz_ch


def tokens_crop_keep(tokens: torch.Tensor, meta: dict, remain_ratio: float) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt = int(meta["Nt"])

    if meta["lead_wise"] == 0:
        K = max(1, min(Nt, int(round(remain_ratio * Nt))))
        with torch.no_grad():
            if Nt - K > 0:
                s = torch.floor(torch.rand(B, device=tokens.device) * (Nt - K + 1)).long()
            else:
                s = torch.zeros(B, device=tokens.device, dtype=torch.long)
            ar  = torch.arange(K, device=tokens.device).view(1, K)  # (1,K)
            idx = (s.view(B,1) + ar).clamp_max(Nt-1)                 # (B,K)
        out = tokens.gather(dim=1, index=idx.unsqueeze(-1).expand(B, K, C))  # [B,K,C]
        return out

    # lead_wise = 1  (2D patch)
    Lr = _lrows(meta)                   # rows after patching
    assert N == Lr * Nt, f"N={N} != Lr*Nt={Lr*Nt}"
    x = tokens.view(B, Lr, Nt, C)       # [B,Lr,Nt,C]

    K = max(1, min(Nt, int(round(remain_ratio * Nt))))
    with torch.no_grad():
        if Nt - K > 0:
            s = torch.floor(torch.rand(B, device=tokens.device) * (Nt - K + 1)).long()  # (B,)
        else:
            s = torch.zeros(B, device=tokens.device, dtype=torch.long)
        ar   = torch.arange(K, device=tokens.device).view(1,1,K,1)                      # (1,1,K,1)
        idxT = (s.view(B,1,1,1) + ar).clamp_max(Nt-1).expand(B, Lr, K, 1)               # (B,Lr,K,1)

    out = x.gather(dim=2, index=idxT.expand(-1, -1, -1, C))   # [B,Lr,K,C]
    return out.reshape(B, Lr*K, C)                             # flatten (row,time)


def tokens_drop_random(tokens: torch.Tensor, meta: dict, drop_ratio: float) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt = int(meta["Nt"])

    if meta["lead_wise"] == 0:
        drop_tok = max(0, min(Nt, int(round(drop_ratio * Nt))))
        K = max(1, Nt - drop_tok)
        scores = torch.rand(B, Nt, device=tokens.device)
        keep   = scores.topk(k=K, dim=1).indices.sort(dim=1).values                  # (B,K)
        out = tokens.gather(dim=1, index=keep.unsqueeze(-1).expand(B, K, C))         # [B,K,C]
        return out

    # lead_wise = 1
    Lr = _lrows(meta)
    assert N == Lr * Nt
    x = tokens.view(B, Lr, Nt, C)

    drop_tok = max(0, min(Nt, int(round(drop_ratio * Nt))))
    K = max(1, Nt - drop_tok)
    scores = torch.rand(B, Nt, device=tokens.device)
    keep   = scores.topk(k=K, dim=1).indices.sort(dim=1).values                      # (B,K)
    idxT   = keep.view(B,1,K,1).expand(B, Lr, K, 1)                                  # (B,Lr,K,1)

    out = x.gather(dim=2, index=idxT.expand(-1, -1, -1, C))                          # [B,Lr,K,C]
    return out.reshape(B, Lr*K, C)

def tokens_row_drop(tokens: torch.Tensor, meta: dict, row_drop_ratio: float) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt  = int(meta["Nt"])
    Lr  = _lrows(meta)                            # rows after 2D patch
    assert N == Lr * Nt

    R = Lr
    drop_r = max(0, min(R, int(round(row_drop_ratio * R))))
    keep_r = max(1, R - drop_r)

    with torch.no_grad():
        keep_rows = torch.stack(
            [torch.randperm(R, device=tokens.device)[:keep_r] for _ in range(B)], dim=0
        )                                                                               # (B,keep_r)
        keep_rows, _ = torch.sort(keep_rows, dim=1)

    x = tokens.view(B, Lr, Nt, C)                                                      # [B,Lr,Nt,C]
    idxL = keep_rows.view(B, keep_r, 1, 1).expand(B, keep_r, Nt, C)                    # (B,keep_r,Nt,C)
    out  = x.gather(dim=1, index=idxL)                                                 # [B,keep_r,Nt,C]
    return out.reshape(B, keep_r * Nt, C)
def _lrows(meta: dict) -> int:
    L, pz_ch = int(meta["L"]), int(meta.get("pz_ch", 1))
    assert L % pz_ch == 0
    return L // pz_ch

def tokens_crop_fill_end(
    tokens: torch.Tensor,   # [B, N, C] (after patch, before PE/lead_emb)
    meta: dict,             # {'lead_wise', 'L', 'Nt', 'pz_ch'}
    remain_ratio: float,    
    end_token: torch.Tensor,# [1,1,C] learnable END from model
    share_across_rows: bool = True, 
) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt = int(meta["Nt"])

    if meta["lead_wise"] == 0:
        assert N == Nt, f"N={N} != Nt={Nt}"
        K = max(1, min(Nt, int(round(remain_ratio * Nt))))

        with torch.no_grad():
            if Nt - K > 0:
                s = torch.floor(torch.rand(B, device=tokens.device) * (Nt - K + 1)).long()  # (B,)
            else:
                s = torch.zeros(B, device=tokens.device, dtype=torch.long)
            ar   = torch.arange(Nt, device=tokens.device).view(1, Nt)  # (1,Nt)
            keep = (ar >= s.view(B,1)) & (ar < (s + K).view(B,1))      

        m = (~keep).unsqueeze(-1).expand(B, Nt, C)                      # (B,Nt,C)
        end = end_token.to(tokens.device, tokens.dtype).expand(B, 1, C)
        end = end.expand_as(tokens)                                     # (B,Nt,C)
        out = torch.where(m, end, tokens)                               
        return out

    # lead_wise = 1  (2D patch: N = Lr*Nt)
    Lr = _lrows(meta)
    assert N == Lr * Nt, f"N={N} != Lr*Nt={Lr*Nt}"
    x = tokens.view(B, Lr, Nt, C)                                       # (B,Lr,Nt,C)
    K = max(1, min(Nt, int(round(remain_ratio * Nt))))
    with torch.no_grad():
        if Nt - K > 0:
            s = torch.floor(torch.rand(B, device=tokens.device) * (Nt - K + 1)).long()  # (B,)
        else:
            s = torch.zeros(B, device=tokens.device, dtype=torch.long)
        ar = torch.arange(Nt, device=tokens.device).view(1, 1, Nt, 1)   # (1,1,Nt,1)
        if share_across_rows:
            keep = (ar >= s.view(B,1,1,1)) & (ar < (s + K).view(B,1,1,1))               # (B,1,Nt,1)
            keep = keep.expand(B, Lr, Nt, 1)                                             # (B,Lr,Nt,1)
        else:

            s_rows = torch.floor(torch.rand(B, Lr, device=tokens.device) * (Nt - K + 1)).long() if Nt-K>0 \
                     else torch.zeros(B, Lr, device=tokens.device, dtype=torch.long)
            keep = (ar >= s_rows.view(B,Lr,1,1)) & (ar < (s_rows + K).view(B,Lr,1,1))   # (B,Lr,Nt,1)
    m = (~keep).expand_as(x)                                                            
    end = end_token.to(tokens.device, tokens.dtype).expand_as(x)
    out = torch.where(m, end, x)                                                         # (B,Lr,Nt,C)
    return out.reshape(B, N, C)
import torch

def tokens_row_mask_fill(
    tokens: torch.Tensor,
    meta: dict,
    row_mask_ratio: float,
    mask_token: torch.Tensor | float
) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt  = int(meta["Nt"])
    Lr  = _lrows(meta)                 
    assert N == Lr * Nt, f"N({N}) must equal Lr({Lr}) * Nt({Nt})"


    R = Lr
    mask_r = max(0, min(R - 1, int(round(row_mask_ratio * R))))
    keep_r = max(1, R - mask_r)

    device = tokens.device
    dtype  = tokens.dtype


    if isinstance(mask_token, (int, float)):
        mask_token_vec = torch.full((C,), float(mask_token), device=device, dtype=dtype)
    else:
        mask_token_t = mask_token.to(device=device, dtype=dtype)
        if mask_token_t.dim() == 1 and mask_token_t.shape[0] == C:
            mask_token_vec = mask_token_t
        elif mask_token_t.dim() == 2 and mask_token_t.shape[-1] == C:

            mask_token_vec = mask_token_t
        else:
            raise ValueError(f"mask_token must be scalar or shape [C] / [*, C], got {tuple(mask_token_t.shape)}")


    x = tokens.view(B, Lr, Nt, C).clone()


    with torch.no_grad():

        keep_rows = torch.stack(
            [torch.randperm(R, device=device)[:keep_r] for _ in range(B)], dim=0
        )  # [B, keep_r]
        keep_rows, _ = torch.sort(keep_rows, dim=1)


        row_mask = torch.ones((B, Lr), dtype=torch.bool, device=device)
        row_mask.scatter_(dim=1, index=keep_rows, value=False)   


    row_mask_4d = row_mask.view(B, Lr, 1, 1).expand(B, Lr, Nt, 1)


    if mask_token_vec.dim() == 1:          # [C]
        mt = mask_token_vec.view(1, 1, 1, C)
    else:                                   
        if mask_token_vec.shape[0] not in (1, B):
            raise ValueError(f"mask_token batch dim must be 1 or B; got {mask_token_vec.shape[0]}")
        mt = mask_token_vec.view(mask_token_vec.shape[0], 1, 1, C)
    mt = mt.expand(B, Lr, Nt, C)


    x = torch.where(row_mask_4d, mt, x)

    return x.view(B, N, C)

import torch

def _lrows(meta: dict) -> int:
    L, pz_ch = int(meta["L"]), int(meta.get("pz_ch", 1))
    assert L % pz_ch == 0
    return L // pz_ch

def tokens_drop_random_fill(
    tokens: torch.Tensor,
    meta: dict,
    drop_ratio: float,
    *,
    end_token: torch.Tensor,          
    share_across_rows: bool = True,   
) -> torch.Tensor:

    B, N, C = tokens.shape
    Nt = int(meta["Nt"])


    M = max(0, min(Nt, int(round(drop_ratio * Nt))))

    if meta["lead_wise"] == 0:
        assert N == Nt, f"N={N} != Nt={Nt}"

        if M == 0:
            return tokens
        with torch.no_grad():
            if M == Nt:
                mask = torch.ones(B, Nt, device=tokens.device, dtype=torch.bool)
            else:
                scores = torch.rand(B, Nt, device=tokens.device)
                idx = scores.topk(k=M, dim=1).indices                # (B,M)
                mask = torch.zeros(B, Nt, device=tokens.device, dtype=torch.bool)
                mask.scatter_(dim=1, index=idx, value=True)          # (B,Nt)
        m3  = mask.unsqueeze(-1).expand(B, Nt, C)                    # (B,Nt,C)
        end = end_token.to(tokens.device, tokens.dtype).expand(B, Nt, C)
        out = torch.where(m3, end, tokens)                           # [B,Nt,C]
        return out

    # lead_wise = 1: N = Lr * Nt
    Lr = _lrows(meta)
    assert N == Lr * Nt, f"N={N} != Lr*Nt={Lr*Nt}"
    x = tokens.view(B, Lr, Nt, C)                                    # (B,Lr,Nt,C)

    if M == 0:
        return tokens
    with torch.no_grad():
        if share_across_rows:
            if M == Nt:
                mask_t = torch.ones(B, Nt, device=tokens.device, dtype=torch.bool)  # (B,Nt)
            else:
                scores = torch.rand(B, Nt, device=tokens.device)
                idx = scores.topk(k=M, dim=1).indices
                mask_t = torch.zeros(B, Nt, device=tokens.device, dtype=torch.bool)
                mask_t.scatter_(1, idx, True)
            m4 = mask_t.view(B, 1, Nt, 1).expand(B, Lr, Nt, 1)      # (B,Lr,Nt,1)
        else:

            if M == Nt:
                m4 = torch.ones(B, Lr, Nt, 1, device=tokens.device, dtype=torch.bool)
            else:
                scores = torch.rand(B*Lr, Nt, device=tokens.device)
                idx = scores.topk(k=M, dim=1).indices               # (B*Lr,M)
                mask = torch.zeros(B*Lr, Nt, device=tokens.device, dtype=torch.bool)
                mask.scatter_(1, idx, True)                         # (B*Lr,Nt)
                m4 = mask.view(B, Lr, Nt, 1)

    end = end_token.to(tokens.device, tokens.dtype).expand_as(x)     # (B,Lr,Nt,C)
    out = torch.where(m4.expand_as(x), end, x)                       # (B,Lr,Nt,C)
    return out.reshape(B, N, C)
