import random
import math
import numpy as np
import torch


class RandomLeadsMask(object):
    def __init__(
        self,
        p=1,
        mask_leads_selection="random",
        mask_leads_prob=0.5,
        mask_leads_condition=None,
        **kwargs,
    ):
        self.p = p
        self.mask_leads_prob = mask_leads_prob
        self.mask_leads_selection = mask_leads_selection
        self.mask_leads_condition = mask_leads_condition
    
    def __call__(self, sample):
        ''' Sample: torch.Tensor '''
        if self.p >= np.random.uniform(0, 1):
            new_sample = sample.new_zeros(sample.size())
            if self.mask_leads_selection == "random":
                survivors = np.random.uniform(0, 1, size=12) >= self.mask_leads_prob
                new_sample[survivors] = sample[survivors]
            elif self.mask_leads_selection == "conditional":
                (n1, n2) = self.mask_leads_selection
                assert (
                    (0 <= n1 and n1 <= 6) and
                    (0 <= n2 and n2 <= 6)
                ), (n1, n2)
                s1 = np.array(
                    random.sample(list(np.arange(6)), 6-n1)
                )
                s2 = np.array(
                    random.sample(list(np.arange(6)), 6-n2)
                ) + 6
                new_sample[s1] = sample[s1]
                new_sample[s2] = sample[s2]
        else:
            new_sample = sample.clone()

        return new_sample.float()


def Tinterpolate(data, marker):

    channels, timesteps = data.shape
    data = data.flatten()
    ndata = data.numpy()
    interpolation = torch.from_numpy(np.interp(np.where(ndata == marker)[0], np.where(ndata != marker)[0], ndata[ndata != marker]))
    data[data == marker] = interpolation.type(data.type())
    data = data.reshape(channels, timesteps)

    return data


class Transformation:
    def __init__(self, *args, **kwargs):
        self.params = kwargs

    def get_params(self):
        return self.params


class TRandomResizedCrop(Transformation):
    """ Extract crop at random position and resize it to full size
    """
    
    def __init__(self, crop_ratio_range=[0.5, 1.0]):
        super().__init__()
        self.crop_ratio_range = crop_ratio_range
       
    def __call__(self, data):
        output = torch.full(data.shape, float("inf")).type(data.type())
        # timesteps, channels = output.shape
        channels, timesteps = data.shape
        crop_ratio = random.uniform(*self.crop_ratio_range)
        data = TRandomCrop(int(crop_ratio * timesteps))(data)  # apply random crop
        cropped_timesteps = data.shape[1]
        indices = torch.sort((torch.randperm(timesteps-2)+1)[:cropped_timesteps-2])[0]
        indices = torch.cat([torch.tensor([0]), indices, torch.tensor([timesteps-1])])
        output[:, indices] = data  # fill output array randomly (but in right order) with values from random crop
        
        # use interpolation to resize random crop
        output = Tinterpolate(output, float("inf"))

        return output
    
    def __str__(self):
        return "RandomResizedCrop"


class TRandomCrop(object):
    """Crop randomly the image in a sample.
    """

    def __init__(self, output_size,annotation=False):
        self.output_size = output_size
        self.annotation = annotation

    def __call__(self, data):

        _, timesteps = data.shape
        assert(timesteps >= self.output_size)
        if(timesteps==self.output_size):
            start=0
        else:
            start = random.randint(0, timesteps - self.output_size-1) #np.random.randint(0, timesteps - self.output_size)

        data = data[:, start: start + self.output_size]
        
        return data
    
    def __str__(self):
        return "RandomCrop"


class TTimeOut(Transformation):
    """ replace random crop by zeros
    """

    def __init__(self, crop_ratio_range=[0.0, 0.5]):
        super(TTimeOut, self).__init__(crop_ratio_range=crop_ratio_range)
        self.crop_ratio_range = crop_ratio_range

    def __call__(self, data):
        data = data.clone()
        timesteps, channels = data.shape
        crop_ratio = random.uniform(*self.crop_ratio_range)
        crop_timesteps = int(crop_ratio*timesteps)
        start_idx = random.randint(0, timesteps - crop_timesteps-1)
        data[start_idx:start_idx+crop_timesteps, :] = 0
        return data

    def __str__(self):
        return "TimeOut"


