from numpy import mean
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


@torch.no_grad()
def part_mean(tensor, op="-"):
    non_zero = tensor * (tensor != 0)
    mean_val = non_zero.mean(-1).view(-1, 1)

    return mean_val


@torch.no_grad()
def high_order_residual(x, mask, order=2):
    sum_order = torch.zeros_like(x)
    new_matrix = x.clone()
    new_matrix = new_matrix * mask
    for od in range(order):
        residual = new_matrix - sum_order
        masked_x_tensor = torch.where(mask, residual, torch.tensor(float("nan")))

        mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
        mean_tensor_all = torch.where(
            torch.isnan(mean_tensor_all),
            torch.zeros_like(mean_tensor_all),
            mean_tensor_all,
        )
        masked_x_tensor -= mean_tensor_all[:, None]
        scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
        scale_tensor_all = torch.where(
            torch.isnan(scale_tensor_all),
            torch.zeros_like(scale_tensor_all),
            scale_tensor_all,
        )

        binary = torch.sign(masked_x_tensor)
        binary *= scale_tensor_all[:, None]
        binary += mean_tensor_all[:, None]
        sum_order = sum_order + binary * mask

    return sum_order


@torch.no_grad()
def normal_quantize(x, scale, zero, maxq):
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)


class Binarization(nn.Module):
    def __init__(self, weight, method="2bit", groupsize=-1):
        super().__init__()
        oc, ic = weight.shape
        if groupsize == -1:
            groupsize = ic
        self.groupsize = groupsize
        self.n_groups = math.ceil(ic / groupsize)
        self.method = method
        self.mean = 0

    def quantize(self, w, mask, order=2, groupi=0):
        if self.method == "xnor":
            w_mean = self.mean[groupi]
            w = w - w_mean
            w = w.sign()
            w = w * self.scale[groupi]
            w += w_mean
        elif self.method == "braq":
            w = high_order_residual(w, mask, order=order)
        elif self.method == "sign":
            w = (w > 0).float()
            w *= self.scale[groupi]
        elif self.method == "rtn":
            w = F.relu(w)
            w_int = (w / self.scale[groupi]).round().clamp(0, 1)
            w = w_int * self.scale[groupi]

        elif self.method in ["2bit", "3bit", "4bit"]:
            bits = int(self.method[0])
            perchannel = True
            weight = True
            dev = w.device
            maxq = torch.tensor(2**bits - 1)
            scale = torch.zeros(1)
            zero = torch.zeros(1)

            if dev != scale.device:
                scale = scale.to(dev)
                zero = zero.to(dev)
                maxq = maxq.to(dev)

            x = w.clone()
            shape = x.shape

            if perchannel:
                if weight:
                    x = x.flatten(1)
                else:
                    if len(shape) == 4:
                        x = x.permute([1, 0, 2, 3])
                        x = x.flatten(1)
                    if len(shape) == 3:
                        x = x.reshape((-1, shape[-1])).t()
                    if len(shape) == 2:
                        x = x.t()
            else:
                x = x.flatten().unsqueeze(0)
            tmp = torch.zeros(x.shape[0], device=dev)
            xmin = torch.minimum(x.min(1)[0], tmp)
            xmax = torch.maximum(x.max(1)[0], tmp)

            tmp = (xmin == 0) & (xmax == 0)
            xmin[tmp] = -1
            xmax[tmp] = +1
            scale = (xmax - xmin) / maxq
            zero = torch.round(-xmin / scale)
            if not perchannel:
                if weight:
                    tmp = shape[0]
                else:
                    tmp = shape[1] if len(shape) != 3 else shape[2]
                scale = scale.repeat(tmp)
                zero = zero.repeat(tmp)
            if weight:
                shape = [-1] + [1] * (len(shape) - 1)
                scale = scale.reshape(shape)
                zero = zero.reshape(shape)
            w = normal_quantize(w, scale, zero, maxq)

        elif self.method == "prune":
            return torch.zeros_like(w)
        elif self.method == "ternary":
            threshold = 0.7 * torch.mean(torch.abs(w))
            w = torch.sign(w) * (torch.abs(w) > threshold).float()
        return w
