import torch

def fidelity(qs:torch.Tensor,ref_state:torch.Tensor)->torch.Tensor:
    if qs.dim() == 1:
        return torch.abs(torch.vdot(qs, ref_state))**2
    elif qs.dim() == 2:
        return torch.einsum('i,ij,j->',ref_state.conj(), qs, ref_state).real
    else:
        raise ValueError('Quantum State must be a statevector (1-d) or a density matrix (2-d)')

def purity(rho:torch.Tensor)->torch.Tensor:
    assert rho.dim() == 2, '`rho` must be a density matrix (2-d)'
    return torch.sum(torch.abs(rho)**2)

def normalized_mean_entropy(probs: torch.Tensor, *, epsilon=1e-12) -> torch.Tensor:
    G = probs.shape[-1]  # Number of outcomes
    probabilities = torch.clamp(probs, min=epsilon)  # Clamp probabilities to avoid log(0)
    log_probs = torch.log(probabilities)  # Compute log probabilities
    entropies = -torch.sum(probabilities * log_probs, dim=-1)  # Compute entropy along the last dimension
    normalized_entropies = entropies / torch.log(torch.tensor(G, dtype=probs.dtype, device=probs.device))  # Normalize by log(G)
    return normalized_entropies.mean()  # Return the mean of normalized entropies

def anglePenalty(thetas:torch.Tensor, min_angle:float=-torch.pi, max_angle:float=torch.pi)->torch.Tensor:
    '''Penalty for the angles in theta being outside the range 
    [min_angle, max_angle]'''
    return (torch.relu(thetas-max_angle)**2 + torch.relu(-thetas + min_angle)**2).sum()