def Tnoise_powerline(fs=100, N=1000,C=1,fn=50.,K=3, channels=1):
    '''powerline noise inspired by https://ieeexplore.ieee.org/document/43620
    fs: sampling frequency (Hz)
    N: lenght of the signal (timesteps)
    C: relative scaling factor (default scale: 1)
    fn: base frequency of powerline noise (Hz)
    K: number of higher harmonics to be considered
    channels: number of output channels (just rescaled by a global channel-dependent factor)
    '''
    #C *= 0.333 #adjust default scale
    t = torch.arange(0,N/fs,1./fs)
    
    signal = torch.zeros(N)
    phi1 = random.uniform(0,2*math.pi)
    for k in range(1,K+1):
        ak = random.uniform(0,1)
        signal += C*ak*torch.cos(2*math.pi*k*fn*t+phi1)
    signal = C*signal[:,None]
    if(channels>1):
        channel_gains = torch.empty(channels).uniform_(-1,1)
        signal = signal*channel_gains[None]
    return signal

def Tnoise_baseline_wander(fs=100, N=1000, C=1.0, fc=0.5, fdelta=0.01,channels=1,independent_channels=False):
    '''baseline wander as in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5361052/
    fs: sampling frequency (Hz)
    N: lenght of the signal (timesteps)
    C: relative scaling factor (default scale : 1)
    fc: cutoff frequency for the baseline wander (Hz)
    fdelta: lowest resolvable frequency (defaults to fs/N if None is passed)
    channels: number of output channels
    independent_channels: different channels with genuinely different outputs (but all components in phase) instead of just a global channel-wise rescaling
    '''
    if(fdelta is None):# 0.1
        fdelta = fs/N

    K = int((fc/fdelta)+0.5)
    t = torch.arange(0, N/fs, 1./fs).repeat(K).reshape(K, N)
    k = torch.arange(K).repeat(N).reshape(N, K).T
    phase_k = torch.empty(K).uniform_(0, 2*math.pi).repeat(N).reshape(N, K).T
    a_k = torch.empty(K).uniform_(0, 1).repeat(N).reshape(N, K).T
    pre_cos = 2*math.pi * k * fdelta * t + phase_k
    cos = torch.cos(pre_cos)
    weighted_cos = a_k * cos
    res = weighted_cos.sum(dim=0)
    return C*res


class Compose():
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, wave, target):
        for t in self.transforms:
            wave, target = t(wave, target)
        return wave, target


class RandomCrop():
    def __init__(self, length, start, end):
        self.length = length
        self.start = start
        self.end = end
    
    def __call__(self, wave, target):
        start = random.randint(self.start, self.end-self.length)
        end = start + self.length
        return wave[:,start:end], target[:,start:end]


class ChannelResize():
    def __init__(self, magnitude_range=(0.5, 2)):
        self.log_magnitude_range = torch.log(torch.tensor(magnitude_range))

    def __call__(self, wave, target):
        channels, len_wave = wave.shape
        resize_factors = torch.exp(torch.empty(channels).uniform_(*self.log_magnitude_range)) 
        resize_factors = resize_factors.repeat(len_wave).view(wave.T.shape).T 
        wave = resize_factors * wave
        return wave, target
    
class GaussianNoise():
    def __init__(self, prob=1.0, scale=0.01):
        self.scale = scale
        self.prob = prob
    
    def __call__(self, wave, target):
        if random.random() < self.prob:
            wave += self.scale * torch.randn(wave.shape)
        return wave, target


class BaselineShift():
    def __init__(self, prob=1.0, scale=1.0):
        self.prob = prob
        self.scale = scale

    def __call__(self, wave, target):
        if random.random() < self.prob:
            shift = torch.randn(1)
            wave = wave + self.scale * shift
        return wave, target


