import torch
import numpy as np


def sparse_code_loss(signal, reconstruction, lb, sparse_codes):
    """
    The original sparse code reconstruction loss.

    Args:
        signal: Original signal tensor (T,B,F)
        reconstruction: Reconstructed signal tensor (T,B,F)
        lb: Lambda weight for L1 regularization
        sparse_codes: Sparse code tensor (T,B,H).

    Returns:
        Combined reconstruction MSE loss and L1 regularization loss
    """
    # MSE reconstruction loss
    loss = torch.norm(signal - reconstruction, p=2, dim=2).mean()

    # L1 regularization - either on sparse_codes (preferred) or on reconstruction
    l1loss = torch.norm(sparse_codes, p=1, dim=-1).mean()

    return (lb * l1loss + loss)


def sparsity_metric(sparse_codes):
    """
    Calculates the sparsity (L0 norm) of sparse codes as the fraction of zero elements.

    Args:
        sparse_codes: Sparse code tensor of any shape

    Returns:
        A float tensor representing the fraction of zero elements (0.0 to 1.0)
        where 1.0 means all elements are zero (completely sparse)
    """
    # Count number of zeros in the sparse codes
    num_zeros = torch.sum(sparse_codes == 0).float()
    # Calculate total number of elements
    total_elements = sparse_codes.numel()
    # Return fraction of zeros
    return num_zeros / total_elements


def psnr(ref, reconstruction,normalized=False):
    """
    Calculate Peak Signal-to-Noise Ratio (PSNR) between reference and reconstructed images.
    Automatically detects if the input is normalized [-1,1] or unnormalized [0,255].
    
    Args:
        ref: Reference image tensor
        reconstruction: Reconstructed image tensor
        
    Returns:
        PSNR value in dB
    """
    if normalized:
        max_pixel_val = 2.0  # Range is 2.0 for [-1,1]
        # Clip to [-1,1] range
        reconstruction = torch.clamp(reconstruction, -1.0, 1.0)
        ref = torch.clamp(ref, -1.0, 1.0)
    else:
        # Unnormalized case [0,255]
        max_pixel_val = 255.0
        # Clip to [0,255] range and round
        reconstruction = torch.clamp(torch.round(reconstruction), 0, 255)
        ref = torch.clamp(torch.round(ref), 0, 255)
    
    mse = torch.mean((ref - reconstruction) ** 2)
    if mse == 0:
        return torch.tensor(100.0)
    return 20 * torch.log10(max_pixel_val / torch.sqrt(mse))


def psnr_numpy(ref, reconstructed):
    """
    Calculate Peak Signal-to-Noise Ratio (PSNR) between reference and reconstructed images.
    Automatically detects if the input is normalized [-1,1] or unnormalized [0,255].
    
    Args:
        ref: Reference image numpy array
        reconstructed: Reconstructed image numpy array
        
    Returns:
        PSNR value in dB
    """
    # Detect if input is normalized [-1,1] or unnormalized [0,255]
    if np.max(np.abs(ref)) <= 1.0:
        # Normalized case [-1,1]
        max_pixel_val = 2.0  # Range is 2.0 for [-1,1]
        # Clip to [-1,1] range
        reconstructed = np.clip(reconstructed, -1.0, 1.0)
        ref = np.clip(ref, -1.0, 1.0)
    else:
        # Unnormalized case [0,255]
        max_pixel_val = 255.0
        # Clip to [0,255] range and round
        reconstructed = np.clip(np.round(reconstructed), 0, 255)
        ref = np.clip(np.round(ref), 0, 255)
    
    mse = np.mean((ref - reconstructed) ** 2)
    if mse == 0:
        return 100
    return 20 * np.log10(max_pixel_val / np.sqrt(mse))


def mse(ref, reconstructed):
    loss = torch.nn.MSELoss()
    return loss(ref, reconstructed)
