from .utils import basic_Q
from .flip import SQuant_Flip_Algorithm, Perturbation_Update_Algorithm
import numpy as np


def process_conv_weight(W: np.ndarray, scale: np.ndarray, bits: int) -> np.ndarray:
    """
    Loop over the dimension of the convolution kernel and perform the SQuant algorithm
    """
    K1, K2, Ci, Co = W.shape
    new_W = np.zeros((K1, K2, Ci, Co))
    E = np.zeros((K1, K2, Ci, Co))
    F = np.zeros((K1, K2, Ci, Co))
    delta_E = np.zeros((K1, K2, Ci, Co))
    delta_F = np.zeros((K1, K2, Ci, Co))
    for co in range(Co):
        for ci in range(Ci):
            for i in range(K1):
                for j in range(K2):
                    E[i, j, ci, co] = basic_Q(W=W[i, j, ci, co], s=scale[co])
                    delta_E[i, j, ci, co] = (
                        E[i, j, ci, co] - W[i, j, ci, co] * scale[co]
                    )
            update = SQuant_Flip_Algorithm(
                weights=E[:, :, ci, co], perturbations=delta_E[:, :, ci, co], bits=bits
            )
            F[:, :, ci, co] = update
            delta_F[:, :, ci, co] = Perturbation_Update_Algorithm(
                perturbations=delta_E[:, :, ci, co]
            )
        new_W[:, :, :, co] = SQuant_Flip_Algorithm(
            weights=F[:, :, :, co], perturbations=delta_F[:, :, :, co], bits=bits
        )
    return new_W
