import time
import torch
import matplotlib.pyplot as plt

def create_image(features, positions, height, width):
    im = torch.zeros(height, width, 3)
    for f, p in zip(features.squeeze(0), positions.squeeze(0)):
        im[p[0], p[1]] = f
    return im


def hierarchical_upsample_ordered(features, positions, tokens_per_scale, input_shape):
    B, N, C = features.shape
    device = features.device
    H, W = input_shape
    visibility = torch.zeros((B, H, W), dtype=torch.bool, device=device)
    n_scales = len(tokens_per_scale)
    ps = [2 ** (n_scales - s - 1) for s in range(n_scales)]
    start_id = 0
    scale_blocks = []
    for t, p in zip(tokens_per_scale, ps):
        end_id = start_id + t
        scale_blocks.append((start_id, end_id, p))
        start_id = end_id
    scale_blocks = scale_blocks[::-1]
    all_feats = []
    all_pos = []
    for start, end, patch_size in scale_blocks:
        feats_s = features[:, start:end, :]        # (B, Ns, C)
        pos_s = positions[:, start:end, :]         # (B, Ns, 2)
        print("feats_s shape is {}".format(feats_s.shape))
        B_s, Ns, _ = pos_s.shape
        dx, dy = torch.meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device), indexing='ij')
        offset = torch.stack([dx, dy], dim=-1).reshape(-1, 2)  # (ps², 2)
        pos_exp = pos_s.unsqueeze(2) + offset.view(1, 1, -1, 2)  # (B, Ns, ps², 2)
        pos_exp = pos_exp.view(B, -1, 2).long()
        xg = pos_exp[:, :, 0]
        yg = pos_exp[:, :, 1]
        flat_visibility = visibility.view(B, -1)  # (B, H*W)
        idx_flat = yg * W + xg
        idx_batch = torch.arange(B, device=device).view(B, 1).repeat(1, idx_flat.shape[1]).long()
        claimed = flat_visibility[idx_batch, idx_flat].view(B, Ns, patch_size**2).any(dim=2)  # (B, Ns)
        keep = ~claimed
        if keep.sum() == 0:
            continue
        B_idx, Ns_idx = torch.nonzero(keep, as_tuple=True)
        print("B_idx shape is {} and Ns_idx shape is {}".format(B_idx.shape, Ns_idx.shape))
        pos_keep = pos_s[B_idx, Ns_idx]
        feat_keep = feats_s[B_idx, Ns_idx]
        pos_keep = pos_keep.view(B, -1, 2)
        feat_keep = feat_keep.view(B, -1, C)
        pos_grid = pos_keep.unsqueeze(2) + offset.view(1, 1, -1, 2)  # (N_keep, ps², 2)
        pos_grid = pos_grid.view(B, -1, 2).long()
        feat_grid = feat_keep.unsqueeze(2).repeat(1, 1, patch_size**2, 1).view(B, -1, C)
        all_feats.append(feat_grid)
        all_pos.append(pos_grid)
        x_vis = pos_grid[:, :, 0]
        y_vis = pos_grid[:, :, 1]
        b_vis = torch.arange(B).unsqueeze(-1).expand(-1, pos_grid.shape[1])
        visibility[b_vis, y_vis, x_vis] = True
    return torch.cat(all_feats, dim=1), torch.cat(all_pos, dim=1)


