from .utils import argmax_k, argmax_k_sorted
import numpy as np


def flip(W: np.ndarray, idx: int, e_positive: bool) -> np.ndarray:
    """
    given a set a scalars W[idx] if the accumulated perturbations was positive
    we substract 1 which recreases the accumulated perturbations value and vice versa
    for a negative accumulated perturbations
    """
    if e_positive:
        W[idx] -= 1
    else:
        W[idx] += 1
    return W


def SQuant_Flip_Algorithm(
    weights: np.ndarray, perturbations: np.ndarray, bits: int
) -> np.ndarray:
    """
    The SQuant Flip Algorithm is the key step of the proposed method from ICLR 2022
    The idea consists in minimizing the sum of the perturbations instead of each individual perturbations

    This function implements this step
    """
    quantization_range = (2 ** (bits - 1)) - 1
    original_shape = perturbations.shape
    weights = np.reshape(weights, newshape=(np.prod(original_shape),))
    perturbations = np.reshape(perturbations, newshape=(np.prod(original_shape),))
    accumulated_perturbations = np.sum(perturbations)
    perturbations[perturbations * accumulated_perturbations < 0] = 0
    num_elements_to_flip = int(np.round(np.abs(accumulated_perturbations)))
    if num_elements_to_flip > 0:
        indices_to_flip = argmax_k(array=np.abs(perturbations), k=num_elements_to_flip)
        flipped_weights = flip(
            W=weights, idx=indices_to_flip, e_positive=(accumulated_perturbations > 0)
        )
        flipped_weights = np.clip(
            a=flipped_weights, a_min=-quantization_range, a_max=quantization_range
        )
    else:
        flipped_weights = weights
    flipped_weights = np.reshape(flipped_weights, newshape=(original_shape))
    return flipped_weights


def Perturbation_Update_Algorithm(perturbations: np.ndarray) -> np.ndarray:
    """
    This step updates the candidates for flipping at the channel level
    """
    original_shape = perturbations.shape
    perturbations = np.reshape(perturbations, newshape=(np.prod(original_shape),))
    accumulated_perturbations = np.sum(perturbations)
    perturbations[perturbations * accumulated_perturbations < 0] = 0
    num_elements_to_flip = int(np.round(np.abs(accumulated_perturbations)))
    if num_elements_to_flip > 0:
        if num_elements_to_flip > np.abs(accumulated_perturbations):
            i = argmax_k_sorted(array=np.abs(perturbations), k=num_elements_to_flip)[-1]
            v = perturbations[i]
        else:
            i = argmax_k_sorted(
                array=np.abs(perturbations), k=num_elements_to_flip + 1
            )[-1]
            v = perturbations[i]
        perturbations = np.zeros(perturbations.shape)
        perturbations[i] = v
    else:
        perturbations = np.zeros(perturbations.shape)
    perturbations = np.reshape(perturbations, newshape=(original_shape))
    return perturbations


def k_largest_index_argpartition_v2(a: np.ndarray, k: int):
    idx = np.argpartition(a.ravel(), a.size - k)[-k:]
    return np.column_stack(np.unravel_index(idx, a.shape))
