import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

def weight_to_vector(binary_weight, vector_length=8):
    """
    将形状为 (rows, cols) 的二值权重矩阵转换成二维向量：
    1. 先去除所有为0的元素
    2. 如果剩余元素不能被vector_length整除，则通过交替填充+1/-1使其能被整除
    3. 将结果重组为(N, vector_length)的矩阵
    """
    # 获取非零元素
    non_zero = binary_weight[binary_weight != 0].flatten()
    
    # 计算需要填充的数量
    remainder = len(non_zero) % vector_length
    if remainder != 0:
        padding_size = vector_length - remainder
        # 创建交替的+1/-1填充
        padding = torch.tensor([1 if i % 2 == 0 else -1 for i in range(padding_size)], device=binary_weight.device, dtype=binary_weight.dtype)
        # 拼接填充
        non_zero = torch.cat([non_zero, padding])
    
    # 重组为(N, vector_length)的矩阵
    return non_zero.reshape(-1, vector_length)

def vector_to_weight(binary_matrix, binary_weight):
    """
    将形状为(N, vector_length)的二维向量转换回原始二值权重矩阵：
    1. 将二维向量展平
    2. 去除可能的填充元素
    3. 根据mask将非零元素放回到原始位置
    
    参数:
        binary_matrix: 形状为(N, vector_length)的二维矩阵
        original_shape: 原始权重矩阵的形状(rows, cols)
        mask: 指示原始矩阵中哪些位置为非零元素的布尔掩码
        vector_length: 向量长度，默认为8
    
    返回:
        恢复后的二值权重矩阵
    """
    mask = (binary_weight != 0)

    # 创建与原始形状相同的全零矩阵
    reconstructed = torch.zeros(binary_weight.shape, device=binary_matrix.device)
    
    # 将二维矩阵展平
    flattened = binary_matrix.flatten()
    
    # 计算原始矩阵中非零元素的数量
    non_zero_count = len(binary_weight[binary_weight != 0].flatten())
    
    # 只取需要的非零元素数量（去除可能的填充）
    flattened = flattened[:non_zero_count]
    
    # 将非零元素放回到原始位置
    reconstructed = reconstructed.masked_scatter(mask, flattened)
    
    return reconstructed



def compute_loss(binary, new_matrix, mask, refine_mean, new_alpha):
    """计算损失函数（修复残差计算）"""
    reconstructed = (new_alpha[:, None] * binary + refine_mean[:, None]) * mask
    residual = new_matrix - reconstructed  # 修正残差计算方向
    return torch.sum(residual ** 2)
    

def apply_codebook_quantization(new_binary, new_matrix, mask, refine_mean, new_alpha, 
                               vector_length=8, num_centroids=256, max_iter=50):
    """
    应用codebook量化到二值矩阵
    
    参数:
        new_binary: 要量化的二值矩阵
        new_matrix, mask, refine_mean, new_alpha: 计算损失所需的参数
        vector_length: 编码向量长度
        num_centroids: codebook大小
        
    返回:
        量化后的二值矩阵和相关信息
    """
    result = optimize_codebook(
        new_binary, new_matrix, mask, refine_mean, new_alpha,
        vector_length, num_centroids, max_iter
    )
    
    # 使用codebook重构二值矩阵
    codebook = result['codebook']
    indices = result['indices']
    reconstructed_binary = result['reconstructed_binary']
    
    return {
        'quantized_binary': reconstructed_binary,
        'codebook': codebook,
        'indices': indices,
        # 'original_loss': compute_loss(new_binary, new_matrix, mask, refine_mean, new_alpha),
        'quantized_loss': result['loss']
    }