def upsample_overwrite_all(features, positions, tokens_per_scale, input_shape, fill_value=0.0):
    """
    Coarse->fine overwrite onto a (B,H,W,C) canvas, then return every grid (x,y)
    with its paired feature.

    Args
    ----
    features:  (B, N, C)  float, requires_grad ok
    positions: (B, N, 2)  long/int, top-left (x,y) per token
    tokens_per_scale: list[int], counts per scale, ordered coarse -> fine
    input_shape: (H, W)
    fill_value: float, feature default for pixels never written

    Returns
    -------
    feats_all: (B, H*W, C)
    pos_all:   (B, H*W, 2)  int64, [x,y] per pixel (x fastest)
    mask:      (B, H, W)    bool, True where written at least once
    """
    B, N, C = features.shape
    H, W = input_shape
    device = features.device

    # patch sizes: s=0 (coarsest) has largest patch; last scale is 1x1
    S = len(tokens_per_scale)
    patch_sizes = [2 ** (S - s - 1) for s in range(S)]  # e.g. [32,16,8,4,2,1]

    canvas = features.new_full((B, H, W, C), fill_value)
    mask   = torch.zeros((B, H, W), dtype=torch.bool, device=device)

    start = 0
    for t_count, ps in zip(tokens_per_scale, patch_sizes):  # coarse -> fine
        end = start + t_count
        if t_count == 0:
            start = end
            continue

        feats_s = features[:, start:end, :]    # (B, Ns, C)
        pos_s   = positions[:, start:end, :]   # (B, Ns, 2)
        Ns = t_count

        # offsets inside a ps×ps patch
        d = torch.arange(ps, device=device)
        dx, dy = torch.meshgrid(d, d, indexing='ij')
        offset = torch.stack([dx, dy], dim=-1).view(1, 1, ps*ps, 2)  # (1,1,ps^2,2)

        # absolute pixel coords for each token's pixels (clamped to image)
        pos_pix = pos_s.unsqueeze(2) + offset                         # (B,Ns,ps^2,2)
        x = pos_pix[..., 0].clamp_(0, W - 1).to(torch.long).reshape(B, -1)  # (B, Ns*ps^2)
        y = pos_pix[..., 1].clamp_(0, H - 1).to(torch.long).reshape(B, -1)  # (B, Ns*ps^2)

        # expand features to per-pixel (use expand, not repeat → lighter, still diff)
        feats_exp = feats_s.unsqueeze(2).expand(B, Ns, ps*ps, C).reshape(B, -1, C)  # (B, Ns*ps^2, C)

        # overwrite per batch (fine scales come later → win)
        for b in range(B):
            canvas[b, y[b], x[b]] = feats_exp[b]   # differentiable w.r.t. features
            mask[b, y[b], x[b]] = True

        start = end

    # Build full-grid positions once, flatten to (B, H*W, 2) with x fastest
    X, Y = torch.meshgrid(torch.arange(W, device=device),
                          torch.arange(H, device=device), indexing='xy')  # (W,H)
    pos_grid = torch.stack([X.T, Y.T], dim=-1)      # (H, W, 2) as [x,y]
    pos_all = pos_grid.view(1, H, W, 2).expand(B, -1, -1, -1).reshape(B, H*W, 2)

    feats_all = canvas.view(B, H*W, C)
    return feats_all, pos_all


