"""
GPU加速的最小二乘求解器，用于attention map特征提取

该模块提供CUDA加速的M^T·M和M^T·S_T计算，避免显式构造大矩阵
"""

import torch
import os
from torch.utils.cpp_extension import load

# 尝试加载编译好的CUDA扩展
_cuda_lstsq = None

def _load_cuda_extension():
    """延迟加载CUDA扩展"""
    global _cuda_lstsq
    if _cuda_lstsq is not None:
        return _cuda_lstsq
    
    try:
        # 获取当前文件所在目录
        current_dir = os.path.dirname(os.path.abspath(__file__))
        cuda_file = os.path.join(current_dir, 'cuda_lstsq_kernel.cu')
        
        # 使用JIT编译
        _cuda_lstsq = load(
            name='cuda_lstsq',
            sources=[cuda_file],
            extra_cuda_cflags=['-O3', '--use_fast_math', '--extended-lambda'],
            verbose=True
        )
        print("CUDA extension loaded successfully")
        return _cuda_lstsq
    except Exception as e:
        print(f"Warning: Failed to load CUDA extension: {e}")
        print("Falling back to PyTorch implementation")
        return None

def compute_mtm_pytorch(S_0, blocks_per_frame, regularization=1e-3):
    """
    PyTorch实现的M^T·M计算（作为fallback）
    
    参数:
        S_0: 第0步的attention map, shape (head, n, n)
        regularization: 正则化系数
    
    返回:
        MTM: M^T·M矩阵, shape (total_features, total_features)
    """
    head_num = S_0.shape[0]
    n = S_0.shape[1]
    num_diags = 2 * n - 1
    total_features = 1 + num_diags + n
    
    device = S_0.device
    dtype = S_0.dtype
    
    MTM = torch.zeros(head_num, total_features, total_features, device=device, dtype=dtype)
    
    
    # 5. 对角线C_i与C_j的内积
    for i in range(num_diags):
        offset_i = i - (n - 1)
        for j in range(i, num_diags):
            offset_j = j - (n - 1)
            if offset_i == offset_j:
                # 同一条对角线
                diag_len = n - abs(offset_i)
                MTM[:, i, j] = float(diag_len)
                if i != j:
                    MTM[:, j, i] = float(diag_len)
    
    # 6. 垂直线D_i与D_j的内积
    for i in range(n):
        for j in range(i, n):
            val = float(n) if i == j else 0.0
            MTM[:, num_diags + i, num_diags + j] = val
            if i != j:
                MTM[:, num_diags + j, num_diags + i] = val

    # 7. 对角线C_i与垂直线D_j的内积
    for i in range(num_diags):
        offset = i - (n - 1)
        for j in range(n):
            # 对角线i与垂直线j的交点
            overlap = 0.0
            if offset >= 0:
                # 主对角线及上方
                if j >= offset and j < n:
                    overlap = 1.0
            else:
                # 主对角线下方
                if j < n + offset:
                    overlap = 1.0
            MTM[:, i, num_diags + j] = overlap
            MTM[:, num_diags + j, i] = overlap

    # 8. 块E_i与E_j的内积
    num_blocks = n // blocks_per_frame
    block_starts = torch.arange(0, n, blocks_per_frame, device=device)
    total = torch.zeros(head_num, device=device, dtype=dtype)
    MTM[:, num_diags + n, num_diags + n] = num_blocks * (blocks_per_frame ** 2)

    # 9. 对角线C_i与块E_j的内积
    for i in range(num_diags):
        offset = i - (n - 1)
        overlap = torch.zeros(head_num, device=device, dtype=dtype)
        for j in range(num_blocks):
            start = block_starts[j].item()
            # 计算对角线与块的重叠
            for k in range(blocks_per_frame):
                row = start + k
                if offset >= 0:
                    col = row + offset
                else:
                    col = row + offset  # offset是负数
                if col >= start and col < start + blocks_per_frame:
                    overlap += 1.0
        MTM[:, i, num_diags + n] = overlap
        MTM[:, num_diags + n, i] = overlap
    
    # 10. 垂直线D_i与块E_j的内积
    for i in range(n):
        MTM[:, num_diags + i, num_diags + n] = blocks_per_frame
        MTM[:, num_diags + n, num_diags + i] = blocks_per_frame

    # 添加正则化（对角线）以确保数值稳定性
    MTM += torch.eye(total_features, device=device, dtype=dtype).unsqueeze(0) * regularization
    
    return MTM

