import torch
import torch.nn.functional as F

def compute_saliency_only(x_patches):
    aff = torch.matmul(x_patches, x_patches.permute(0, 2, 1))
    aff = F.softmax(aff, dim=2)
    aff_sum = torch.sum(aff, dim=1)
    return (aff_sum - aff_sum.min(dim=1, keepdim=True)[0]) / (aff_sum.max(dim=1, keepdim=True)[0] - aff_sum.min(dim=1, keepdim=True)[0] + 1e-8)

def calc_dynamic_mask_ratio(y, base_mask_ratio=0.5, L=196):
    y_normalized = y.float().mean() / L
    dynamic_mask_ratio = base_mask_ratio - 0.15 + 2 * 0.15 * y_normalized
    return torch.clamp(dynamic_mask_ratio, 0.0, 1.0)

def apply_multi_channel_masking(x_channels, mask_ratios):
    C = len(x_channels)
    N, L, D = x_channels[0].shape
    device = x_channels[0].device

    sal_list = [compute_saliency_only(x) for x in x_channels]
    sal = torch.stack(sal_list, dim=0)

    K = torch.tensor([int(L * r) for r in mask_ratios], device=device)
    total_cap = (C-1)*L
    if K.sum() > total_cap:
        scale = total_cap / K.sum().float()
        K = torch.floor(K.float()*scale).long()

    masks = torch.zeros((C,N,L), device=device, dtype=torch.bool)
    for c in range(C):
        topk_idx = torch.topk(sal[c], k=K[c], dim=1).indices
        masks[c].scatter_(1, topk_idx, True)

    all_masked = masks.all(dim=0)
    if all_masked.any():
        n_idx, l_idx = torch.nonzero(all_masked, as_tuple=True)
        for n, l in zip(n_idx, l_idx):
            min_c = torch.argmin(sal[:,n,l])
            masks[min_c,n,l] = False
            already_masked = masks[min_c,n]
            cand = (~already_masked).float() * sal[min_c,n]
            if cand.max()>0:
                new_idx = torch.argmax(cand)
                masks[min_c,n,new_idx] = True

    return masks, sal

def adaptive_inter_channel_masking(x, base_mask_ratio=0.5):
    N, total_len, D = x.shape
    cls_token = x[:, :1, :]
    C = 2
    L = (total_len-1)//C

    x_channels = [x[:,1+c*L:1+(c+1)*L,:] for c in range(C)]

    sal_list = [compute_saliency_only(ch) for ch in x_channels]
    mask_ratios = [calc_dynamic_mask_ratio(s, base_mask_ratio, L) for s in sal_list]

    masks, sal = apply_multi_channel_masking(x_channels, mask_ratios)

    x_masked_list, mask_list, ids_list = [], [], []
    for c in range(C):
        sal_c = sal[c]
        ids_shuffle = torch.argsort(sal_c, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        len_keep = L - masks[c].sum(dim=1)
        ids_keep = ids_shuffle[:, :len_keep.min()]
        x_kept = torch.gather(x_channels[c], 1, ids_keep.unsqueeze(-1).repeat(1,1,D))
        x_masked_list.append(x_kept)
        mask_list.append(masks[c].float())
        ids_list.append(ids_restore)

    x_masked = torch.cat([cls_token]+x_masked_list, dim=1)
    mask = torch.cat(mask_list, dim=1)
    ids_restore = torch.cat(ids_list, dim=1)

    return x_masked, mask, ids_restore
