import torch
import numpy as np

from rfcutils2.utils import get_demod_soi
import utils.data_transforms as data_transforms 


def cut_to_sync_offset(x, offsets, ber_sync):
    # x has shape [B, s] of complex values.
    # starting pos of x[i] in original signal from the dataset is offsets[i]
    # to compute BER, we need ber_sync to divide starting pos and length
    # so we cut off prefix and suffix to satisfy that
    new_len = x.shape[1] - ber_sync
    start_pos = ber_sync - torch.remainder(offsets, ber_sync)
    slice_idx = start_pos[:, None] + torch.arange(new_len, device=x.device)
    return torch.take_along_dim(x, slice_idx, dim=1)


def compute_ber(pred, target, offsets, ber_sync, soi_type):
    pred = data_transforms.stacked_to_complex(pred)
    target = data_transforms.stacked_to_complex(target)

    if ber_sync != 1:
        pred = cut_to_sync_offset(pred, offsets, ber_sync)
        target = cut_to_sync_offset(target, offsets, ber_sync)

    demod_soi = get_demod_soi(soi_type)

    bit_est, _ = demod_soi(pred.cpu().numpy())
    bit_gt, _ = demod_soi(target.cpu().numpy())
    return np.mean(bit_est.numpy() != bit_gt.numpy())
