import numpy as np
import math
import torch
from scipy.spatial import cKDTree

def mono_to_big_batch(color, original_order):
    """
    批量将单色矩阵坐标转换为大矩阵坐标
    输入:
        color: 颜色类型 ("white", "gray", "black")
        original_order: 形状为 (nb_sample, n, 2) 的张量,
                        [..., 0] = i' (行坐标), [..., 1] = j' (列坐标)
    输出:
        big_coords: 形状为 (nb_sample, n, 2) 的张量,
                    [..., 0] = i (行坐标), [..., 1] = j (列坐标)
    """
    if not isinstance(original_order, torch.Tensor):
        original_order = torch.tensor(original_order, dtype=torch.long)
    i_prime = original_order[..., 0]
    j_prime = original_order[..., 1]
    j = j_prime
    if color == "light_gray":
        j = j_prime * 2
        c = j % 4
        base_map = torch.tensor([1, 0, 3, 2], device=original_order.device)
        base = base_map[c]
        i = base + 4 * i_prime
    elif color == "dark_gray":
        j = j_prime * 2 + 1
        c = j % 4
        base_map = torch.tensor([1, 0, 3, 2], device=original_order.device)
        base = base_map[c]
        i = base + 4 * i_prime
    elif color == "gray":
        c = j_prime % 4
        base_map = torch.tensor([1, 0, 3, 2], device=original_order.device)
        base = base_map[c]
        i = base + 4 * i_prime
    elif color == "black":
        c = j_prime % 4
        base_map = torch.tensor([3, 2, 1, 0], device=original_order.device)
        base = base_map[c]
        i = base + 4 * i_prime
    elif color == "white":
        c = j_prime % 4
        k = i_prime // 2
        pos_in_block = i_prime % 2
        r = torch.zeros_like(i_prime)
        mask0_2 = (c == 0) | (c == 2)
        r[mask0_2] = pos_in_block[mask0_2] * 2  # pos0->0, pos1->2
        mask1_3 = (c == 1) | (c == 3)
        r[mask1_3] = pos_in_block[mask1_3] * 2 + 1  # pos0->1, pos1->3
        i = 4 * k + r
    else:
        raise ValueError(f"Invalid color: {color}. Must be 'white', 'gray' or 'black'.")
    big_coords = torch.stack([i, j], dim=-1)
    return big_coords

def generate_full_autoregressive_order(nb_sample=50):
    """
    生成完整的自回归顺序
    输入: nb_sample - 样本数量
    输出: full_order - 形状为 (nb_sample, 256, 2) 的张量, 表示完整的生成顺序
    """
    # halton_dark_gray = generate_2D_sample_order((4, 8), nb_sample)
    # halton_light_gray = generate_2D_sample_order((4, 16), nb_sample)
    # halton_black = generate_2D_sample_order((4, 16), nb_sample)
    halton_white = generate_2D_sample_order((16, 16), nb_sample)

    # big_dark_gray = mono_to_big_batch("dark_gray", halton_dark_gray)
    # big_light_gray = mono_to_big_batch("gray", halton_light_gray)
    # big_black = mono_to_big_batch("black", halton_black)
    # big_white = mono_to_big_batch("white", halton_white)
    # 5. 按顺序拼接三个阶段: 灰色 -> 白色 -> 黑色
    # full_order = torch.cat([big_white, big_black, big_light_gray], dim=1)

    return halton_white

def reorder_samples_vectorized(sample_orders, width):
    """
    向量化版本的重排序函数，性能更好

    Args:
        sample_orders: 形状为(bsz, sample_lens)的张量，包含一维索引
        width: 图像宽度，用于计算二维索引

    Returns:
        重新排序后的张量，形状和原张量相同，设备和数据类型保持一致
    """
    bsz, sample_lens = sample_orders.shape
    device = sample_orders.device

    # 转换为二维索引
    i_coords = sample_orders // width  # 行索引
    j_coords = sample_orders % width  # 列索引

    # 判断条件 (i + j) % 2 == 1
    condition = (i_coords + j_coords) % 2 == 1

    # 创建排序键：满足条件的赋值0（排在前面），不满足的赋值1（排在后面）
    sort_keys = (~condition).long()

    # 创建稳定排序的辅助索引
    position_indices = torch.arange(sample_lens, device=device).unsqueeze(0).expand(bsz, -1)

    # 使用argsort进行稳定排序
    # 首先按sort_keys排序，然后按原始位置排序（保持稳定性）
    combined_keys = sort_keys * sample_lens + position_indices
    sorted_indices = torch.argsort(combined_keys, dim=1, stable=True)

    # 根据排序索引重新排列
    result = torch.gather(sample_orders, 1, sorted_indices)

    return result

