import matplotlib.pyplot as plt
import numpy as np


def visualize_sampling_order():
    height, width = 16, 16
    step = 4
    indices = [i * width + j for i in range(height) for j in range(width)]

    # 获取采样顺序
    sparse_indices = indices[::step]
    remaining_indices = [idx for idx in indices if idx not in sparse_indices]
    full_order = sparse_indices + remaining_indices

    # 创建顺序映射表（值表示采样顺序编号）
    order_map = {idx: i for i, idx in enumerate(full_order)}

    # 准备可视化数据
    grid = np.zeros((height, width))
    colors = np.empty((height, width), dtype=object)

    for i in range(height):
        for j in range(width):
            idx = i * width + j
            grid[i, j] = order_map[idx]
            colors[i, j] = 'red' if idx in sparse_indices else 'blue'

    # 绘制
    fig, ax = plt.subplots(figsize=(10, 10))
    for i in range(height):
        for j in range(width):
            ax.text(j, i, int(grid[i, j]), ha='center', va='center', color='white')
            ax.fill_between([j - 0.5, j + 0.5], [i - 0.5, i - 0.5], [i + 0.5, i + 0.5],
                            color=colors[i, j], alpha=0.5)

    ax.set_xlim(-0.5, 15.5)
    ax.set_ylim(-0.5, 15.5)
    ax.set_xticks(np.arange(0, 16, 1))
    ax.set_yticks(np.arange(0, 16, 1))
    ax.grid(which='both', color='black', linestyle='-', linewidth=0.5)
    ax.set_title('Sampling Order Visualization\n(Red: First 128 sparse samples, Blue: Last 128 remaining samples)')
    ax.set_aspect('equal')
    plt.gca().invert_yaxis()
    plt.show()


visualize_sampling_order()