import numpy as np
import torch as tc

def l2_err(true, pred, verbose=True, prefix="", eps=1e-30) -> None | float:
    assert true.shape == pred.shape, "Shape mismatch between true and pred."
    norm_true = np.linalg.norm(true)
    norm_true = max(norm_true, eps) # avoid zero-devision
    err = np.linalg.norm(true - pred) / norm_true
    if verbose: print(f"-> {prefix}{err:.2e}")
    else: return err.item()

def l2_err_TC(true: tc.Tensor, pred: tc.Tensor, verbose=True, prefix="L2_REL    ", eps=1e-30):
    assert true.shape == pred.shape, "Shape mismatch between true and pred."
    norm_true = tc.norm(true)
    norm_true = tc.maximum(norm_true, tc.tensor(eps, dtype=true.dtype, device=true.device))
    err = tc.norm(true - pred) / norm_true
    if verbose: print(f"-> {prefix}{err:.2e}")
    else: return err.item()

def mse(true, pred, verbose=True, prefix="") -> None | float:
    assert true.shape == pred.shape, "Shape mismatch between true and pred."
    err =  np.mean(np.square(true - pred))
    if verbose: print(f"-> {prefix}{err:.2e}")
    else: return err.item()

def mse_TC(true: tc.Tensor, pred: tc.Tensor, verbose=True, prefix="MSE      ") -> None | float:
    assert true.shape == pred.shape, "Shape mismatch between true and pred."
    err = (true - pred).pow(2).mean()
    if verbose: print(f"-> {prefix}{err:.2e}")
    else: return err.item()