def convert_order(original_order: torch.Tensor, w: int, h: int) -> torch.Tensor:
    """
    将形状为 (bsz, w*h, 2) 的顺序张量转换为展平后的索引形式 (bsz, w*h)。

    Args:
        original_order (torch.Tensor): 输入的顺序张量，形状为 (bsz, seq_len, 2)。
        w (int): 图像的宽度（列数）。
        h (int): 图像的高度（行数）。

    Returns:
        torch.Tensor: 转换后的索引张量，形状为 (bsz, seq_len)。
    """
    # 提取 x 和 y 坐标
    x = original_order[..., 0]  # (bsz, seq_len)
    y = original_order[..., 1]  # (bsz, seq_len)
    # 计算行优先的线性索引
    linear_indices = y * w + x
    # 确保结果与输入张量的数据类型和设备一致
    return linear_indices.to(dtype=original_order.dtype, device=original_order.device)

def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking

def generate_2D_sample_order(input_size, nb_sample=50):
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size

    total_pixels = height * width
    basic_halton_mask = build_halton_mask((height, width))
    randomize_mask = torch.randint(0, total_pixels, (nb_sample,))
    halton_mask = torch.zeros(nb_sample, total_pixels, 2, dtype=torch.long)

    for i_h in range(nb_sample):
        rand_halton = torch.roll(basic_halton_mask.clone(), randomize_mask[i_h].item(), 0)
        halton_mask[i_h] = rand_halton
    return halton_mask

def build_halton_mask(input_size, nb_point=50_000):
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size

    total_pixels = height * width
    nb_point = max(nb_point, total_pixels * 10)  # 确保足够点数

    def halton(b, n_sample):
        n, d = 0, 1
        res = []
        for _ in range(n_sample):
            x = d - n
            if x == 1:
                n = 1
                d *= b
            else:
                y = d // b
                while x <= y:
                    y //= b
                n = (b + 1) * y - x
            res.append(n / d)
        return res

    data_i = torch.tensor(halton(2, nb_point))
    data_j = torch.tensor(halton(3, nb_point))
    mask = torch.stack([data_i, data_j], dim=1)

    # 缩放到矩形尺寸
    mask[:, 0] = mask[:, 0] * height
    mask[:, 1] = mask[:, 1] * width
    mask = torch.floor(mask).long()

    # 过滤有效点
    valid_mask = (mask[:, 0] >= 0) & (mask[:, 0] < height) & \
                 (mask[:, 1] >= 0) & (mask[:, 1] < width)
    mask = mask[valid_mask]

    # 去重
    _, unique_indices = np.unique(mask.numpy(), axis=0, return_index=True)
    mask = mask[torch.from_numpy(np.sort(unique_indices))]

    # 确保足够点数
    if len(mask) < total_pixels:
        # 生成补充点
        existing_set = set(tuple(p.tolist()) for p in mask)
        complement = []
        while len(complement) < total_pixels - len(mask):
            point = [torch.randint(0, height, (1,)).item(),
                     torch.randint(0, width, (1,)).item()]
            if tuple(point) not in existing_set:
                complement.append(point)
                existing_set.add(tuple(point))
        mask = torch.cat([mask, torch.tensor(complement)])
    else:
        mask = mask[:total_pixels]

    return mask

def build_sobol_mask(input_size, nb_point=50_000):
    """构建支持矩形的Sobol序列掩码"""
    # 解析输入尺寸
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size
    total_pixels = height * width

    # 确保足够的采样点
    nb_point = max(nb_point, total_pixels * 10)

    # Sobol序列生成器
    def sobol_sequence(n_sample, dim=2):
        from torch.quasirandom import SobolEngine
        sobol_engine = SobolEngine(dimension=dim, scramble=True)
        return sobol_engine.draw(n_sample)

    # 生成并映射Sobol序列
    sobol_seq = sobol_sequence(nb_point)
    mask = torch.stack([
        sobol_seq[:, 0] * height,
        sobol_seq[:, 1] * width
    ], dim=1).floor().long()

    # 过滤有效点
    valid_mask = (mask[:, 0] >= 0) & (mask[:, 0] < height) & \
                 (mask[:, 1] >= 0) & (mask[:, 1] < width)
    mask = mask[valid_mask]

    # 去重
    _, unique_indices = np.unique(mask.numpy(), axis=0, return_index=True)
    mask = mask[torch.from_numpy(np.sort(unique_indices))]

    # 补充缺失点
    if len(mask) < total_pixels:
        existing_set = set(tuple(p.tolist()) for p in mask)
        complement = []
        while len(complement) < total_pixels - len(mask):
            point = [torch.randint(0, height, (1,)).item(),
                     torch.randint(0, width, (1,)).item()]
            if tuple(point) not in existing_set:
                complement.append(point)
                existing_set.add(tuple(point))
        mask = torch.cat([mask, torch.tensor(complement)])
    else:
        mask = mask[:total_pixels]

    return mask

