
import torch


def convert(
        pred_x_acti: torch.Tensor,
        pred_y_acti: torch.Tensor, 
        n_bit_integer: int, 
        n_bit_fractional: int, 
        narrow_ratio: float = 0.25
    ):
    device = pred_x_acti.device

    pred_x_acti = pred_x_acti.float()
    pred_y_acti = pred_y_acti.float()

    pred_x_bin = (pred_x_acti >= 0.5).float()
    pred_y_bin = (pred_y_acti >= 0.5).float()

    pred_x_bin_int  = pred_x_bin[:, :, :n_bit_integer]
    pred_x_bin_frac = pred_x_bin[:, :, n_bit_integer:]

    pred_y_bin_int  = pred_y_bin[:, :, :n_bit_integer]
    pred_y_bin_frac = pred_y_bin[:, :, n_bit_integer:]

    int_powers = 2 ** torch.arange(n_bit_integer - 1, -1, -1, dtype=torch.float32, device=device)
    frac_powers = 2 ** -torch.arange(1, n_bit_fractional + 1, dtype=torch.float32, device=device)

    pred_h_x_int = torch.sum(pred_x_bin_int * int_powers, dim=2)
    pred_h_x_frac = torch.sum(pred_x_bin_frac * frac_powers, dim=2)
    pred_h_x = pred_h_x_int + pred_h_x_frac

    pred_h_y_int = torch.sum(pred_y_bin_int * int_powers, dim=2)
    pred_h_y_frac = torch.sum(pred_y_bin_frac * frac_powers, dim=2)
    pred_h_y = pred_h_y_int + pred_h_y_frac

    pred = torch.stack([pred_h_x, pred_h_y], dim=-1) / narrow_ratio

    prob_x = torch.abs(2*pred_x_acti - 1.0).pow(0.5).mean(dim=2)
    prob_y = torch.abs(2*pred_y_acti - 1.0).pow(0.5).mean(dim=2)
    prob = ((prob_x + prob_y)*0.5).unsqueeze(2)

    return pred, prob