def compute_mts_pytorch(S_T, S_0, blocks_per_frame):
    """
    PyTorch实现的M^T·S_T计算（作为fallback）
    
    参数:
        S_T: 当前步的attention map, shape (n, n)
        S_0: 第0步的attention map, shape (n, n)
        block_starts: 块起始索引, shape (num_blocks,)
        block_sizes: 块大小, shape (num_blocks,)
    
    返回:
        MTS: M^T·S_T向量, shape (total_features,)
    """
    n = S_T.shape[1]
    num_diags = 2 * n - 1
    total_features = 1 + num_diags + n
    num_heads = S_0.shape[0]
    
    device = S_T.device
    dtype = S_T.dtype
    
    MTS = torch.zeros(num_heads, total_features, device=device, dtype=dtype)
    
    
    # 2. C_k^T · S_T（对角线）
    for k in range(num_diags):
        offset = k - (n - 1)
        total = torch.zeros(num_heads, device=device, dtype=dtype)
        if offset >= 0:
            for i in range(n - offset):
                total += S_T[:, i, i + offset]
        else:
            for i in range(n + offset):
                total += S_T[:, i - offset, i]
        MTS[:, k] = total
    
    # 3. D_k^T · S_T（垂直线）
    for k in range(n):
        MTS[:, num_diags + k] = S_T[:, :, k].sum(dim=1)

    # 4. E_k^T · S_T（块）
    num_blocks = n // blocks_per_frame
    block_starts = torch.arange(0, n, blocks_per_frame, device=device)
    total = torch.zeros(num_heads, device=device, dtype=dtype)
    for idx in range(num_blocks):
        start = block_starts[idx].item()
        total += S_T[:, start:start+blocks_per_frame, start:start+blocks_per_frame].sum(dim=(1, 2))
    MTS[:, num_diags + n] = total

    return MTS

