from collections import namedtuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import InplaceFunction, Function

import math
import numpy as np
import scipy.io as sio

# import pytorch_minimax


QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits', 'max_value'])

_DEFAULT_FLATTEN = (1, -1)
_DEFAULT_FLATTEN_GRAD = (0, -1)


class Quantizer:
    def __init__(self, x, num_bits, flatten_dims, signed, stochastic, permute=None, depermute=None, is_zero_point=True):
        self.x_shape = x.shape
        self.x_shape_p = x.shape
        self.num_bits = num_bits
        self.flatten_dims = flatten_dims
        self.signed = signed
        self.stochastic = stochastic
        self.permute = permute
        self.depermute = depermute
        self.is_zero_point = is_zero_point

        self.x = self.flatten(x)
        self.Tx = self.transform(self.x)

    def flatten(self, x):
        if self.permute is not None:
            x = x.permute(self.permute)
            self.x_shape_p = x.shape
        if self.flatten_dims[0] == len(x.shape):
            Tx = x
        else:
            Tx = x.flatten(*self.flatten_dims)
        self.x_shape2 = Tx.shape
        if self.flatten_dims[0] == 2:
            Tx = Tx.reshape(x.shape[0] * x.shape[1], -1)
        elif self.flatten_dims[0] == 3:
            Tx = Tx.view(x.shape[0] * x.shape[1] * x.shape[2], -1)
        return Tx.contiguous()
        
    def deflatten(self, Tx):
        if self.flatten_dims[0] == 2:
            x = Tx.view(*self.x_shape2)
        else:
            x = Tx
        if self.permute is not None:
            x = x.view(*self.x_shape_p)
            if self.depermute is not None:
                x = x.permute(self.depermute)
            else:
                x = x.permute(self.permute)
        else:
            x = x.view(*self.x_shape)
        return x
    
    def forward(self):
        return self.Tx

    def inverse(self, Tx):
        x = self.inverse_transform(Tx)
        return self.deflatten(x)

    def transform(self, x):
        with torch.no_grad():
            if len(x.shape) >= 2:
                # min_values = pytorch_minimax.min(x).unsqueeze(1) - 1e-8
                # max_values = pytorch_minimax.max(x).unsqueeze(1) + 1e-8
                min_values = torch.min(x, -1).values.unsqueeze(1) - 1e-8
                max_values = torch.max(x, -1).values.unsqueeze(1) + 1e-8
            else:
                min_values = x
                max_values = x

        if self.is_zero_point == False:
            max_values = (max_values - min_values) / 2
            min_values = - max_values

        qmin = -(2. ** (self.num_bits - 1)) if self.signed else 0.
        qmax = qmin + 2. ** self.num_bits - 1.
        self.qmin = qmin

        self.zero_point = min_values
        scale = (max_values - min_values) / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        self.scale = torch.max(scale, min_scale)

        with torch.no_grad():
            if self.is_zero_point:
                x.add_(qmin * self.scale - self.zero_point).div_(self.scale)
            else:
                x.div_(self.scale)

            if self.stochastic:
                noise = x.new(x.shape).uniform_(-0.5, 0.5)
                x.add_(noise)

        return x.clamp_(qmin, qmax).round_()

    def inverse_transform(self, x):
        if self.is_zero_point:
            x.mul_(self.scale).add_(
                self.zero_point - self.qmin * self.scale)
        else:
            x.mul_(self.scale)
        return x

