import numpy as np
from joblib import Parallel, delayed, cpu_count

def matmul_dot(start, end, X, R):
    # 计算一部分行的结果
    return X[start:end] @ R

def parallel_dot(X, R, n_jobs=None):
    n = X.shape[0]
    # 如果 n_jobs=None，就取所有核心数
    if n_jobs is None:
        n_jobs = cpu_count()
    
    # 自动分块
    chunk_size = (n + n_jobs - 1) // n_jobs  
    
    results = Parallel(n_jobs=n_jobs)(
        delayed(matmul_dot)(i * chunk_size, min((i + 1) * chunk_size, n), X, R)
        for i in range(n_jobs)
    )
    # 垂直拼接结果
    return np.vstack(results)


def _chunk_bounds(n, n_jobs):
    if not n_jobs or n_jobs <= 0:
        n_jobs = cpu_count()
    n_jobs = min(n_jobs, n)
    q, r = divmod(n, n_jobs)
    bounds = []
    start = 0
    for i in range(n_jobs):
        end = start + q + (1 if i < r else 0)
        if start < end:
            bounds.append((start, end))
        start = end
    return bounds

def parallel_dot_2d(X, R, n_jobs=None):
    """X:(n, M, K), R:(K, N) -> (n, M, N)"""
    assert X.ndim == 3 and R.ndim == 2 and X.shape[-1] == R.shape[0]
    bounds = _chunk_bounds(X.shape[0], n_jobs)
    parts = Parallel(n_jobs=len(bounds))(
        delayed(lambda s, e: X[s:e] @ R)(s, e) for s, e in bounds
    )
    return np.concatenate(parts, axis=0)

def parallel_dot_3d(X, R, n_jobs=None):
    """X:(n, M, K), R:(P, K, N) -> (n, M, P, N)  按 K 做张量点积"""
    assert X.ndim == 3 and R.ndim == 3 and X.shape[-1] == R.shape[1]
    bounds = _chunk_bounds(X.shape[0], n_jobs)
    def worker(s, e):
        # tensordot over K: (e-s, M, K) x (P, K, N) -> (e-s, M, P, N)
        return np.tensordot(X[s:e], R, axes=([2], [1]))
    parts = Parallel(n_jobs=len(bounds))(
        delayed(worker)(s, e) for s, e in bounds
    )
    return np.concatenate(parts, axis=0)

def parallel_dot_general(X, R, n_jobs=None):
    if R.ndim == 2:
        return parallel_dot_2d(X, R, n_jobs)
    elif R.ndim == 3:
        return parallel_dot_3d(X, R, n_jobs)
    else:
        raise ValueError(f"R.ndim must be 2 or 3, got {R.ndim}")