def optimize_codebook(new_binary, new_matrix, mask, refine_mean, new_alpha,
                      vector_length=8, num_centroids=256, max_iter=50):
    """
    Optimizes a codebook for binary quantization using an EM-like algorithm.

    Args:
        new_binary: The {-1, +1} binary matrix to quantize.
        new_matrix, mask, refine_mean, new_alpha: Parameters for compute_loss.
        vector_length: Length of codebook vectors.
        num_centroids: Desired number of centroids (codebook size).
        max_iter: Maximum number of iterations for the optimization.

    Returns:
        A dictionary containing:
            'codebook': The optimized codebook (Tensor shape [num_centroids, vector_length]).
            'indices': Assignment indices for each input vector (Tensor shape [num_vectors]).
            'reconstructed_binary': The binary matrix reconstructed using the codebook.
            'loss': The final loss calculated using compute_loss.
    """
    device = new_binary.device
    dtype = torch.float32

    # 1. Vectorize input
    vectors = weight_to_vector(new_binary, vector_length)
    vectors = vectors.float()
    num_vectors = vectors.shape[0]

    if num_vectors == 0:
        # Handle case where input matrix leads to zero vectors after filtering/padding
        print("Warning: Input matrix resulted in zero vectors for quantization.")
        codebook = torch.empty((0, vector_length), device=device, dtype=dtype)
        indices = torch.empty((0,), dtype=torch.long, device=device)
        reconstructed_binary = torch.zeros_like(new_binary)
        # Ensure compute_loss can handle potentially zero reconstructed_binary
        final_loss = compute_loss(reconstructed_binary, new_matrix, mask, refine_mean, new_alpha)
        return {
            'codebook': codebook,
            'indices': indices,
            'reconstructed_binary': reconstructed_binary,
            'loss': final_loss
        }

    # 2. Find unique vectors
    unique_vectors, inverse_indices = torch.unique(vectors, dim=0, return_inverse=True)
    num_unique_vectors = unique_vectors.shape[0]

    # 3. Initialize codebook and potentially adjust num_centroids
    early_exit_possible = False
    if num_unique_vectors <= num_centroids:
        # print(f"Adjusting num_centroids from {num_centroids} to {num_unique_vectors}")
        num_centroids = num_unique_vectors
        codebook = unique_vectors.clone()
        # Assignment is exact, algorithm might finish in 0 or 1 iteration
        indices = inverse_indices # Pre-calculate exact indices
        early_exit_possible = True
    else:
        # Initialize by selecting the first num_centroids unique vectors
        indices_to_select = torch.arange(num_centroids, device=device)
        codebook = unique_vectors[indices_to_select].clone()
        indices = torch.full((num_vectors,), -1, dtype=torch.long, device=device) # Placeholder


    # --- EM Algorithm Loop ---
    for i in range(max_iter):
        # E-step: Assign vectors to nearest centroid (Squared Euclidean distance)
        # Avoid recalculating distances if assignment is already known (early exit case)
        if not (early_exit_possible and i == 0):
             # dist_matrix[k, j] = ||vectors[k] - codebook[j]||^2
             dist_matrix = torch.cdist(vectors, codebook, p=2.0) ** 2
             new_indices = torch.argmin(dist_matrix, dim=1)
        else:
             # Use pre-calculated indices for the first iteration if early exit is possible
             new_indices = indices

        # Check for convergence (assignments haven't changed)
        if i > 0 and torch.equal(new_indices, indices):
            # print(f"Converged after {i} iterations.")
            break

        indices = new_indices

        # If using exact indices from unique vectors, can break after first assignment confirmation
        if early_exit_possible and i == 0:
            # print("Assignment is exact based on unique vectors. Performing one M-step.")
            pass # Proceed to M-step once, then will break


        # M-step: Update centroids based on assigned vectors
        new_codebook = torch.zeros_like(codebook, dtype=dtype)
        cluster_non_empty = torch.zeros(num_centroids, device=device, dtype=torch.bool)

        # Calculate sum of vectors for each cluster efficiently
        indices_expanded = indices.unsqueeze(1).expand(-1, vector_length)
        vector_sum_per_cluster = torch.zeros_like(codebook, dtype=dtype).scatter_add_(0, indices_expanded, vectors)
        
        # Calculate counts for each cluster efficiently
        cluster_counts = torch.bincount(indices, minlength=num_centroids)

        # Update centroids where clusters are not empty
        non_empty_clusters = cluster_counts > 0
        cluster_non_empty[non_empty_clusters] = True
        
        # Calculate mean and update using sign; handle division by zero implicitly via non_empty_clusters mask
        mean_vectors = vector_sum_per_cluster[non_empty_clusters] / cluster_counts[non_empty_clusters].unsqueeze(1).float()
        signed_mean = torch.sign(mean_vectors)
        # Ensure codebook elements are strictly +1 or -1 (replace 0s from sign)
        signed_mean[signed_mean == 0] = 1.0 
        new_codebook[non_empty_clusters] = signed_mean


        # Handle empty clusters: keep the old centroid
        empty_clusters = ~non_empty_clusters
        new_codebook[empty_clusters] = codebook[empty_clusters]
            
        codebook = new_codebook

        # If assignment was exact, break after the first M-step
        if early_exit_possible and i == 0:
            # print("Exiting after first M-step for exact assignment case.")
            break
    # --- End EM Loop ---


    # 4. Reconstruct the binary matrix using final codebook and indices
    quantized_vectors = codebook[indices]
    reconstructed_binary = vector_to_weight(quantized_vectors.float(), new_binary)

    # 5. Calculate final loss using the specified loss function
    final_loss = compute_loss(reconstructed_binary, new_matrix, mask, refine_mean, new_alpha)

    return {
        'codebook': codebook,
        'indices': indices,
        'reconstructed_binary': reconstructed_binary,
        'loss': final_loss
    }