def upsample_select_winner(
    features: torch.Tensor,         # (B, N, C), requires_grad ok
    positions: torch.Tensor,        # (B, N, 2), long, top-left [x,y]
    tokens_per_scale,               # list[int], counts per scale (coarse -> fine)
    input_shape,                    # (H, W)
    fill_value: float = 0.0
):
    B, N, C = features.shape
    H, W = input_shape
    device = features.device
    assert sum(tokens_per_scale) == N, "tokens_per_scale must be COUNTS and sum to N"

    S = len(tokens_per_scale)
    # Larger priority wins; make fine scale highest
    # (coarse s=0 -> priority 0, ..., fine s=S-1 -> priority S-1)
    priorities = list(range(S))

    # Build per-scale slices
    starts = []
    s = 0
    for t in tokens_per_scale:
        starts.append(s)
        s += t
    ends = [st + t for st, t in zip(starts, tokens_per_scale)]

    # Accumulate per-pixel candidates across all scales
    all_flat_idx = []
    all_keys     = []  # priority*N_plus + token_id
    N_plus = N + 1     # base for packing; ensures unique key order

    for scale_idx, (st, ed) in enumerate(zip(starts, ends)):
        Ns = ed - st
        if Ns == 0:
            continue

        # Patch size from scale count (coarse largest, fine 1)
        ps = 2 ** (S - scale_idx - 1)

        pos_s = positions[:, st:ed, :]                      # (B, Ns, 2)
        # Offsets inside patch: (ps, ps, 2)
        d = torch.arange(ps, device=device)
        off = torch.stack(torch.meshgrid(d, d, indexing='ij'), dim=-1)  # (ps, ps, 2)

        # Absolute pixel coords: (B, Ns, ps, ps, 2)
        pos_pix = pos_s[:, :, None, None, :] + off[None, None, :, :, :]
        x = pos_pix[..., 0].clamp_(0, W - 1).to(torch.long).reshape(B, -1)   # (B, Ns*ps^2)
        y = pos_pix[..., 1].clamp_(0, H - 1).to(torch.long).reshape(B, -1)   # (B, Ns*ps^2)
        flat_idx = y * W + x                                                 # (B, Ns*ps^2)

        # Token ids for these pixels, packed with priority for argmax
        tok_ids = torch.arange(st, ed, device=device).view(1, Ns, 1, 1) \
                    .expand(B, Ns, ps, ps).reshape(B, -1)                    # (B, Ns*ps^2)
        key = priorities[scale_idx] * N_plus + tok_ids                        # (B, Ns*ps^2), int64

        all_flat_idx.append(flat_idx)
        all_keys.append(key)

    # Concatenate all candidates from all scales
    idx = torch.cat(all_flat_idx, dim=1)        # (B, P)
    key = torch.cat(all_keys, dim=1).to(torch.long)  # (B, P)

    # Pick winner per pixel via amax on packed key
    out_keys = torch.full((B, H * W), -1, dtype=torch.long, device=device)
    out_keys.scatter_reduce_(1, idx, key, reduce='amax', include_self=True)  # (B, H*W)

    # Decode token indices and valid mask
    valid = out_keys >= 0
    winner_tok = (out_keys % N_plus).clamp_min(0)                            # (B, H*W)

    # Gather features for winners; gradient flows to those tokens
    gather_idx = winner_tok.unsqueeze(-1).expand(-1, -1, C)                  # (B, H*W, C)
    feats_all = torch.gather(features, 1, gather_idx)                        # (B, H*W, C)
    if fill_value != 0.0:
        feats_all = feats_all.clone()
        feats_all[~valid] = fill_value

    # Full grid positions (x,y), x fastest
    yy, xx = torch.meshgrid(torch.arange(H, device=device),
                            torch.arange(W, device=device), indexing='ij')
    pos_all = torch.stack([xx, yy], dim=-1).view(1, H * W, 2).expand(B, -1, -1)

    mask = valid.view(B, H, W)
    return feats_all, pos_all


features = torch.zeros(1, 36, 3)
input_shape = (256,256)
tokens_per_scale = [4, 8, 16, 4, 4]
positions = torch.zeros(1,36, 2).long()

# 32 patches
positions[0,0,:] = torch.tensor([0,0])
positions[0,1,:] = torch.tensor([0,16])
positions[0,2,:] = torch.tensor([16,0])
positions[0,3,:] = torch.tensor([16,16])

features[0,0,:] = torch.tensor([0,255,255])
features[0,1,:] = torch.tensor([0,255,255])
features[0,2,:] = torch.tensor([0,255,255])
features[0,3,:] = torch.tensor([0,255,255])

# 16 patches
positions[0,4,:] = torch.tensor([0,0])
positions[0,5,:] = torch.tensor([0,8])
positions[0,6,:] = torch.tensor([8,0])
positions[0,7,:] = torch.tensor([8,8])
positions[0,8,:] = torch.tensor([16,16])
positions[0,9,:] = torch.tensor([16,24])
positions[0,10,:] = torch.tensor([24,16])
positions[0,11,:] = torch.tensor([24,24])

features[0,4,:] = torch.tensor([255,255,0])
features[0,5,:] = torch.tensor([255,255,0])
features[0,6,:] = torch.tensor([255,255,0])
features[0,7,:] = torch.tensor([255,255,0])
features[0,8,:] = torch.tensor([255,255,0])
features[0,9,:] = torch.tensor([255,255,0])
features[0,10,:] = torch.tensor([255,255,0])
features[0,11,:] = torch.tensor([255,255,0])