class BaselineWander():
    def __init__(self, prob=1.0, freq=500):
        self.freq = freq
        self.prob = prob

    def __call__(self, wave, target):
        if random.random() < self.prob:
            channels, len_wave = wave.shape
            wander = Tnoise_baseline_wander(fs=self.freq, N=len_wave) 
            wander = wander.repeat(channels).view(wave.shape)
            wave = wave + wander
        return wave, target


class PowerlineNoise():
    def __init__(self, prob=1.0, freq=500):
        self.freq = freq
        self.prob = prob

    def __call__(self, wave, target):
        if random.random() < self.prob:
            channels, len_wave = wave.shape
            noise = Tnoise_powerline(fs=self.freq, N=len_wave, channels=channels).T 
            wave = wave + noise
        return wave, target




import torch
import torch.nn.functional as F
from typing import Tuple

@torch.no_grad()
def random_time_crop(
    x: torch.Tensor, 
    ratio: Tuple[float, float] | float = (0.6, 0.9),
    *, 
    resize_back: bool = True,
    align_to: int | None = 40
) -> torch.Tensor:
    """
    Randomly crop a contiguous sub-sequence per sample, optionally resize back to original T.
    Args:
        x: (B, C, T)
        ratio: crop length ratio in [low, high] or a float
        resize_back: if True, linearly interpolate the cropped view back to length T
        align_to: if not None, crop length is rounded to a multiple of align_to (>= align_to)
    """
    assert x.dim() == 3, f"expected (B,C,T), got {tuple(x.shape)}"
    B, C, T = x.shape
    dev = x.device

    def _sample_L() -> int:
        if isinstance(ratio, (tuple, list)):
            a, b = float(ratio[0]), float(ratio[1])
            r = torch.empty((), device=dev).uniform_(a, b).item()
        else:
            r = float(ratio)
        L = max(2, int(round(T * r)))
        if align_to and align_to > 1:
            L = max(align_to, int(round(L / align_to)) * align_to)
        return min(L, T)

    Ls = [ _sample_L() for _ in range(B) ]
    outs = []
    for b in range(B):
        L = Ls[b]
        max_start = max(0, T - L)
        s = int(torch.randint(0, max_start + 1, (1,), device=dev).item())
        v = x[b, :, s:s+L]  # (C,L)
        if resize_back and v.shape[-1] != T:
            v = F.interpolate(v[None], size=T, mode="linear", align_corners=False)[0]
        outs.append(v)
    return torch.stack(outs, dim=0)