class DivisionQuantizer(Quantizer):
    def __init__(self, x, num_bits, signed, stochastic, groups, permute):
        self.groups = groups
        super(DivisionQuantizer, self).__init__(x, num_bits, (1, -1), signed, stochastic, permute)
    
    def flatten(self, x):
        x = x.permute(self.permute)
        self.x_shape_p = x.shape
        Tx = x.flatten(*self.flatten_dims)
        self.x_shape2 = Tx.shape
        if self.flatten_dims[0] == 2:
            Tx = Tx.view(x.shape[0] * x.shape[1], -1)
        return Tx.contiguous()
        
    def deflatten(self, Tx):
        if self.flatten_dims[0] == 2:
            x = Tx.view(*self.x_shape2)
        else:
            x = Tx
        x = x.view(*self.x_shape_p)
        x = x.permute(self.permute)
        return x

    def transform(self, x):
        with torch.no_grad():
            # min_values = pytorch_minimax.min(x).unsqueeze(1) - 1e-8
            # max_values = pytorch_minimax.max(x).unsqueeze(1) + 1e-8
            min_values = torch.min(x, -1).values.unsqueeze(1) - 1e-8
            max_values = torch.max(x, -1).values.unsqueeze(1) + 1e-8
            min_values, max_values = self.get_transform(min_values, max_values)

        qmin = -(2. ** (self.num_bits - 1)) if self.signed else 0.
        qmax = qmin + 2. ** self.num_bits - 1.
        self.qmin = qmin

        self.zero_point = min_values
        scale = (max_values - min_values) / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        self.scale = torch.max(scale, min_scale)

        with torch.no_grad():
            x.add_(qmin * self.scale - self.zero_point).div_(self.scale)
            if self.stochastic:
                noise = x.new(x.shape).uniform_(-0.5, 0.5)
                x.add_(noise)

        return x.clamp_(qmin, qmax).round_()
    
    def get_transform(self, min_values, max_values):
        mn_divisionquant = min_values
        mx_divisionquant = max_values
        
        mn = min_values.min()
        mx = max_values.max()

        for i in range(self.groups):
            if i == self.groups - 1:
                lower_bound = max_values.min()
            else:
                lower_bound = mx * 2 ** -(i + 1)
            up_bound = mx * 2 ** -i
            condition = (max_values >= lower_bound) & (max_values <= up_bound)
            mn_divisionquant = torch.where(condition, mn * 2 ** -i * torch.ones_like(max_values).to(device=max_values.device), mn_divisionquant)
            mx_divisionquant = torch.where(condition, mx * 2 ** -i * torch.ones_like(max_values).to(device=max_values.device), mx_divisionquant)
        
        return mn_divisionquant, mx_divisionquant
    
    def get_transform_rank(self, min_values, max_values, threshold=0.05):
        mn_divisionquant = min_values
        mx_divisionquant = max_values
        
        mn = min_values.min()
        mx = max_values.max()

        rank, _ = torch.sort(max_values, descending=True)
        boundary = mx * threshold
        m = (rank < boundary).sum().item()
        n = len(max_values)
        n_ = n
        m_ = m
        num_elements_group = n // self.groups
        up_bound = 0
        low_bound = mx
        lup_bound = max_values.min()
        llow_bound = 0
        n_up = 0
        n_low = 0

        for i in range(self.groups):
            # divide
            n_i = math.ceil((n_ - m_) / (self.groups - i))
            n_extra_i = num_elements_group - n_i
            n_ = n_ - m_
            m_ = n_i
            n_up = n_up + n_i
            n_low = n_low + n_extra_i
            # decide the boundarys
            up_bound = low_bound
            low_bound = rank[n_up - 1]
            llow_bound = lup_bound
            lup_bound = rank[- n_low - 1]
            # calculate scales
            if i == 0:
                condition_0 = (max_values >= low_bound) & (max_values <= up_bound)
            else:
                condition_0 = (max_values > low_bound) & (max_values <= up_bound)
            condition_1 = (max_values > llow_bound) & (max_values <= lup_bound)
            condition = condition_0 | condition_1
            mn_divisionquant = torch.where(condition, mn * 2 ** -i * torch.ones_like(max_values).to(device=max_values.device), mn_divisionquant)
            mx_divisionquant = torch.where(condition, mx * 2 ** -i * torch.ones_like(max_values).to(device=max_values.device), mx_divisionquant)
        
        return mn_divisionquant, mx_divisionquant


