import numpy as np


def _flatten(t):
    return t.reshape(-1)


def compute_metrics(preds: np.ndarray,
                    targets: np.ndarray,
                    metrics=('MSE', 'RMSE', 'MAE', 'SI')):
    preds  = _flatten(preds)
    targets = _flatten(targets)
    diff   = preds - targets

    out = {}
    if 'MSE' in metrics:
        out['MSE']  = float(np.mean(diff ** 2))
    if 'RMSE' in metrics:
        out['RMSE'] = float(np.sqrt(out.get('MSE', np.mean(diff ** 2))))
    if 'MAE' in metrics:
        out['MAE']  = float(np.mean(np.abs(diff)))
    if 'SI' in metrics:
        rmse = out.get('RMSE', np.sqrt(np.mean(diff ** 2)))
        out['SI']   = float(rmse / (np.mean(targets) + 1e-8))
    return out


def compute_global_mean(train_loader):
    total = 0.0
    count = 0
    for _, targets, _ in train_loader:          # targets: (B,S,1)
        total  += targets.sum().item()
        count  += targets.numel()
    return total / count
