import numpy as np
import concurrent.futures
import numpy as np

import numpy as np

def eye_basis(n, rate_choose=1.0):
    """
    生成标准正交基（单位矩阵的行向量），并根据 rate_choose 保留前一部分。

    参数:
    - n: 信号长度
    - rate_choose: 保留前 rate_choose 比例的基底，范围 (0, 1]，默认保留全部

    返回:
    - basis: 形状为 (num_basis, n) 的基底矩阵
    """
    basis = np.eye(n, dtype=float)
    num_to_keep = int(np.floor(n * rate_choose))
    return basis[:num_to_keep]
    
def dft_basis_high2low(n, rate_choose=1.0):
    """
    生成按频率从高到低排序的实数 DFT 基底（余弦/正弦），并根据 rate_choose 保留前一部分高频基底。

    参数:
    - n: 信号长度
    - rate_choose: 保留比例 (0, 1]，<1 时优先保留高频

    返回:
    - basis: 形状 (num_basis, n) 的实数 DFT 基底矩阵，其中 num_basis = floor(n * rate_choose)
    """
    if not (0 < rate_choose <= 1):
        raise ValueError("rate_choose 必须在 (0, 1] 内。")

    t = np.arange(n)
    basis = []

    if n % 2 == 0:
        # 先放 Nyquist（k = n/2，仅余弦，归一化 1/sqrt(n)）
        k = n // 2
        ck = np.cos(2 * np.pi * k * t / n) / np.sqrt(n)
        basis.append(ck)

        # 再从 k = n/2-1 到 1，按频率从高到低，每个 k 先 cos 再 sin
        for k in range(n // 2 - 1, 0, -1):
            c = np.sqrt(2 / n) * np.cos(2 * np.pi * k * t / n)
            s = np.sqrt(2 / n) * np.sin(2 * np.pi * k * t / n)
            basis.extend([c, s])
    else:
        # 奇数 n：从最高频 k = (n-1)//2 递减到 1
        for k in range((n - 1) // 2, 0, -1):
            c = np.sqrt(2 / n) * np.cos(2 * np.pi * k * t / n)
            s = np.sqrt(2 / n) * np.sin(2 * np.pi * k * t / n)
            basis.extend([c, s])

    # 最后放直流分量 k = 0
    c0 = np.ones(n) / np.sqrt(n)
    basis.append(c0)

    B = np.vstack(basis)
    num_to_keep = int(np.floor(n * rate_choose))
    return B[:num_to_keep]

def dft_basis(n, rate_choose=1.0):
    """
    生成 DFT 基底的实数版本（余弦 + 正弦），并根据 rate_choose 保留前一部分基底。

    参数:
    - n: 信号长度
    - rate_choose: 保留前 rate_choose 比例的基底，范围 (0, 1]，默认保留全部

    返回:
    - basis: 形状为 (num_basis, n) 的实数 DFT 基底矩阵
    """
    basis = []
    c0 = np.ones(n) / np.sqrt(n)
    basis.append(c0)

    max_k = n // 2 if n % 2 == 0 else (n - 1) // 2
    for k in range(1, max_k + 1):
        if n % 2 == 0 and k == n // 2:
            ck = np.cos(2 * np.pi * k * np.arange(n) / n) / np.sqrt(n)
            basis.append(ck)
        else:
            ck = np.sqrt(2 / n) * np.cos(2 * np.pi * k * np.arange(n) / n)
            sk = np.sqrt(2 / n) * np.sin(2 * np.pi * k * np.arange(n) / n)
            basis.append(ck)
            basis.append(sk)

    basis = np.array(basis)
    num_to_keep = int(np.floor(n * rate_choose))
    return basis[:num_to_keep]

def seg_dft_basis(n, n_per_seg, rate_choose=1.0):
    """
    生成分段 DFT 基底（实数版本），并根据 rate_choose 保留每个分段的基底。

    参数:
    - n: 信号长度
    - n_per_seg: 每段的长度
    - rate_choose: 保留每段基底的比例，范围 (0, 1]，默认保留全部

    返回:
    - basis: 形状为 (num_basis, n) 的实数 DFT 基底矩阵
    """
    basis = []
    num_segments = (n + n_per_seg - 1) // n_per_seg  # 计算总的段数

    for seg_idx in range(num_segments):
        start = seg_idx * n_per_seg
        end = min(start + n_per_seg, n)
        seg_length = end - start

        # 每段基底的数量，根据 rate_choose 计算
        num_to_keep = int(np.floor(seg_length * rate_choose))

        # 生成当前分段的傅里叶基向量
        for k in range(num_to_keep):
            vec = np.zeros(n, dtype=complex)
            # 生成单位化的傅里叶基向量
            seg = np.exp(-2j * np.pi * k * np.arange(seg_length) / seg_length) / np.sqrt(seg_length)
            vec[start:end] = seg
            basis.append(vec)

    # 截取前 n 个基向量，确保返回的基底数量是 n
    basis = np.array(basis[:n])
    return basis