class PersampleDivisionQuantizer(Quantizer):
    def __init__(self, x, num_bits, signed, stochastic, groups):
        self.groups = groups
        super(PersampleDivisionQuantizer, self).__init__(x, num_bits, (2, -1), signed, stochastic)

    def transform(self, x):
        with torch.no_grad():
            # min_values = pytorch_minimax.min(x).reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            # max_values = pytorch_minimax.max(x).reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            # min_values_channel = pytorch_minimax.min(min_values.permute(1, 0).contiguous())
            # max_values_channel = pytorch_minimax.max(max_values.permute(1, 0).contiguous())
            min_values = torch.min(x, -1).values.reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            max_values = torch.max(x, -1).values.reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            min_values_channel = torch.min(min_values.permute(1, 0).contiguous(), -1).values
            max_values_channel = torch.max(max_values.permute(1, 0).contiguous(), -1).values
            min_values, max_values = self.get_transform(min_values_channel, max_values_channel, min_values, max_values)
            

        qmin = -(2. ** (self.num_bits - 1)) if self.signed else 0.
        qmax = qmin + 2. ** self.num_bits - 1.
        self.qmin = qmin

        self.zero_point = min_values
        scale = (max_values - min_values) / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        self.scale = torch.max(scale, min_scale)

        with torch.no_grad():
            x.add_(qmin * self.scale - self.zero_point).div_(self.scale)
            if self.stochastic:
                noise = x.new(x.shape).uniform_(-0.5, 0.5)
                x.add_(noise)

        return x.clamp_(qmin, qmax).round_()
    
    def get_transform(self, min_values_channel, max_values_channel, min_values, max_values):
        mn_groups = torch.zeros([self.groups, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_channel.min()
        mx = max_values_channel.max()

        for i in range(self.groups):
            if i == self.groups - 1:
                lower_bound = max_values_channel.min()
            else:
                lower_bound = mx * 2 ** -(i + 1)
            up_bound = mx * 2 ** -i
            condition = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) >= lower_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            # mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            # mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
            mn_groups[i] = torch.where(condition, torch.min(mn_groups[i], -1).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, torch.max(mx_groups[i], -1).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
        
        mn = torch.sum(mn_groups, 0).view(-1).unsqueeze(1)
        mx = torch.sum(mx_groups, 0).view(-1).unsqueeze(1)
        
        return mn, mx

    def get_transform_rank(self, min_values_channel, max_values_channel, min_values, max_values, threshold=0.0):
        mn_groups = torch.zeros([self.groups, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_channel.min()
        mx = max_values_channel.max()

        rank, _ = torch.sort(max_values_channel, descending=True)
        boundary = mx * threshold
        m = (rank < boundary).sum().item()
        n = len(max_values_channel)
        n_ = n
        m_ = m
        num_elements_group = n // self.groups
        up_bound = 0
        low_bound = mx
        lup_bound = max_values_channel.min()
        llow_bound = 0
        n_up = 0
        n_low = 0

        for i in range(self.groups):
            # divide
            n_i = math.ceil((n_ - m_) / (self.groups - i))
            n_extra_i = num_elements_group - n_i
            n_ = n_ - m_
            m_ = n_i
            n_up = n_up + n_i
            n_low = n_low + n_extra_i
            # decide the boundarys
            up_bound = low_bound
            low_bound = rank[n_up - 1]
            llow_bound = lup_bound
            lup_bound = rank[- n_low - 1]
            # calculate scales
            if i == 0:
                condition_0 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) >= low_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            else:
                condition_0 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) > low_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            condition_1 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) > llow_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= lup_bound)
            condition = condition_0 | condition_1
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            # mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            # mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
            mn_groups[i] = torch.where(condition, torch.min(mn_groups[i], -1).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, torch.max(mx_groups[i], -1).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])

            # print('i: {}, n_i: {}, n_extra_i: {}, n_condition: {}, n_up: {}, n_low: {}, up_bound:{}, low_bound: {}, lup_bound: {},  llow_bound: {}, mx: {}, mn: {}'.format(i, n_i, n_extra_i, torch.sum(condition)/min_values.shape[0], n_up, n_low, up_bound, low_bound, lup_bound, llow_bound, mx, max_values_channel.min()))
        
        mn = torch.sum(mn_groups, 0).view(-1).unsqueeze(1)
        mx = torch.sum(mx_groups, 0).view(-1).unsqueeze(1)
        
        return mn, mx