def solve_lstsq(warmup_state, S_T, S_0, step, blocks_per_frame, regularization=1e-5, use_cuda=True):
    """
    求解最小二乘问题: min ||S_T - MX||^2
    理论解: X = (M^T·M)^{-1} · M^T·S_T
    
    参数:
        S_T: 当前步的attention map, shape (n, n)
        S_0: 第0步的attention map, shape (n, n)
        block_starts: 块起始索引, shape (num_blocks,) 或 None
        block_sizes: 块大小, shape (num_blocks,) 或 None
        regularization: 正则化系数，用于数值稳定性（默认改为1e-4）
        use_cuda: 是否使用CUDA加速
    
    返回:
        特征字典 {
            'p': 标量，S_0的权重
            'c': Tensor，对角线亮度值，shape (2n-1,)
            'd': Tensor，垂直线亮度值，shape (n,)
            'e': Tensor，块亮度值，shape (num_blocks,)
        }
    """
    num_heads = S_T.shape[0]
    n = S_T.shape[1]
    S_0 = S_0.to(torch.float32)
    S_T = S_T.to(torch.float32)
    num_diags = 2 * n - 1
    
    # 尝试使用CUDA加速
    cuda_ext = _load_cuda_extension() if use_cuda else None
    
    if cuda_ext is not None and S_T.is_cuda:
        if step == warmup_state['warmup_steps'] - 2:
            try:
            # 使用CUDA实现
                MTM = cuda_ext.compute_mtm(S_0, blocks_per_frame, regularization)
            except Exception as e:
                print(f"CUDA execution failed: {e}, falling back to PyTorch")
                MTM = compute_mtm_pytorch(S_0, blocks_per_frame, regularization)
            warmup_state['MTM'] = MTM
        else:
            try:
                # 使用CUDA实现
                MTS = cuda_ext.compute_mts(S_T, S_0, blocks_per_frame)
            except Exception as e:
                print(f"CUDA execution failed: {e}, falling back to PyTorch")
                MTS = compute_mts_pytorch(S_T, S_0, blocks_per_frame)
    else:
        # 使用PyTorch实现
        if step == warmup_state['warmup_steps'] - 2:
            MTM = compute_mtm_pytorch(S_0, blocks_per_frame, regularization)
            warmup_state['MTM'] = MTM
        else:
            MTS = compute_mts_pytorch(S_T, S_0, blocks_per_frame)

    if step > warmup_state['warmup_steps'] - 2:
        # 求解线性系统 MTM · X = MTS
        MTM = warmup_state['MTM']
        
        # 安全的逐个头求解，避免数值稳定性问题
        X_list = []
        for head_idx in range(MTM.shape[0]):
            try:
                MTM_head = MTM[head_idx]  # 形状: [total_features, total_features]
                
                # 确保MTS_head形状正确
                if MTS.dim() == 2 and MTS.shape[0] == MTM.shape[0]:
                    MTS_head = MTS[head_idx].unsqueeze(1)  # 形状: [total_features, 1]
                else:
                    MTS_head = MTS.unsqueeze(1) if MTS.dim() == 2 else MTS
                    if MTS_head.shape[0] == MTM.shape[0]:
                        MTS_head = MTS_head[head_idx]
                    else:
                        MTS_head = MTS_head[0]  # 如果MTS只有一个头，使用第一个
                
                # 检查形状匹配
                if MTM_head.shape[0] != MTS_head.shape[0]:
                    # 调整到最小维度
                    min_dim = min(MTM_head.shape[0], MTS_head.shape[0])
                    MTM_head = MTM_head[:min_dim, :min_dim]
                    MTS_head = MTS_head[:min_dim].unsqueeze(1)
                
                # 使用稳定的求解方法
                try:
                    # 首先尝试Cholesky分解（最稳定）
                    L = torch.linalg.cholesky(MTM_head)
                    X_head = torch.cholesky_solve(MTS_head, L)
                except RuntimeError:
                    # Cholesky失败，使用LU分解
                    try:
                        X_head = torch.linalg.solve(MTM_head, MTS_head)
                    except RuntimeError:
                        # 如果都失败，使用伪逆
                        MTM_pinv = torch.linalg.pinv(MTM_head)
                        X_head = MTM_pinv @ MTS_head
                
                X_list.append(X_head.squeeze(-1))
                
            except Exception as e:
                print(f"头{head_idx}求解失败: {e}")
                # 返回零向量作为fallback
                X_head = torch.zeros(MTM_head.shape[0], device=MTM_head.device, dtype=MTM_head.dtype)
                X_list.append(X_head)
        
        X = torch.stack(X_list)
        
        # 安全的解包结果
        total_features = X.shape[1]
        
        # 安全的索引范围检查
        idx_d_start = 0
        idx_d_end = min(0 + num_diags, total_features)
        idx_c_start = idx_d_end
        idx_c_end = min(idx_d_end + n, total_features)

        features = {
            'd': X[:, idx_d_start:idx_d_end],  # (head, min(num_diags, available))
            'c': X[:, idx_c_start:idx_c_end] / 2 ,  # (head, min(n, available))
            'b_d': X[:, -1]  # (head,)
        }
        
        return features
    else:
        return{
            'd': torch.zeros((num_heads,2*n-1), device=S_T.device, dtype=S_T.dtype),
            'c': torch.zeros((num_heads,n), device=S_T.device, dtype=S_T.dtype),
            'b_d': torch.zeros((num_heads), device=S_T.device, dtype=S_T.dtype)
        }