def build_kronecker_mask(input_size, nb_point=50_000):
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size

    total_pixels = height * width
    nb_point = max(nb_point, total_pixels * 10)  # 确保生成足够点数

    # 计算黄金比例φ和φ²
    phi = (1 + math.sqrt(5)) / 2
    phi2 = phi * phi  # 或 phi + 1（因为φ² = φ + 1）

    # 生成Kronecker序列
    n_indices = torch.arange(1, nb_point + 1, dtype=torch.float32)

    # 计算坐标的小数部分: {n * φ} 和 {n * φ²}
    data_i = (n_indices * phi) % 1
    data_j = (n_indices * phi2) % 1

    # 缩放到网格尺寸并取整
    mask = torch.stack([
        torch.floor(data_i * height).long(),
        torch.floor(data_j * width).long()
    ], dim=1)

    # 过滤有效点 (在网格范围内)
    valid_mask = (mask[:, 0] >= 0) & (mask[:, 0] < height) & \
                 (mask[:, 1] >= 0) & (mask[:, 1] < width)
    mask = mask[valid_mask]

    # 去重 (保留首次出现的唯一坐标)
    _, unique_indices = np.unique(mask.numpy(), axis=0, return_index=True)
    mask = mask[torch.from_numpy(np.sort(unique_indices))]

    # 确保点数足够 (不足时补充随机点)
    if len(mask) < total_pixels:
        existing_set = set(tuple(p.tolist()) for p in mask)
        complement = []
        while len(complement) < total_pixels - len(mask):
            point = (torch.randint(0, height, (1,)).item(),
                     torch.randint(0, width, (1,)).item())
            if point not in existing_set:
                complement.append(point)
                existing_set.add(point)
        mask = torch.cat([mask, torch.tensor(complement)])
    else:
        mask = mask[:total_pixels]

    return mask

def build_random_mask(input_size, nb_point=50_000):
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size

    # 生成所有坐标点
    i = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
    j = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
    mask = torch.stack([i, j], dim=1).float()

    # 打乱顺序
    shuffle_idx = torch.randperm(mask.size(0))
    mask = mask[shuffle_idx]

    return mask

def build_low_discrepancy_mask(input_size, nb_point=50_000, candidate_size=1000, rebuild_interval=100):
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size
    total_points = height * width

    # 生成所有坐标点网格
    i = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
    j = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
    all_points = torch.stack([i, j], dim=1).float()

    # 只处理前min(nb_point, total_points)个点
    k = min(nb_point, total_points)

    # 初始化数据结构
    visited = torch.zeros(total_points, dtype=torch.bool)
    selected = torch.zeros((k, 2), dtype=torch.float)

    # 随机选择起始点
    start_idx = torch.randint(total_points, (1,)).item()
    visited[start_idx] = True
    selected[0] = all_points[start_idx]

    # KD树初始化
    tree_data = [selected[0].numpy()]
    kd_tree = cKDTree(tree_data)
    buffer = []

    # 迭代生成低差异序列
    for idx in range(1, k):
        # 定期重建KD树以提高效率
        if len(buffer) >= rebuild_interval:
            tree_data = np.vstack([tree_data] + [p.numpy() for p in buffer])
            kd_tree = cKDTree(tree_data)
            buffer = []

        # 随机选择候选点（仅考虑未访问点）
        unvisited = torch.where(~visited)[0]
        if len(unvisited) == 0:
            break

        sample_size = min(candidate_size, len(unvisited))
        candidate_indices = unvisited[torch.randperm(len(unvisited))[:sample_size]]
        candidates = all_points[candidate_indices]

        # 计算候选点到已有点的最小距离
        dists, _ = kd_tree.query(candidates.numpy())
        min_dists = torch.from_numpy(dists).float()

        # 考虑缓冲区中的点（尚未加入KD树）
        if buffer:
            buffer_tensor = torch.stack(buffer)
            buffer_dists = torch.cdist(candidates, buffer_tensor).min(dim=1)[0]
            min_dists = torch.min(min_dists, buffer_dists)

        # 选择最小距离最大的候选点（最远点）
        best_idx = torch.argmax(min_dists).item()
        best_point = candidates[best_idx]
        best_global_idx = candidate_indices[best_idx].item()

        # 更新数据结构
        selected[idx] = best_point
        visited[best_global_idx] = True
        buffer.append(best_point)

    # 处理剩余点（随机顺序）
    remaining_indices = torch.where(~visited)[0]
    remaining_points = all_points[remaining_indices]
    shuffled_indices = torch.randperm(len(remaining_points))
    mask = torch.cat([selected, remaining_points[shuffled_indices]])

    return mask

