import logging

import torch
import numpy as np

from .admm_core import ADMM

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

class UnstructuredADMM(ADMM):
    """
    ADMM optimizer for unstructured sparsity.
    """

    def __init__(
            self,
            model: torch.nn.Module,
            filename: str,
            rho=0.001
        ):

        logger.info("Initializing Unstructured ADMM instance")

        super().__init__(
            model,
            filename,
            rho=rho)

    def project(self, weight: torch.Tensor, name: str):
        """
        Unstructured weight pruning.

        Args:
            weight (pytorch.Tensor): weight tensor, assumed to be [out_neurons, in_neurons]
            name (str): name of parameter

        Returns:
            1 (torch.Tensor): mask for nonzero weights used for retraining
            2 (torch.Tensor): real-valued tensor with accurate weight values
        """
        prune_ratio = self.get_prune_ratio(name)
        weight = weight.cpu().detach().numpy()
        dual_variable = self._u[name].cpu().detach().numpy()

        z_temp = np.absolute(weight + dual_variable)
        percentile = np.percentile(z_temp, prune_ratio * 100)

        z_mask = (z_temp > percentile).astype(np.float32)
        z = (weight + dual_variable) * z_mask

        z_mask_tensor = torch.from_numpy(z_mask).cuda()
        z_tensor = torch.from_numpy(z).cuda()

        return z_mask_tensor, z_tensor