# 8 patches
positions[0,12,:] = torch.tensor([0,0])
positions[0,13,:] = torch.tensor([0,4])
positions[0,14,:] = torch.tensor([4,0])
positions[0,15,:] = torch.tensor([4,4])
positions[0,16,:] = torch.tensor([0,8])
positions[0,17,:] = torch.tensor([0,12])
positions[0,18,:] = torch.tensor([4,8])
positions[0,19,:] = torch.tensor([4,12])
positions[0,20,:] = torch.tensor([16,16])
positions[0,21,:] = torch.tensor([16,20])
positions[0,22,:] = torch.tensor([20,16])
positions[0,23,:] = torch.tensor([20,20])
positions[0,24,:] = torch.tensor([24,24])
positions[0,25,:] = torch.tensor([24,28])
positions[0,26,:] = torch.tensor([28,24])
positions[0,27,:] = torch.tensor([28,28])

features[0,12,:] = torch.tensor([255,0,255])
features[0,13,:] = torch.tensor([255,0,255])
features[0,14,:] = torch.tensor([255,0,255])
features[0,15,:] = torch.tensor([255,0,255])
features[0,16,:] = torch.tensor([255,0,255])
features[0,17,:] = torch.tensor([255,0,255])
features[0,18,:] = torch.tensor([255,0,255])
features[0,19,:] = torch.tensor([255,0,255])
features[0,20,:] = torch.tensor([255,0,255])
features[0,21,:] = torch.tensor([255,0,255])
features[0,22,:] = torch.tensor([255,0,255])
features[0,23,:] = torch.tensor([255,0,255])
features[0,24,:] = torch.tensor([255,0,255])
features[0,25,:] = torch.tensor([255,0,255])
features[0,26,:] = torch.tensor([255,0,255])
features[0,27,:] = torch.tensor([255,0,255])

# 4 patches
positions[0,28,:] = torch.tensor([20,20])
positions[0,29,:] = torch.tensor([20,22])
positions[0,30,:] = torch.tensor([22,20])
positions[0,31,:] = torch.tensor([22,22])

features[0,28,:] = torch.tensor([255,0,0])
features[0,29,:] = torch.tensor([255,0,0])
features[0,30,:] = torch.tensor([255,0,0])
features[0,31,:] = torch.tensor([255,0,0])


# 2 patches
positions[0,32,:] = torch.tensor([22,20])
positions[0,33,:] = torch.tensor([22,21])
positions[0,34,:] = torch.tensor([23,20])
positions[0,35,:] = torch.tensor([23,21])

features[0,32,:] = torch.tensor([0,0,255])
features[0,33,:] = torch.tensor([0,0,255])
features[0,34,:] = torch.tensor([0,0,255])
features[0,35,:] = torch.tensor([0,0,255])


features = features.repeat((2, 1, 1))
positions = positions.repeat((2, 1, 1))

#features = torch.randn(512, 8000, 3)
#pos = torch.zeros(512, 8000, 2)
#tokens_per_scale = (256, 1024, 2048, 4672)
#input_shape = (512,512)

start = time.time()
feats, poss = hierarchical_upsample_ordered(features, positions, tokens_per_scale, input_shape)
print("Time for slow version: {}".format(time.time() - start))
plt.imshow(create_image(feats[0], poss[0], input_shape[0], input_shape[1]))
plt.show()
'''
start = time.time()
feats, poss = upsample_overwrite_all(features, positions, tokens_per_scale, input_shape)
print("Time for fast version: {}".format(time.time() - start))
plt.imshow(create_image(feats[0], poss[0], input_shape[0], input_shape[1]))
plt.show()
start = time.time()
feats, poss = upsample_select_winner(features, positions, tokens_per_scale, input_shape)
print("Time for fast version 2: {}".format(time.time() - start))
plt.imshow(create_image(feats[0], poss[0], input_shape[0], input_shape[1]))
plt.show()
'''