class DivisionShiftQuantizer2D(Quantizer):
    def __init__(self, x, num_bits, signed, stochastic, groups_s, groups_c, permute=None):
        self.groups_s = groups_s
        self.groups_c = groups_c
        super(DivisionShiftQuantizer2D, self).__init__(x, num_bits, (2, -1), signed, stochastic, permute)

    def transform(self, x):
        with torch.no_grad():
            min_values = torch.min(x, -1).values.reshape(self.x_shape_p[0], self.x_shape_p[1]).contiguous()
            max_values = torch.max(x, -1).values.reshape(self.x_shape_p[0], self.x_shape_p[1]).contiguous()
            min_values, max_values = self.get_transform_channel(min_values, max_values)
            min_values, max_values = min_values.view(-1).unsqueeze(1), max_values.view(-1).unsqueeze(1)
            
        qmin = -(2. ** (self.num_bits - 1)) if self.signed else 0.
        qmax = qmin + 2. ** self.num_bits - 1.
        self.qmin = qmin

        self.zero_point = min_values
        scale = (max_values - min_values) / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        self.scale = torch.max(scale, min_scale)

        with torch.no_grad():
            # x.add_(qmin * self.scale - self.zero_point).div_(self.scale)
            x.div_(self.scale)
            if self.stochastic:
                noise = x.new(x.shape).uniform_(-0.5, 0.5)
                x.add_(noise)

        return x.clamp_(qmin, qmax).round_()
    
    def inverse_transform(self, x):
        # x.mul_(self.scale).add_(
        #     self.zero_point - self.qmin * self.scale)
        x.mul_(self.scale)
        return x
    
    def get_transform_channel(self, min_values, max_values):
        # divide based on max_values
        scale = max_values - min_values
        scale_sample = torch.max(scale, 1).values

        temp_max = max_values
        temp_min = min_values
        scale_new = scale

        for i in range(self.groups_c):
            up_state = (scale >= scale_sample[:, None] / 2 ** (i + 1))
            low_state = (scale < scale_sample[:, None] / 2 ** i)
            scale_bound = scale_sample[:, None] / 2 ** i
            if i == 0:
                state = up_state
            elif i == self.groups_c - 1:
                state = low_state
            else:
                state = up_state & low_state
            
            scale_new = torch.where(state, scale_bound, scale_new)

        # temp_min = temp_max - scale_new
        
        temp_max = scale_new / 2
        temp_min = - temp_max
            
         # divide channels
        # temp_max = torch.where(max_values >= torch.max(max_values, 1).values[:, None] / 2, 
        #                             torch.max(max_values, 1).values[:, None].repeat([1, self.x_shape2[1]]), max_values)
        # temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, None] / 4) & 
        #                        (max_values < torch.max(max_values, 1).values[:, None] / 2), 
        #                             torch.max(max_values, 1).values[:, None].repeat([1, self.x_shape2[1]]) / 2, temp_max)
        # temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, None] / 8) & 
        #                        (max_values < torch.max(max_values, 1).values[:, None] / 4), 
        #                             torch.max(max_values, 1).values[:, None].repeat([1, self.x_shape2[1]]) / 4, temp_max)
        # temp_max = torch.where(max_values < torch.max(max_values, 1).values[:, None] / 8, 
        #                             torch.max(max_values, 1).values[:, None].repeat([1, self.x_shape2[1]]) / 8, temp_max)
        
        # state = max_values >= torch.max(max_values, 1).values[:, None] / 2
        # min_values_g = min_values
        # min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        # temp_zero_point = torch.where(state, 
        #                             torch.min(min_zero_point, 1).values[:, None].repeat([1, self.x_shape2[1]]), 
        #                             min_zero_point)
        
        # state = (max_values >= torch.max(max_values, 1).values[:, None] / 4) & (max_values < torch.max(max_values, 1).values[:, None] / 2)
        # min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        # temp_zero_point = torch.where(state, 
        #                             torch.min(min_zero_point, 1).values[:, None].repeat([1, self.x_shape2[1]]), 
        #                             temp_zero_point)
        
        # state = (max_values >= torch.max(max_values, 1).values[:, None] / 8) & (max_values < torch.max(max_values, 1).values[:, None] / 4)
        # min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        # temp_zero_point = torch.where(state, 
        #                             torch.min(min_zero_point, 1).values[:, None].repeat([1, self.x_shape2[1]]), 
        #                             temp_zero_point)
        
        # state = max_values < torch.max(max_values, 1).values[:, None] / 8
        # min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        # temp_zero_point = torch.where(state,
        #                             torch.min(min_zero_point, 1).values[:, None].repeat([1, self.x_shape2[1]]), 
        #                             temp_zero_point)
        
        return temp_min, temp_max



