import logging

import torch
import numpy as np

from .admm_core import ADMM

JSON_CONFIG_KEY = "nxm_config"
DEFAULT_N = 4
DEFAULT_M = 2

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

class NxMADMM(ADMM):
    """
    ADMM optimizer for N:M semi-structured sparsity.
    """

    def __init__(
            self,
            model: torch.nn.Module,
            filename: str,
            rho=0.001
        ):

        logger.info("Initializing NxM ADMM instance")

        super().__init__(
            model,
            filename,
            rho=rho)

        if JSON_CONFIG_KEY in self._raw_dict:
            try:
                self._N = self._raw_dict[JSON_CONFIG_KEY]["N"]
                self._M = self._raw_dict[JSON_CONFIG_KEY]["M"]
            except KeyError as err:
                logger.error("JSON configuration key found, but M/N improperly formatted")
                raise err
        else:
            self._N = DEFAULT_N
            self._M = DEFAULT_M
        
        if self._M > self._N:
            logger.error("M should not be larger than N")
            raise ValueError
        
        logger.info("Using N as {} and M as {}".format(self._N, self._M))

    def project(self, weight: torch.Tensor, name: str):
        """
        N:M semi-structured weight pruning.

        Args:
            weight (torch.Tensor): Weight matrix to project (along with the dual variable)
                                    [out_neurons, in_neurons]
            name (str): name of parameter

        Returns:
            1 (torch.Tensor): mask for nonzero weights used for retraining
            2 (torch.Tensor): real-valued tensor with accurate weight values
        """

        weight = weight.cpu().detach().numpy()
        dual_variable = self._u[name].cpu().detach().numpy()

        # Projection is of W[k+1] + U[k]
        raw_values = weight + dual_variable
        out_neurons, in_neurons = raw_values.shape
        percentile = (100 * self._M) / self._N

        weight_reshaped = raw_values.reshape(out_neurons, -1, self._N)
        weight_temp = np.absolute(weight_reshaped)
        group_percentiles = np.percentile(weight_temp, percentile, axis=-1, keepdims=True)
        weight_mask_reshape = (weight_temp > group_percentiles).astype(np.float32)
        weight_mask = weight_mask_reshape.reshape(out_neurons, in_neurons)

        weight_values = raw_values * weight_mask

        weight_mask_tensor = torch.from_numpy(weight_mask).cuda()
        weight_vals_tensor = torch.from_numpy(weight_values).cuda()

        return weight_mask_tensor, weight_vals_tensor