if __name__ == "__main__":
    import time
    
    print("="*70)
    print("Testing CUDA LSTSQ solver with multi-head attention")
    print("="*70)
    
    # 测试配置
    head_num = 8
    n = 52
    blocks_per_frame = 4  # 52/13 = 4 blocks
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"\nConfiguration:")
    print(f"  Device: {device}")
    print(f"  Number of heads: {head_num}")
    print(f"  Sequence length: {n}")
    print(f"  Blocks per frame: {blocks_per_frame}")
    print(f"  Number of blocks: {n // blocks_per_frame}")
    
    # 生成测试数据
    torch.manual_seed(42)
    S_0 = torch.rand(head_num, n, n, device=device, dtype=torch.float32)
    S_T = 0.5 * S_0 + 0.3 * torch.eye(n, device=device).unsqueeze(0) + 0.2 * torch.rand(head_num, n, n, device=device)
    
    # 预先加载CUDA扩展（不计入测试时间）
    if device.type == 'cuda':
        print("\n" + "="*70)
        print("Preloading CUDA extension (compilation time not counted)...")
        print("="*70)
        _ = _load_cuda_extension()
        if _ is not None:
            print("CUDA extension ready.")
        else:
            print("Warning: CUDA extension not available, will use PyTorch only")
    
    # ============================================================
    # 测试1: PyTorch实现
    # ============================================================
    print("\n" + "="*70)
    print("Test 1: PyTorch implementation")
    print("="*70)
    
    warmup_state_torch = {}
    warmup_state_torch['warmup_steps'] = 2
    # Step 0: 计算MTM
    print("\nStep 0: Computing MTM...")
    start = time.time()
    _ = solve_lstsq(warmup_state_torch, S_T, S_0, step=0, blocks_per_frame=blocks_per_frame, use_cuda=False)
    mtm_time_torch = time.time() - start
    print(f"  PyTorch MTM computation time: {mtm_time_torch*1000:.4f} ms")
    print(f"  MTM shape: {warmup_state_torch['MTM'].shape}")
    
    # Step 1: 求解
    print("\nStep 1: Solving linear system...")
    start = time.time()
    features_torch = solve_lstsq(warmup_state_torch, S_T, S_0, step=1, blocks_per_frame=blocks_per_frame, use_cuda=False)
    solve_time_torch = time.time() - start
    print(f"  PyTorch solve time: {solve_time_torch*1000:.4f} ms")
    print(f"  Total time: {(mtm_time_torch + solve_time_torch)*1000:.4f} ms")
    
    print(f"\n  Feature statistics:")
    # print(f"    p shape: {features_torch['p'].shape}, mean: {features_torch['p'].mean():.6f}")
    print(f"    c shape: {features_torch['c'].shape}, mean: {features_torch['c'].mean():.6f}")
    print(f"    d shape: {features_torch['d'].shape}, mean: {features_torch['d'].mean():.6f}")
    print(f"    b_d shape: {features_torch['b_d'].shape}, mean: {features_torch['b_d'].mean():.6f}")
    
    # ============================================================
    # 测试2: CUDA实现（如果可用）
    # ============================================================
    if device.type == 'cuda' and _load_cuda_extension() is not None:
        print("\n" + "="*70)
        print("Test 2: CUDA implementation")
        print("="*70)
        
        warmup_state_cuda = {}
        warmup_state_cuda['warmup_steps'] = 2
        # Warmup
        print("\nWarming up CUDA kernels...")
        for _ in range(3):
            warmup_state_tmp = {}
            warmup_state_tmp['warmup_steps'] = 2
            _ = solve_lstsq(warmup_state_tmp, S_T, S_0, step=0, blocks_per_frame=blocks_per_frame, use_cuda=True)
            _ = solve_lstsq(warmup_state_tmp, S_T, S_0, step=1, blocks_per_frame=blocks_per_frame, use_cuda=True)
        torch.cuda.synchronize()
        
        # Step 0: 计算MTM
        print("\nStep 0: Computing MTM...")
        torch.cuda.synchronize()
        start = time.time()
        _ = solve_lstsq(warmup_state_cuda, S_T, S_0, step=0, blocks_per_frame=blocks_per_frame, use_cuda=True)
        torch.cuda.synchronize()
        mtm_time_cuda = time.time() - start
        print(f"  CUDA MTM computation time: {mtm_time_cuda*1000:.4f} ms")
        print(f"  MTM shape: {warmup_state_cuda['MTM'].shape}")
        
        # Step 1: 求解
        print("\nStep 1: Solving linear system...")
        torch.cuda.synchronize()
        start = time.time()
        features_cuda = solve_lstsq(warmup_state_cuda, S_T, S_0, step=1, blocks_per_frame=blocks_per_frame, use_cuda=True)
        torch.cuda.synchronize()
        solve_time_cuda = time.time() - start
        print(f"  CUDA solve time: {solve_time_cuda*1000:.4f} ms")
        print(f"  Total time: {(mtm_time_cuda + solve_time_cuda)*1000:.4f} ms")
        
        print(f"\n  Feature statistics:")
        # print(f"    p shape: {features_cuda['p'].shape}, mean: {features_cuda['p'].mean():.6f}")
        print(f"    c shape: {features_cuda['c'].shape}, mean: {features_cuda['c'].mean():.6f}")
        print(f"    d shape: {features_cuda['d'].shape}, mean: {features_cuda['d'].mean():.6f}")
        print(f"    b_d shape: {features_cuda['b_d'].shape}, mean: {features_cuda['b_d'].mean():.6f}")
        
        # ============================================================
        # 测试3: 性能对比和精度验证
        # ============================================================
        print("\n" + "="*70)
        print("Test 3: Performance comparison and accuracy verification")
        print("="*70)
        
        print(f"\n{'Performance Comparison':^70}")
        print("-"*70)
        print(f"  {'Stage':<20} {'PyTorch (ms)':<15} {'CUDA (ms)':<15} {'Speedup':<10}")
        print("-"*70)
        print(f"  {'MTM computation':<20} {mtm_time_torch*1000:>12.4f}   {mtm_time_cuda*1000:>12.4f}   {mtm_time_torch/mtm_time_cuda:>7.2f}x")
        print(f"  {'Linear solve':<20} {solve_time_torch*1000:>12.4f}   {solve_time_cuda*1000:>12.4f}   {solve_time_torch/solve_time_cuda:>7.2f}x")
        total_torch = mtm_time_torch + solve_time_torch
        total_cuda = mtm_time_cuda + solve_time_cuda
        print(f"  {'Total':<20} {total_torch*1000:>12.4f}   {total_cuda*1000:>12.4f}   {total_torch/total_cuda:>7.2f}x")
        print("-"*70)
        
        print(f"\n{'Accuracy Verification':^70}")
        print("-"*70)
        # p_diff = (features_torch['p'] - features_cuda['p']).abs()
        c_diff = (features_torch['c'] - features_cuda['c']).abs()
        d_diff = (features_torch['d'] - features_cuda['d']).abs()
        b_d_diff = (features_torch['b_d'] - features_cuda['b_d']).abs()
        
        print(f"  {'Feature':<15} {'Max Diff':<15} {'Mean Diff':<15} {'Rel Error':<15}")
        print("-"*70)
        # print(f"  {'p':<15} {p_diff.max().item():<15.2e} {p_diff.mean().item():<15.2e} {(p_diff.mean()/features_torch['p'].abs().mean()).item():<15.2e}")
        print(f"  {'c':<15} {c_diff.max().item():<15.2e} {c_diff.mean().item():<15.2e} {(c_diff.mean()/features_torch['c'].abs().mean()).item():<15.2e}")
        print(f"  {'d':<15} {d_diff.max().item():<15.2e} {d_diff.mean().item():<15.2e} {(d_diff.mean()/features_torch['d'].abs().mean()).item():<15.2e}")
        print(f"  {'b_d':<15} {b_d_diff.max().item():<15.2e} {b_d_diff.mean().item():<15.2e} {(b_d_diff.mean()/features_torch['b_d'].abs().mean()).item():<15.2e}")
        print("-"*70)
        
        # 计算总体相对误差
        all_torch = torch.cat([features_torch['c'].flatten(), 
                               features_torch['d'].flatten(), features_torch['b_d'].flatten()])
        all_cuda = torch.cat([features_cuda['c'].flatten(), 
                              features_cuda['d'].flatten(), features_cuda['b_d'].flatten()])
        rel_error = ((all_torch - all_cuda).norm() / all_torch.norm()).item()
        print(f"\n  Overall relative error: {rel_error:.2e}")
        
        if rel_error < 1e-4:
            print("  ✓ Accuracy test PASSED (relative error < 1e-4)")
        elif rel_error < 1e-3:
            print("  ⚠⚠⚠ Accuracy test WARNING (1e-4 < relative error < 1e-3)")
        else:
            print("  ✗✗ Accuracy test FAILED (relative error > 1e-3)")
    
    print("\n" + "="*70)
    print("Test completed!")
    print("="*70)