@torch.no_grad()
def time_mask(
    x: torch.Tensor, 
    p: float = 1.0, 
    max_len: int | None = None
) -> torch.Tensor:
    """
    Contiguous time-window masking (zero-out) per sample. Operates on time axis only.
    Args:
        x: (B, C, T)
        p: probability to apply masking for each sample
        max_len: max masked length (in samples). Default = T // 10
    """
    assert x.dim() == 3
    B, C, T = x.shape
    if max_len is None:
        max_len = max(1, T // 10)
    out = x.clone()
    for b in range(B):
        if torch.rand(()) < p and max_len > 0:
            L = int(torch.randint(1, min(max_len, T) + 1, (1,)))
            s = int(torch.randint(0, max(1, T - L + 1), (1,)))
            out[b, :, s:s+L] = 0
    return out


@torch.no_grad()
def scale(x: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
    """
    Per-sample, per-channel multiplicative scaling with factors ~ N(1, sigma^2).
    Args:
        x: (B, C, T)
    """
    assert x.dim() == 3
    B, C, T = x.shape
    factors = 1.0 + torch.randn(B, C, 1, device=x.device, dtype=x.dtype) * sigma
    return x * factors


@torch.no_grad()
def jitter(x: torch.Tensor, sigma: float = 0.02, eps: float = 1e-6) -> torch.Tensor:
    """
    Additive Gaussian noise scaled by per-sample, per-channel std.
    Args:
        x: (B, C, T)
    """
    assert x.dim() == 3
    ch_std = x.std(dim=-1, keepdim=True, unbiased=False)
    noise = torch.randn_like(x) * (sigma * (ch_std + eps))
    return x + noise


@torch.no_grad()
def time_warp_smooth(
    x: torch.Tensor, 
    max_warp: float = 0.05, 
    knots: int = 4
) -> torch.Tensor:
    """
    Smooth time warping using grid_sample; same warp applied to all channels of a sample.
    Args:
        x: (B, C, T)
        max_warp: maximum relative displacement in normalized grid units (~ +/- max_warp)
        knots: number of control points to define a smooth displacement curve
    """
    assert x.dim() == 3
    if max_warp <= 0 or knots < 2:
        return x
    B, C, T = x.shape
    dev, dt = x.device, x.dtype

    # base grid in [-1, 1]
    base = torch.linspace(-1, 1, T, device=dev, dtype=dt).view(1, 1, T, 1).expand(B, 1, T, 1)
    # random displacement on coarse grid per sample
    disp_ctrl = torch.empty(B, 1, knots, 1, device=dev, dtype=dt).uniform_(-max_warp, max_warp)
    disp = F.interpolate(disp_ctrl, size=T, mode="linear", align_corners=True)  # (B,1,T,1)
    grid_t = (base + disp).clamp(-1, 1)  # (B,1,T,1)

    # build 2D grid (T,1,2) for grid_sample
    grid2d = torch.cat([grid_t, torch.zeros_like(grid_t)], dim=-1)  # (B,1,T,2)
    grid2d = grid2d.squeeze(1).view(B, T, 1, 2)
    x2d = x.unsqueeze(-1)  # (B, C, T, 1)
    out = F.grid_sample(x2d, grid2d, mode='bilinear', padding_mode='border', align_corners=True)
    return out.squeeze(-1)


@torch.no_grad()
def permute_segments(x: torch.Tensor, K: int = 4, p: float = 0.8) -> torch.Tensor:
    """
    Split time into K segments and randomly permute them per sample with prob p.
    Args:
        x: (B, C, T)
    """
    assert x.dim() == 3
    if K <= 1:
        return x
    B, C, T = x.shape
    seg = T // K
    splits = [seg] * (K - 1) + [T - seg * (K - 1)]
    outs = []
    for b in range(B):
        if torch.rand(()) >= p:
            outs.append(x[b])
            continue
        xs = torch.split(x[b], splits, dim=-1)
        order = torch.randperm(K, device=x.device)
        outs.append(torch.cat([xs[i] for i in order], dim=-1))
    return torch.stack(outs, dim=0)


@torch.no_grad()
def bandpass_perturb(
    x: torch.Tensor, 
    p: float = 0.5, 
    delta_hz: float = 0.6, 
    fs: float = 128.0,
    bumps: int = 2,
    amp: float = 0.05
) -> torch.Tensor:
    """
    Light frequency response perturbation by multiplying rFFT spectrum with a smooth random envelope.
    Args:
        x: (B, C, T)
        p: probability to apply per sample
        delta_hz: width of gaussian bumps in Hz
        fs: sampling rate
        bumps: number of random bumps in the envelope
        amp: max amplitude of envelope deviations (±amp)
    """
    assert x.dim() == 3
    B, C, T = x.shape
    dev, dt = x.device, x.dtype
    freqs = torch.fft.rfftfreq(T, d=1.0 / fs).to(dev)
    out = []
    for b in range(B):
        xb = x[b]
        if torch.rand(()) < p:
            X = torch.fft.rfft(xb, dim=-1)
            env = torch.ones_like(freqs)
            for _ in range(max(1, bumps)):
                f0 = freqs[torch.randint(0, len(freqs), (1,))].item()
                width = max(1e-6, delta_hz)
                bump = torch.exp(-0.5 * ((freqs - f0) / width) ** 2)
                env = env * (1.0 + (torch.rand((), device=dev) * 2 - 1) * amp * bump)
            env = env.clamp(1.0 - amp, 1.0 + amp)
            X = X * env.view(1, -1)  # (C, F)
            xb2 = torch.fft.irfft(X, n=T, dim=-1)
            out.append(xb2.to(dtype=dt))
        else:
            out.append(xb)
    return torch.stack(out, dim=0)


@torch.no_grad()
def spec_time_freq_mask(
    x: torch.Tensor,
    fs: float = 128.0,
    win: int = 256,
    hop: int = 128,
    time_mask_p: float = 0.6,
    freq_mask_p: float = 0.6,
    t_max: int = 8,
    f_max: int = 6
) -> torch.Tensor:
    """
    SpecAugment-style masking on STFT magnitude, then iSTFT back to time domain.
    Args:
        x: (B, C, T)
    """
    assert x.dim() == 3
    B, C, T = x.shape
    X = torch.stft(
        x.view(B * C, T),
        n_fft=win, hop_length=hop, win_length=win,
        return_complex=True, center=True
    )  # (B*C, F, TT)
    mag, ph = X.abs(), X.angle()
    # time mask
    if torch.rand(()) < time_mask_p:
        t_len = mag.shape[-1]
        L = int(torch.randint(1, min(t_max, t_len) + 1, (1,)))
        s = int(torch.randint(0, max(1, t_len - L + 1), (1,)))
        mag[..., s:s+L] = 0
    # freq mask
    if torch.rand(()) < freq_mask_p:
        f_len = mag.shape[-2]
        L = int(torch.randint(1, min(f_max, f_len) + 1, (1,)))
        s = int(torch.randint(0, max(1, f_len - L + 1), (1,)))
        mag[..., s:s+L, :] = 0
    X2 = torch.polar(mag, ph)
    y = torch.istft(
        X2, n_fft=win, hop_length=hop, win_length=win, length=T, center=True
    )
    return y.view(B, C, T)


@torch.no_grad()
def channel_dropout(
    x: torch.Tensor, 
    drop_prob: float = 0.2, 
    min_keep: int = 1
) -> torch.Tensor:
    """
    Drop entire channels to zero with probability drop_prob (per sample, per channel).
    Ensures at least `min_keep` channels remain active in each sample.
    Args:
        x: (B, C, T)
    """
    assert x.dim() == 3
    B, C, T = x.shape
    mask = (torch.rand(B, C, 1, device=x.device, dtype=x.dtype) > drop_prob).to(x.dtype)
    # ensure at least min_keep channels kept
    keep = mask.sum(dim=1, keepdim=True)  # (B,1,1)
    need = (keep < min_keep).squeeze(-1).squeeze(-1)  # (B,)
    if need.any():
        for b in torch.where(need)[0]:
            idx = torch.randperm(C, device=x.device)[:min_keep]
            mask[b, idx, 0] = 1.0
    return x * mask
# ===== ContraWR-style augmentations (time-domain; FFT helpers) =====
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional

def _rfft_band_mask(T: int, fs: float, lo: float, hi: float, device) -> torch.Tensor:
    # returns [F] mask (0/1) for rfft bins
    freqs = torch.fft.rfftfreq(T, d=1.0/fs, device=device)  # [F]
    return ((freqs >= lo) & (freqs <= hi)).to(torch.float32)

@torch.no_grad()
def bandpass_fft(
    x: torch.Tensor,                   # (B,C,T)
    fs: float,
    per_channel_bands: Optional[Dict[int, Tuple[float, float]]] = None,
    default_band: Optional[Tuple[float, float]] = None,
    jitter_pct: float = 0.0
) -> torch.Tensor:
    """
    Simple FFT masking bandpass. For each channel c, keep lo..hi and zero the rest.
    per_channel_bands: {c_idx: (lo,hi)}; channels not in dict use default_band if provided.
    jitter_pct: random jitter on (lo, hi): lo*=U(1±jitter), hi*=U(1±jitter)
    """
    B, C, T = x.shape
    X = torch.fft.rfft(x, dim=-1)                  # (B,C,F)
    device = x.device
    for c in range(C):
        if per_channel_bands and c in per_channel_bands:
            lo, hi = per_channel_bands[c]
        elif default_band is not None:
            lo, hi = default_band
        else:
            continue
        if jitter_pct > 0:
            j = 1.0 + (torch.rand(2, device=device) * 2 - 1.0) * jitter_pct
            lo, hi = lo * float(j[0]), hi * float(j[1])
            if lo > hi:
                lo, hi = hi, lo
            lo = max(lo, 0.0)
        m = _rfft_band_mask(T, fs, lo, hi, device) # (F,)
        X[:, c, :] = X[:, c, :] * m                # broadcast
    y = torch.fft.irfft(X, n=T, dim=-1)            # (B,C,T)
    return y.to(x.dtype)

@torch.no_grad()
def colored_noise(
    x: torch.Tensor,        # (B,C,T)
    fs: float,
    kind: str = "high",     # "high" or "low"
    cutoff_hz: float = 15.0,
    snr_scale: float = 0.05 # noise amplitude = snr_scale * per-channel std
) -> torch.Tensor:
    B, C, T = x.shape
    device = x.device
    # make white noise then filter in freq domain
    eps = torch.randn_like(x)
    E = torch.fft.rfft(eps, dim=-1)                # (B,C,F)
    freqs = torch.fft.rfftfreq(T, d=1.0/fs, device=device) # (F,)
    if kind == "high":
        mask = (freqs >= cutoff_hz).to(torch.float32)
    else:
        mask = (freqs <= cutoff_hz).to(torch.float32)
    E = E * mask                                   # (broadcast on last dim)
    n = torch.fft.irfft(E, n=T, dim=-1)
    # scale per-channel
    std = x.float().std(dim=-1, keepdim=True).clamp_min(1e-6)
    y = x + n * (snr_scale * std)
    return y.to(x.dtype)

@torch.no_grad()
def lr_channel_swap(
    x: torch.Tensor,                  # (B,C,T)
    pairs: List[Tuple[int, int]],     # e.g., [(c3_idx,c4_idx), (f3_idx,f4_idx)]
    p: float = 0.5
) -> torch.Tensor:
    if torch.rand(1).item() > p:
        return x
    y = x.clone()
    for (l, r) in pairs:
        y[:, l, :], y[:, r, :] = x[:, r, :], x[:, l, :]
    return y

@torch.no_grad()
def time_shift_roll(
    x: torch.Tensor,        # (B,C,T)
    max_shift_samples: int  # e.g., int(0.5 * fs)
) -> torch.Tensor:
    if max_shift_samples <= 0:
        return x
    B, C, T = x.shape
    shift = torch.randint(low=-max_shift_samples, high=max_shift_samples+1, size=(B,), device=x.device)
    # per-batch-item circular shift (same shift for all channels)
    idx = torch.arange(T, device=x.device).unsqueeze(0) - shift.view(-1,1)
    idx = idx % T
    y = x.gather(dim=-1, index=idx[:, None, :].expand(B, C, T))
    return y
# ==== MAE-style 2D token masking over (lead x time) for (B,C,T) ====
import torch
from einops import rearrange

@torch.no_grad()
def patchify2d(x: torch.Tensor, pL: int, pT: int) -> torch.Tensor:
    """
    x: [B, C, T]; pL: patch size along channels (leads); pT: patch size along time
    returns tokens: [B, N, P] where N = (C/pL)*(T/pT), P = pL*pT
    """
    B, C, T = x.shape
    assert C % pL == 0 and T % pT == 0, f"shape {(C,T)} not divisible by {(pL,pT)}"
    tokens = rearrange(x, 'b (h pL) (w pT) -> b (h w) (pL pT)', pL=pL, pT=pT)
    return tokens

@torch.no_grad()
def unpatchify2d(tokens: torch.Tensor, pL: int, pT: int, C: int, T: int) -> torch.Tensor:
    """
    tokens: [B, N, P]; return y: [B, C, T]
    """
    H, W = C // pL, T // pT
    y = rearrange(tokens, 'b (h w) (pL pT) -> b (h pL) (w pT)', h=H, w=W, pL=pL, pT=pT)
    return y

@torch.no_grad()
def random_token_mask2d(
    x: torch.Tensor,                # [B,C,T]
    pL: int, pT: int,
    mask_ratio: float = 0.5,        # fraction of patches to mask
    mode: str = "zero",             # {"zero","noise","mean"}
    noise_scale: float = 0.0,       # std = noise_scale * per-channel std (used when mode="noise")
    same_mask_for_batch: bool = False
) -> torch.Tensor:
    """
    MAE-like: random subset of tokens is masked, then unpatchify back to time-domain.
    Each sample gets its own mask (unless same_mask_for_batch=True).
    """
    B, C, T = x.shape
    tokens = patchify2d(x, pL, pT)          # [B, N, P]
    B, N, P = tokens.shape
    K = int(round(mask_ratio * N))
    if K <= 0:
        return x

    # choose mask indices
    if same_mask_for_batch:
        idx = torch.randperm(N, device=x.device)[:K]
        mask = torch.zeros(N, device=x.device, dtype=torch.bool)
        mask[idx] = True
        mask = mask.unsqueeze(0).expand(B, -1)     # [B,N]
    else:
        idx = torch.argsort(torch.rand(B, N, device=x.device), dim=1)[:, :K]
        mask = torch.zeros(B, N, device=x.device, dtype=torch.bool).scatter_(1, idx, True)  # [B,N]

    # apply replacement
    toks = tokens.clone()
    if mode == "zero":
        toks[mask] = 0
    elif mode == "mean":
        # replace with per-sample per-token mean (stable)
        mean_tok = tokens.mean(dim=1, keepdim=True)               # [B,1,P]
        toks[mask] = mean_tok.expand_as(tokens)[mask]
    elif mode == "noise":
        # gaussian noise scaled by channel std in original space
        # compute per-channel std, broadcast to tokens
        std_c = x.float().std(dim=-1, keepdim=True).mean(dim=1, keepdim=True).clamp_min(1e-6)  # [B,1,1]
        noise = torch.randn_like(tokens) * (noise_scale * std_c)
        toks[mask] = noise[mask]
    else:
        raise ValueError(f"unknown mode {mode}")

    y = unpatchify2d(toks, pL, pT, C, T)
    return y.to(x.dtype)
def random_token_mask2d_fill(
    x: torch.Tensor,                # [B,C,T]
    pL: int, pT: int,
    mask_ratio: float,              # 0~1
    fill: str | torch.Tensor,       # "zero" | "mean" | learnable tensor of shape [P]
    same_mask_for_batch: bool = False,
) -> torch.Tensor:

    B, C, T = x.shape
    toks = patchify2d(x, pL, pT)                 # [B,N,P]
    B, N, P = toks.shape
    K = int(round(mask_ratio * N))
    if K <= 0:  
        return x


    if same_mask_for_batch:
        idx = torch.randperm(N, device=x.device)[:K]
        mask = torch.zeros(N, device=x.device, dtype=torch.bool)
        mask[idx] = True
        mask = mask.unsqueeze(0).expand(B, -1)   # [B,N]
    else:
        idx = torch.argsort(torch.rand(B, N, device=x.device), dim=1)[:, :K]
        mask = torch.zeros(B, N, device=x.device, dtype=torch.bool).scatter_(1, idx, True)

    if isinstance(fill, torch.Tensor):
        assert fill.numel() == P, f"mask token dim {fill.numel()} != P={P}"
        fill_tok = fill.view(1, 1, P).expand(B, N, P)
    elif fill == "zero":
        fill_tok = torch.zeros_like(toks)
    elif fill == "mean":
        fill_tok = toks.mean(dim=1, keepdim=True).expand(B, N, P)
    else:
        raise ValueError(f"Unknown fill={fill}")


    mask3 = mask.unsqueeze(-1)                   # [B,N,1]
    toks = torch.where(mask3, fill_tok, toks)    

    y = unpatchify2d(toks, pL, pT, C, T)         
    return y
    