def generate_halton_like_sequence(input_size, r=3, alpha=0.85):
    """
    生成一个类似Halton序列的点序列

    参数:
        height (int): 矩阵高度
        width (int): 矩阵宽度
        r (int): 邻域半径，控制概率衰减范围
        alpha (float): 衰减强度因子 (0 < alpha < 1)

    返回:
        torch.Tensor: 形状为 (height*width, 2) 的序列，数据类型为torch.long
    """
    if isinstance(input_size, int):
        height, width = input_size, input_size
    else:
        height, width = input_size
    # 初始化权重矩阵（全1）
    weights = torch.ones((height, width), dtype=torch.float32)

    # 预计算衰减核 (2r+1 x 2r+1)
    kernel = torch.ones(2 * r + 1, 2 * r + 1)
    center = r
    for i in range(2 * r + 1):
        for j in range(2 * r + 1):
            # 计算到中心点的欧氏距离
            d = ((i - center) ** 2 + (j - center) ** 2) ** 0.5
            # 归一化距离 (0到1之间)
            normalized_d = d / (r * 1.414)  # 最大距离为 r*sqrt(2)
            # 线性衰减因子 (距离越大衰减越小)
            if normalized_d <= 1.0:
                factor = 1.0 - alpha * (1.0 - normalized_d)
            else:
                factor = 1.0
            kernel[i, j] = factor

    samples = []
    total_points = height * width

    for _ in range(total_points):
        # 归一化权重为概率分布
        probs = weights.view(-1)
        probs /= probs.sum()

        # 采样一个点
        idx = torch.multinomial(probs, 1).item()
        y = idx // width
        x = idx % width
        samples.append([y, x])

        # 将采样点权重置零
        weights[y, x] = 0.0

        # 更新邻域权重
        y_min = max(0, y - r)
        y_max = min(height, y + r + 1)
        x_min = max(0, x - r)
        x_max = min(width, x + r + 1)

        if y_min < y_max and x_min < x_max:
            # 提取邻域子网格
            subgrid = weights[y_min:y_max, x_min:x_max]
            # 计算对应的核子区域
            k_start_y = r - (y - y_min)
            k_end_y = r + (y_max - y)
            k_start_x = r - (x - x_min)
            k_end_x = r + (x_max - x)
            kernel_sub = kernel[k_start_y:k_end_y, k_start_x:k_end_x]

            # 应用衰减（仅影响未采样点）
            updated_subgrid = subgrid * kernel_sub
            weights[y_min:y_max, x_min:x_max] = updated_subgrid

    # 转换为tensor并调整坐标顺序
    sequence = torch.tensor(samples, dtype=torch.long)
    return sequence

def shuffle_sequences(x, a, b):
    """
    打乱输入序列集中每个样本在[a, b)范围内的片段

    参数:
    x (Tensor): 输入序列集, 形状为(bsz, seq_len), 数据类型为torch.long
    a (int): 片段的起始索引(包含)
    b (int): 片段的结束索引(不包含)

    返回:
    Tensor: 打乱后的序列集, 形状为(bsz, seq_len)
    """
    # 检查边界有效性
    if a >= b or a < 0 or b > x.size(1):
        return x.clone()  # 无效范围直接返回副本

    # 创建结果的副本
    result = x.clone()
    segment = x[:, a:b]  # 提取需要打乱的片段 (bsz, segment_len)
    segment_len = segment.size(1)

    # 生成随机排列索引 (避免for循环的批处理方式)
    rand_matrix = torch.rand(segment.size(), device=x.device)  # 随机数矩阵
    perm_indices = rand_matrix.argsort(dim=1)  # 每行独立排序获取索引

    # 使用gather操作进行批处理打乱
    shuffled_segment = torch.gather(
        segment,
        dim=1,
        index=perm_indices
    )

    # 将打乱后的片段放回原位置
    result[:, a:b] = shuffled_segment
    return result