class DivisionQuantizer2D(Quantizer):
    def __init__(self, x, num_bits, signed, stochastic, groups_s, groups_c, permute=None):
        self.groups_s = groups_s
        self.groups_c = groups_c
        super(DivisionQuantizer2D, self).__init__(x, num_bits, (2, -1), signed, stochastic, permute)

    def transform(self, x):
        with torch.no_grad():
            # min_values = pytorch_minimax.min(x).reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            # max_values = pytorch_minimax.max(x).reshape(self.x_shape[0], self.x_shape[1]).contiguous()
            # min_values_sample = pytorch_minimax.min(min_values)
            # max_values_sample = pytorch_minimax.min(max_values)
            # min_values_channel = pytorch_minimax.min(min_values.permute(1, 0).contiguous())
            # max_values_channel = pytorch_minimax.min(max_values.permute(1, 0).contiguous())
            min_values = torch.min(x, -1).values.reshape(self.x_shape_p[0], self.x_shape_p[1]).contiguous()
            max_values = torch.max(x, -1).values.reshape(self.x_shape_p[0], self.x_shape_p[1]).contiguous()
            # min_values_sample = torch.min(min_values, -1).values
            # max_values_sample = torch.max(max_values, -1).values
            min_values_channel = torch.min(min_values.permute(1, 0).contiguous(), -1).values
            max_values_channel = torch.max(max_values.permute(1, 0).contiguous(), -1).values
            # sample division
            # min_values, max_values = self.get_transform_sample(min_values_sample, max_values_sample, min_values, max_values)
            # # channel division
            min_values, max_values = self.get_transform_channel(min_values_channel, max_values_channel, min_values, max_values)
            min_values, max_values = min_values.view(-1).unsqueeze(1), max_values.view(-1).unsqueeze(1)
            
        qmin = -(2. ** (self.num_bits - 1)) if self.signed else 0.
        qmax = qmin + 2. ** self.num_bits - 1.
        self.qmin = qmin

        self.zero_point = min_values
        scale = (max_values - min_values) / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        self.scale = torch.max(scale, min_scale)

        with torch.no_grad():
            x.add_(qmin * self.scale - self.zero_point).div_(self.scale)
            if self.stochastic:
                noise = x.new(x.shape).uniform_(-0.5, 0.5)
                x.add_(noise)

        return x.clamp_(qmin, qmax).round_()
    
    def get_transform_channel(self, min_values_channel, max_values_channel, min_values, max_values):
        mn_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_channel.min()
        mx = max_values_channel.max()

        for i in range(self.groups_c):
            if i == self.groups_c - 1:
                lower_bound = max_values_channel.min()
            else:
                lower_bound = mx * 2 ** -(i + 1)
            up_bound = mx * 2 ** -i
            condition = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) >= lower_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            # mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            # mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
            mn_groups[i] = torch.where(condition, torch.min(mn_groups[i], -1).values.unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, torch.max(mx_groups[i], -1).values.unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
        
        mn = torch.sum(mn_groups, 0)
        mx = torch.sum(mx_groups, 0)
        
        return mn, mx

    def get_transform_sample(self, min_values_sample, max_values_sample, min_values, max_values):
        min_values = min_values.permute(1, 0).contiguous()
        max_values = max_values.permute(1, 0).contiguous()
        mn_groups = torch.zeros([self.groups_s, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups_s, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_sample.min()
        mx = max_values_sample.max()

        for i in range(self.groups_s):
            if i == self.groups_s - 1:
                lower_bound = max_values_sample.min()
            else:
                lower_bound = mx * 2 ** -(i + 1)
            up_bound = mx * 2 ** -i
            condition = (max_values_sample[None, :].repeat([min_values.shape[0], 1]) >= lower_bound) & (max_values_sample[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
        
        mn = torch.sum(mn_groups, 0)
        mx = torch.sum(mx_groups, 0)
        
        return mn.permute(1, 0).contiguous(), mx.permute(1, 0).contiguous()

    def get_transform_equal_channel(self, min_values_channel, max_values_channel, min_values, max_values):
        mn_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_channel.min()
        mx = max_values_channel.max()

        for i in range(self.groups_c):
            if i == self.groups_c - 1:
                lower_bound = max_values_channel.min()
            else:
                lower_bound = mx / self.groups_c * (self.groups_c - i - 1)
            up_bound = mx / self.groups_c * (self.groups_c - i)
            condition = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) >= lower_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])
        
        mn = torch.sum(mn_groups, 0)
        mx = torch.sum(mx_groups, 0)
        
        return mn, mx
    
    def get_transform_rank_channel(self, min_values_channel, max_values_channel, min_values, max_values, threshold=0.0):
        mn_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        mx_groups = torch.zeros([self.groups_c, min_values.shape[0], min_values.shape[1]]).to(device=min_values.device)
        
        mn = min_values_channel.min()
        mx = max_values_channel.max()

        rank, _ = torch.sort(max_values_channel, descending=True)
        boundary = mx * threshold
        m = (rank < boundary).sum().item()
        n = len(max_values_channel)
        n_ = n
        m_ = m
        num_elements_group = n // self.groups_c
        up_bound = 0
        low_bound = mx
        lup_bound = max_values_channel.min()
        llow_bound = 0
        n_up = 0
        n_low = 0

        for i in range(self.groups_c):
            # divide
            n_i = math.ceil((n_ - m_) / (self.groups_c - i))
            n_extra_i = num_elements_group - n_i
            n_ = n_ - m_
            m_ = n_i
            n_up = n_up + n_i
            n_low = n_low + n_extra_i
            # decide the boundarys
            up_bound = low_bound
            low_bound = rank[n_up - 1]
            llow_bound = lup_bound
            lup_bound = rank[- n_low - 1]
            # calculate scales
            if i == 0:
                condition_0 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) >= low_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            else:
                condition_0 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) > low_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= up_bound)
            condition_1 = (max_values_channel[None, :].repeat([min_values.shape[0], 1]) > llow_bound) & (max_values_channel[None, :].repeat([min_values.shape[0], 1]) <= lup_bound)
            condition = condition_0 | condition_1
            mn_groups[i] = torch.where(condition, min_values, mn_groups[i])
            mx_groups[i] = torch.where(condition, max_values, mx_groups[i])
            mn_groups[i] = torch.where(condition, pytorch_minimax.min(mn_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mn_groups[i])
            mx_groups[i] = torch.where(condition, pytorch_minimax.max(mx_groups[i]).unsqueeze(1).repeat([1, min_values.shape[1]]), mx_groups[i])

            # print('i: {}, n_i: {}, n_extra_i: {}, n_condition: {}, n_up: {}, n_low: {}, up_bound:{}, low_bound: {}, lup_bound: {},  llow_bound: {}, mx: {}, mn: {}'.format(i, n_i, n_extra_i, torch.sum(condition)/min_values.shape[0], n_up, n_low, up_bound, low_bound, lup_bound, llow_bound, mx, max_values_channel.min()))
        
        mn = torch.sum(mn_groups, 0).view(-1).unsqueeze(1)
        mx = torch.sum(mx_groups, 0).view(-1).unsqueeze(1)
        
        return mn, mx