"""
File for the individual nodes of the
UniQuant Algorithm and the ChocoSGD algorithm
"""
import NueralNet as OurNNUtils
import torchvision.models as models
import numpy as np
import torch
import copy
import math
import torch.nn as nn

class MQNode:

    def __init__(self, data_loader=None, test_dl=None, lr=.001, id=-1,
                 nn_type="Full", q_scalar=.9999, optim="sgd", bits=0,
                 dith=0, threshold=0, reg='l1', quantize_loop=None, l=1,
                 alpha=1,eps=10e-5,lr_const=True, device_str='cuda', E=1,tr_decay=1):

        self.lr_const=lr_const
        # can pass a quantization method
        self.quantize_loop = quantize_loop
        self.tr_decay=tr_decay
        self.E=E 
        self.reg = reg
        self.loss_v =0

        self.train_loss =0 

        # Dataloader for the Training phase
        self.data_loader = data_loader

        # Learning rate for sgd
        self.lr = lr

        # id identifying the individual node
        self.id = id

        # underlying nn structure
        self.nn = self.createNeuralNetwork(nn_type)
        self.device=torch.device('cpu')
        if device_str=='mps':
            if not torch.backends.mps.is_available():

                if not torch.backends.mps.is_built():
                    print("MPS not available because the current PyTorch install was not "
                        "built with MPS enabled.")
                else:
                    print("MPS not available because the current MacOS version is not 12.3+ "
                        "and/or you do not have an MPS-enabled device on this machine.")
                mps_device = torch.device("cpu")
                self.device=mps_device
                self.nn.to(mps_device)
            else:
                mps_device = torch.device("mps")
                self.device=mps_device
                self.nn.to(mps_device)
                print("Model moved to MPS")
        elif device_str=='cuda':
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
                self.nn.to(self.device)
                print("Model moved to GPU")
            else:
                print("GPU is not available, using CPU instead")

        # dummy_input = torch.randn(10, 3, 32, 32, device = self.device)

        # # Pass the dummy input through the model to get the output tensor
        # self.nn.eval()  # Set the model to evaluation mode
        # with torch.no_grad():  # Disable gradient computation for inference
        #     output = self.nn(dummy_input)

        # # Inspect the shape of the output tensor
        # output_shape = output.shape
        # total_elements = torch.prod(torch.tensor(output_shape,device=self.device)).item()
        # print("Output shape:", total_elements)

        # data loader for the testing set
        self.test_dl = test_dl

        # scalar for decreseasing the range of the quantizer at
        # each round of sharing
        self.q_scalar = q_scalar

        # Maximum initial magnitude of wieghts
        self.q_range = 0.005

        # optimizer to be used for minimization
        self.optimizer = self.createOptimizer(optim)

        # number of bits allowed for quantization
        self.bits = bits

        #bucket=2**(bits-1)
        # varence for the dithering
        self.dither = 1/(bits)
     

        # threshold for treating weight as 0
        self.threshold = threshold
        self.threshold_og =threshold

        # the previous full weigths
        self.prev_w = self.weightDictTo1D()

        # the previous full difference between this and last round
        self.prev_dif = np.zeros(self.prev_w.shape)

        # dictionrary repersenting the weights of the entire network.
        self.nbd_weights_out = {}

        self.nbd_weights_in = {}


        # repersents the support that this has for each node in the network
        self.nbd_indices_out = {}

        self.nbd_indices_in = {}

        # total bits used
        self.total_bits = 0
        
        # used for merging the weights together, ie returning the 1-dim vector
        # back into seperate layers and vice versa 
        self.total_dim =0
        self.layers = []
        count = 0


        # for learning rate updates
        self.l=l
        self.eps=eps 
        self.alpha =alpha
        self.t =1
        self.updateLr()

    def initZero(self):
        zero=torch.zeros(self.prev_w.shape)
        self.nn.load_state_dict(self.reconstructW(zero))

    def computeQNoScale(self,id):
        #if id not in intialize 
        if id not in self.nbd_weights_out:
            self.initWeightandIdxOut(id)

        q,b = self.quantizeNoScale(self.weightDictTo1D()-self.nbd_weights_out[id])
        self.nbd_weights_out[id] += copy.deepcopy(torch.Tensor(q))*b[0]+b[1]
        return q, b

    def computeQ(self,id):
        #if id not in intialize 
        if id not in self.nbd_weights_out:
            self.initWeightandIdxOut(id)

        q,b = self.quantize(self.weightDictTo1D()-self.nbd_weights_out[id])
        self.nbd_weights_out[id] += copy.deepcopy(torch.Tensor(q)*b[0]+b[1]).detach()
        return copy.deepcopy(q).detach(), b
    
    def computeQ_C(self, id):
        if id not in self.nbd_weights_out:
            self.initWeightandIdxOut(id)

        q,b,l = self.quantize(self.weightDictTo1D()-self.nbd_weights_out[id], True)

        self.nbd_weights_out[id] += copy.deepcopy(torch.Tensor(q)*b[0]+b[1]).detach()
        return copy.deepcopy(q), b, l

    def gradThreshold(self):
        if self.prev_w is not None:
            self.prev_w = self.weightDictTo1D()

        self.nn, self.train_loss = OurNNUtils.trainLoop(self.data_loader, self.nn,
                                       self.optimizer, id=self.id, reg=self.reg,
                                       device=self.device, E=self.E)

        curr_weights = self.weightDictTo1D()
        curr_weights = self.threshold_w(curr_weights)
        #curr_weights = self.threshold_w(curr_weights)

        if self.prev_w is None:
            self.prev_w = torch.zeros(curr_weights.shape)


        self.prev_dif = curr_weights-self.prev_w
        self.t+=1
        self.updateLr()
        self.updateThreshold()

        #curr_weights = self.threshold_w(curr_weights)

        if self.prev_w is None:
            self.prev_w = torch.zeros(curr_weights.shape)

        self.prev_dif = curr_weights-self.prev_w
        self.t+=1
        self.updateLr()
        self.updateThreshold()
        
    def updateYnoScale(self,id,q,rm):
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)

        self.nbd_weights_in[id]+=(copy.deepcopy(torch.Tensor(q))) 

    def updateYtest(self,id,q,rm, og):
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)
        for i in range(len(q)):
            if q[i]!=2:
                q[i]=q[i]*rm[0] +rm[1]
            else:
                q[i]=0
        self.nbd_weights_in[id]+=copy.deepcopy(torch.Tensor(q)).detach() 

    def updateY(self,id,q,rm):
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)

        q=q*rm[0]+rm[1]
        # for i in range(len(q)):
        #     if q[i]!=2:
        #         q[i]= q[i]*rm[0] +rm[1]
        #     else:
        #         q[i]=0

        self.nbd_weights_in[id]+=copy.deepcopy(torch.Tensor(q)).detach() 

    def aggregateThres(self,id, weight=1):
        avg_weights = (self.weightDictTo1D()+weight*self.nbd_weights_in[id])/2
        avg_weights= self.threshold_w(avg_weights)
        return avg_weights
    
    def aggregateNbdThres(self, nbd):
        avg_weights=self.weightDictTo1D()
        for id in nbd:
            avg_weights+=self.nbd_weights_in[id]
        avg_weights=avg_weights/(len(nbd)+1)
        return self.threshold_w(avg_weights)
    
    def aggregateNbd(self,nbd): 
        scale=1/(len(nbd)+1)
        avg_weights=self.weightDictTo1D()
        for id in nbd:
            avg_weights-= scale*self.nbd_weights_out[id]
            avg_weights+= scale*self.nbd_weights_in[id]

        self.nn.load_state_dict(self.reconstructW(avg_weights))


    def aggregate(self,id, choco=False):
        if not choco:
            avg_weights = (self.weightDictTo1D()+self.nbd_weights_in[id])/2
            self.nn.load_state_dict(self.reconstructW(avg_weights))
        else:
            avg_weights = self.weightDictTo1D()-.5*self.nbd_weights_out[id]
            avg_weights+=.5*self.nbd_weights_in[id]
            self.nn.load_state_dict(self.reconstructW(avg_weights))

        return avg_weights

    def createOptimizer(self, optim, mo=.99):
        """
        Returns the optimizer specified by the string optim
        """
        if optim == "adam":
            return torch.optim.Adam(self.nn.parameters(), lr=self.lr)
        elif optim == "mo":
            return torch.optim.SGD(self.nn.parameters(), lr=self.lr, momentum=.9)
        return torch.optim.SGD(self.nn.parameters(), lr=self.lr)

    def createNeuralNetwork(self, nn_type):
        """
        returns a new nn with the specified type.
        """
        if nn_type == "light":
            return OurNNUtils.LightMnist()
        elif nn_type == "full":
            return OurNNUtils.FullMnist()
        elif nn_type == "res":
            net = models.resnet18(pretrained=False)
            num_ftrs = net.fc.in_features
            net.fc = nn.Linear(num_ftrs, 10)
            return net
        
        return OurNNUtils.CNNMnist()

    def test(self):
        """
        Tests the accuracy of the current model with the provided
        test data loader.
        Inputs:
        None
        Returns:
        accuracy: float repersenting the testing accuracy.
        """
        acc, self.loss_v = OurNNUtils.testLoop(self.test_dl, self.nn, queit=True, device=self.device) 
        return acc
    
    def loss(self):
        return self.loss_v

    def getTrainLoss(self):
        return self.train_loss
    
    def computeGrad(self, z_i):
        self.prev_w = self.weightDictTo1D()

        self.nn, self.train_loss = OurNNUtils.trainLoop(self.data_loader, self.nn,
                                       self.optimizer, id=self.id, reg=self.reg)

        curr_weights = self.weightDictTo1D()


        grad=self.prev_w-curr_weights

        curr_weights = self.threshold_w(z_i-grad)

        if self.prev_w is None:
            self.prev_w = torch.zeros(curr_weights.shape)

        self.prev_dif = curr_weights-self.prev_w

    def gradNoThreshold(self):
        if self.prev_w is not None:
            self.prev_w = self.weightDictTo1D()
            #has_nan = torch.isnan(self.prev_w).any()
            #if has_nan:
            #    raise Exception("Gradient descent RECIEVED nan")

        self.nn, self.train_loss = OurNNUtils.trainLoop(self.data_loader, self.nn,
                                       self.optimizer, id=self.id, reg=self.reg, 
                                       device=self.device, E=self.E)

        curr_weights = self.weightDictTo1D()

        #has_nan = torch.isnan(curr_weights).any()
        #if has_nan:
        #    raise Exception("Gradient descent produced nan")

        #curr_weights = self.threshold_w(curr_weights)

        if self.prev_w is None:
            self.prev_w = torch.zeros(curr_weights.shape)

        
        self.prev_dif = curr_weights-self.prev_w
        self.t+=1
        self.updateLr()
        self.updateThreshold()


    def updateLr(self):
        if self.lr_const==False:
            denom=(self.t+self.alpha)**self.eps
            new_lr = 1/(self.l*denom)
    
            for g in self.optimizer.param_groups:
                g['lr'] = new_lr


    def train(self):
        """
        Performs Training on the local Dataset and stores it in
        self.nn

        Input: None
        Returns: None

        """
        if self.prev_w is not None:
            self.prev_w = self.weightDictTo1D()

        self.nn, self.train_loss = OurNNUtils.trainLoop(self.data_loader, self.nn,
                                       self.optimizer, id=self.id, reg=self.reg)

        curr_weights = self.weightDictTo1D()
        curr_weights = self.threshold_w(curr_weights)

        if self.prev_w is None:
            self.prev_w = torch.zeros(curr_weights.shape)

        self.prev_dif = curr_weights-self.prev_w

    def weightDictTo1D(self):
        """
        This really should only be called once/twice per iteration 
        and stored as a field.
        """

        weights_to_flatten = []
        count = 0

        for p in self.nn.parameters():
            weights_to_flatten.extend(p.view(-1).tolist())

            count += 1

        return copy.deepcopy(torch.tensor(weights_to_flatten, device=self.device)).detach() #weights_to_flatten[0]


    def quantize(self, w_dif,qsgd=False):
        """
        Generic Method for quantization. It takes in the delta for the weight along
        with the indices.

        Input:
        w_diff: the diffrence between this weight vector and last rounds weight vector
        """

        # potentially herendeously inefficent.
        if self.bits == -1:
            return w_dif, (1,0)
        #-1*self.dither,self.dither,
        dither_vec = (self.dither)*torch.rand(w_dif.shape, device=self.device)
        has_nan = torch.isnan(w_dif).any()
        #if has_nan:
        #    raise Exception("quantize recived nan!") 

        if qsgd==True:
            q, b, l= self.qsgdQuant(w_dif,self.bits,dither_vec)
            return q,b,l
        
        q, range_x, min_x = self.specialQuant(w_dif, self.bits, dither_vec)
        b= (range_x,min_x)

        has_nan = torch.isnan(q).any()
        #if has_nan:
        #    raise Exception("quantize returned a nan!")        

        return q, b
    

    def quantizeNoScale(self, w_dif):
        dither_vec = torch.from_numpy(np.random.default_rng().uniform(-1*self.dither,self.dither,w_dif.shape))
            #np.random.default_rng().normal(0, self.dither, w_dif.shape))


        q, b = self.uniformQuantBitFast(w_dif+dither_vec, self.bits, zizo=False)
      
        # self.q_local_dif = q delete me soon?

        return q, b

    def initWeightandIdxOut(self, i):
        """
        Given id=i, sets the self.nbd_weights[i] equal to the 
        zeros vector and the self.nbd_indices[i] equal to a list 
        that ranges from 0 to the size of the weights vector. i.e. 
        full support. 
        """
        weights = torch.zeros(self.prev_w.shape, device=self.device)
        self.nbd_weights_out[i] = weights
        self.nbd_indices_out[i] = list(range(weights.shape[0]))
        # if i == self.id:
        #     #potentially the wrong one here
        #     self.nbd_weights[i]=self.weightDictTo1D()

    def initWeightandIdxIn(self, i):
        """
        No comment. JK its for the weights and Indicies we 
        recieve from neighbors
        """
        weights = torch.zeros(self.prev_w.shape, device=self.device)
        self.nbd_weights_in[i] = weights
        self.nbd_indices_in[i] = list(range(weights.shape[0]))
 

    def expandWeights(self, weights, indices):
       #expand weights
       #using self only to get the weight dimensions
       full_weights = torch.zeros(len(self.prev_w))
       for i in range(len(indices)):
           full_weights[indices[i]] = weights[i]

       return full_weights
    

    def reconstructW(self, weights):
        """
        Reconstrcts the dictionary based on a provided 1-D
        weight tensor
        """
        param_index = 0
        for param in self.nn.parameters():
            param_shape = param.shape
            param_size = np.prod(param_shape)
            
            # Slice the appropriate number of elements from params_array
            sliced_params = copy.deepcopy(weights[param_index:param_index + param_size])
            
            # Reshape and assign the sliced_params to the parameter
            # param.data = torch.tensor(sliced_params).reshape(param_shape)
            param.data = sliced_params.clone().detach().reshape(param_shape).to(self.device)

            # Update the index for the next parameter
            param_index += param_size


        return self.nn.state_dict()

    def getDifSup(self, id):
        if id  not in self.nbd_weights_out:

            # might need to double check about if id == self.id
 
            self.initWeightandIdxOut(id)

        curr_y = self.weightDictTo1D() #X^(t+1/2)
        nz_delta = []
        nz_indices = []
        local_delta = curr_y-self.nbd_weights_out[id]

        sup_diff = []

        for i in range(curr_y.shape[0]):
            if np.abs(local_delta[i])>= 0: #np.abs(local_delta[i])>= 0
                if i not in self.nbd_indices_out[id]:
                    # potentially remove, since we assume that local_delta 
                    # anf curr_y are equal on the first data exchange.
                    nz_delta.append(curr_y[i])
                    nz_indices.append(i)
                    sup_diff.append(i)
                else:
                    nz_delta.append(local_delta[i])
                    nz_indices.append(i)

        t = torch.tensor(nz_delta, dtype=torch.float64)
        
        return t, nz_indices #, sup_diff

    def updateLocalnbdVals(self, deltanonzero, deltanonzero_i, id):
        delta = self.expandWeights(deltanonzero, deltanonzero_i)
        # if below threshold it is equal
        reconstructed_weights = self.nbd_weights_out[id]+delta
        #potentially remove this....
        # for i in deltanonzero_i:
        #     if i not in self.nbd_indices_out[id]:
        #         reconstructed_weights[i] -= self.nbd_weights_out[id][i]

        self.nbd_indices_out[id] = deltanonzero_i
        self.nbd_weights_out[id] = reconstructed_weights

    def updateOut(self,id ,delta):
        self.nbd_weights_out[id]+=delta

    def getSupport(self):
        w = self.weightDictTo1D()
        return torch.sum(torch.abs(w)>torch.tensor(1e-7, device=self.device)).item()

    def getDif(self, i):
        """
        Get dif of what this sent to i last time
        """
        if i  not in self.nbd_weights_out:

            # might need to double check about if id == self.id
            self.initWeightandIdxOut(i)
        curr_w = self.weightDictTo1D()
        local_delta=curr_w-self.nbd_indices_out[i]
        return torch.tensor(local_delta, dtype=torch.float64)
    
    def updateThreshold(self, const=True):
        self.threshold = self.threshold_og*self.tr_decay

    def threshold_w(self, weights, hard=False):
        if (hard):
            for i in range(weights.shape[0]):
                if np.abs(weights[i])<self.threshold:
                    weights[i]=0
        else:
            weights=torch.max(torch.abs(weights)-self.threshold,torch.tensor(0, device=self.device))*torch.sign(weights)
        
        self.nn.load_state_dict(self.reconstructW(weights))
        return weights

    def computeQandX(self, i):
        #delta=self.getDif(i)
        delta, delta_idx = self.getDifSup(i)
        q_delta, bits = self.quantize(delta)
        self.updateLocalnbdVals(q_delta, delta_idx, i)
        #self.updateLocalnbdVals(delta, delta_idx, i)
        #self.updateOut(i, delta)

        if i != self.id:
            self.total_bits += bits

        #return delta, delta_idx, bits, zb_perc
        return q_delta, delta_idx, bits

    def mergeSimple(self,q_delta_j,id_j):
        x_hat_j=q_delta_j+self.nbd_weights_in[id_j]
        local_weights = self.weightDictTo1D()
        avg_weights = (local_weights+x_hat_j)/2
        self.nbd_weights_in[id]=x_hat_j
        self.nn.load_state_dict(self.reconstructW(avg_weights))


    def mergeNoUpdate(self, q, q_idx, id, choco=False):
        delta = self.expandWeights(q, q_idx)

        # choco
        local_weights = self.weightDictTo1D()
        if choco:
            local_weights *= 2
            for i in range(local_weights.shape[0]):
                # this might be wrong since nbd_weights[self.id]
                # should have already been moved to the t+1 time interval.
                # local_weights[i] -= (self.q_local_dif[i] +
                #                      self.nbd_weights[self.id][i])
                
                local_weights[i] -= (self.nbd_weights_out[self.id][i])

        # if below threshold it is equal
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)

        reconstructed_weights = self.nbd_weights_in[id]+delta
        for i in q_idx:
            if i not in self.nbd_indices_in[id]:
                reconstructed_weights[i] -= self.nbd_weights_in[id][i]

        self.nbd_indices_in[id] = q_idx
        self.nbd_weights_in[id] = reconstructed_weights

        avg_weights = (local_weights+reconstructed_weights)/2
        self.nbd_weights_out[self.id] = avg_weights
        # z_i=self.reconstructW(avg_weights)
        #self.nn.load_state_dict(z_i)
        return avg_weights

    def mergeWeights(self, q, q_idx, id, choco=False):
        delta = self.expandWeights(q, q_idx)

        # choco
        local_weights = self.weightDictTo1D()
        if choco:
            local_weights *= 2
            for i in range(local_weights.shape[0]):
                # this might be wrong since nbd_weights[self.id]
                # should have already been moved to the t+1 time interval.
                # local_weights[i] -= (self.q_local_dif[i] +
                #                      self.nbd_weights[self.id][i])
                
                local_weights[i] -= (self.nbd_weights_out[self.id][i])

        # if below threshold it is equal
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)

        reconstructed_weights = self.nbd_weights_in[id]+delta
        for i in q_idx:
            if i not in self.nbd_indices_in[id]:
                reconstructed_weights[i] -= self.nbd_weights_in[id][i]

        self.nbd_indices_in[id] = q_idx
        self.nbd_weights_in[id] = reconstructed_weights

        avg_weights = (local_weights+reconstructed_weights)/2
        self.nbd_weights_out[self.id] = avg_weights
        self.nn.load_state_dict(self.reconstructW(avg_weights))
        return reconstructed_weights
    
    def mergeThres(self, q, q_idx, id, choco=False):
        delta = self.expandWeights(q, q_idx)

        # choco
        local_weights = self.weightDictTo1D()
        if choco:
            local_weights *= 2
            for i in range(local_weights.shape[0]):
                # this might be wrong since nbd_weights[self.id]
                # should have already been moved to the t+1 time interval.
                # local_weights[i] -= (self.q_local_dif[i] +
                #                      self.nbd_weights[self.id][i])
                
                local_weights[i] -= (self.nbd_weights_out[self.id][i])

        # if below threshold it is equal
        if id not in self.nbd_weights_in:
            self.initWeightandIdxIn(id)

        reconstructed_weights = self.nbd_weights_in[id]+delta
        for i in q_idx:
            if i not in self.nbd_indices_in[id]:
                reconstructed_weights[i] -= self.nbd_weights_in[id][i]

        self.nbd_indices_in[id] = q_idx
        self.nbd_weights_in[id] = reconstructed_weights
 
        avg_weights = (local_weights+reconstructed_weights)/2
        self.nbd_weights_out[self.id] = avg_weights
        avg_weights= self.threshold_w(avg_weights)
        return avg_weights
    

    def getBins(self, x, bit_num):
        L = 2 + bit_num - 1 
        width = (2*self.q_range)/(2*L - 1) 
        bins_count = {}

        for scalar in x:
            k = round(scalar.item()/width)
  
            if abs(k) > L:
                k_sign = k / abs(k)
                k = k_sign * L 
            assert abs(k) <= L
            if k in bins_count:
                bins_count[k] += 1
            else:
                bins_count[k] = 1
        return bins_count

    
    def specialQuant(self, x, bit_num, dither_vec,eps=10e-7):

        min_x=torch.min(x).item()#math.floor()
        max_x=torch.max(x).item()#math.ceil())
        range_x=max_x-min_x
 
        bucket=bit_num
        x=(x-min_x)/range_x
        x+=dither_vec
        x=torch.floor(x*bucket)/bucket

        return x, range_x,min_x #round(bucket*max_norm)/bucket,  0#round(bucket*min_x)/bucket #max_norm,

    def qsgdQuant(self,x, bit_num, dith_vec, eps=1e-5):
        min_x=torch.min(x).item()#math.floor()
        max_x=torch.max(x).item()#math.ceil())
        range_x=max_x-min_x
        if range_x< eps:
            return torch.zeros_like(x,device=self.device), (0,0), torch.zeros_like(x,device=self.device, dtype=torch.int64) 
        bucket=bit_num
        x=(x-min_x)/range_x
        x+=dith_vec
        levels=torch.floor((x)*bucket)
        x=levels/bucket
  
        return x, (range_x, min_x), levels

    def qsgd_quant(self, x, bit_num, eps=10e-7):

        s = torch.sign(x)
        #print(s)
        norm = torch.max(torch.abs(x))
        x = torch.abs(x) / norm

        q = x.clone()
        xi_l = torch.floor(x.clone()*bit_num)
  
        prob = 1 - (x.clone()*bit_num - xi_l)
        p = torch.rand(size=prob.shape, device=self.device)
        p_bool = p < prob

        q = q.where(~p_bool,(xi_l+1)/bit_num)
        q = q.where(p_bool, (xi_l)/bit_num)

        q = q*s
        q = q*norm

        return q, norm
        #return p_bool, norm
        
    def uniformQuantBit(self, x, bit_num=14, zizo=False):
        """
        Performs the standard uniform quantization on x, where x is a vector like structure
        and where delta is the step size for the quantization levels.
        Note: we don't have a max scalar, so we might want to put an upper and lower bound
        on our uniform quantizer.

        Input:
        x: vector reperseting the data to be quantized
        delta: the step/level size.
        """
        num_bits=0
        new = []
        #bucket=2**(bit_num-1)
        L = 2 + bit_num - 1 #this is dumb, but its just to show the current assumption lb = 2
        width = (2*self.q_range)/(2*L - 1) 

        #zeros = 0
        for scalar in x:
            k = round(scalar.item()/width)
            q = width*k
   
            if abs(k) > L:
                k_sign = k / abs(k)
                k = k_sign * L 
            assert abs(k) <= L 
            if abs(k) > L - 2: #if its L or L-1
                num_bits += bit_num
            else:
                num_bits += abs(k) + 1 #(there's a zero bin here)
                #if abs(k) == 0:
                #    zeros += 1
            new.append(q)

        return new, num_bits
