import EliasMap
import torch
import numpy as np
import math


####################     ELIAS DICTIONARY CREATION     ####################
def eliasBit(x):
    assert x > 0
    bits = 1
    while x > 1:
        binary = math.floor(math.log2(x)) + 1
        bits += binary
        x = binary - 1
    return bits

def getWeightSize(string):
    num_weights = "error"
    if string=="res":
        num_weights = 11400000 
    elif string=="full":
        num_weights = 670000
        pass 
    elif string=="cnn": 
        num_weights = 697200 
    elif string=="light":
        num_weights = 8000 
    else:
        print(string+ " is not a valid nn type. Must be: res/full/cnn/light.")
    return num_weights 

def return_dict(string):
    L = getWeightSize(string)
    edict = {}
    for i in range(1, L+1):
        edict[i] = eliasBit(i)
    return edict


####################           BIT COUNTING             ####################
def XOR(vec,device):
    #returns vec + 1 mod 2
    return torch.bitwise_xor(vec.type(torch.int64), torch.ones(vec.shape,device=device).type(torch.int64))

def getType(data, Q):
    #returns the number of occurences in data of each level in Q
    er_type = torch.bincount(data, minlength=len(Q))
    return er_type

def countRunLengths(supp, device):
    """
    Given a list of support vectors supp, returns the run-lengths within
    each support vector.
    """
    rl_lvl = []
    for s in supp:
        rls = []
        s = torch.nonzero(s == 1).flatten()

        z = torch.tensor([0], device=device)
        shift = torch.cat((z, s))
        shift = shift[:-1]
        rls = torch.sub(s.flatten(), shift)
        oneplus = torch.ones(rls.shape[0], device=device).type(torch.int64)
 
        oneplus[0] = 0
        rls = torch.sub(rls, oneplus)
        rl_lvl.append(rls)
    return rl_lvl

def getSupport(data, L, t, device):
    """
    Given data, returns the support vectors for each of the L levels.
    """
    notdone = torch.arange(0, len(data), dtype=torch.int32, device=device)
    sups, sum_type = [], 0
    datalen = len(data)
    Ms = []
    for q in range(L):
        sup_idx = torch.index_select(data, 0, notdone)
        res = sup_idx.clone()
        res[sup_idx == q] = 1
        res[sup_idx!= q] = 0
        idxs = torch.nonzero(res == 1)
        idxs = torch.flatten(idxs)
        idxs = torch.index_select(notdone, 0, idxs)
        if t[q] > math.ceil(notdone.shape[0] / 2):
            res = XOR(res,device)
        comb = torch.cat((notdone, idxs))
        uniques, cts = comb.unique(return_counts = True)
        notdone = uniques[cts == 1]

        if torch.count_nonzero(res) > 0: 
            sups.append(res)
        sum_type += t[q]
        if t[q] > 0:
            M = ((1-(sum_type/datalen))*math.log(2))/(t[q]/datalen)
            assert M >= 0
            M = max(1, round(M.item()))
            Ms.append(M)
        else:
           M=0 

    return sups, Ms

def golombBit(x, M, device):
    #Counts the number of bits when using Golomb coding
    b = math.floor(math.log2(M))
    q = torch.floor(x / M)
    r = x - q*M
    bits = torch.floor(1 + q)
    le = r < 2**(b+1) - M
    le.to(device)
    lec = torch.ones(x.shape[0])
    lec[le == 1] = b
    lec[le == 0] = b+1
    return torch.sum(bits) + torch.sum(lec)

def eliasBit(x):
    #Counts the number of bits when using Elias coding
    bits = 1
    while x > 1:
        binary = math.floor(math.log2(x)) + 1
        bits += binary
        x = binary - 1
    return bits

def getLabels(data, L):
    #returns integer labels of data
    ltest = data*L
    ltest = ltest.int()
    return ltest

def calcBitsCum(q, L, device='cuda',elias=None):
    """
    Counts the number of bits required to encode labels q.  
    """
    t = getType(q, [i for i in range(L)])
    sup, Ms = getSupport(q, L, t, device=device)
    rls = countRunLengths(sup, device=device)
    totalbits = 0
    for j in range(len(rls)):
        totalbits += golombBit(rls[j], Ms[j], device=device)
    totalbits += torch.sum(elias[torch.abs(t)])
    return totalbits+torch.tensor([64],device=device) 

def calcBitsQSGD(q, L, device='cuda', elias=None):
    """
    Counts the number of bits required to encode labels q using QSGD 
    encoding.
    """
    q_idx = torch.nonzero(q).flatten()
    if q_idx.shape[0] == 0: 
        return torch.tensor([32])
    
    z = torch.tensor([0],device=device)
    shift = torch.cat((z, q_idx))
    shift = shift[:-1]
    rls = torch.sub(q_idx.flatten(), shift)
    oneplus = torch.ones(rls.shape[0],device=device).type(torch.int64)
    oneplus[0] = 0

    rls = torch.sub(rls, oneplus)
    ebits = torch.sum(elias[torch.abs(rls)])
    nonz = q[q_idx]
    idxr = torch.abs(nonz)
    ebits += torch.sum(elias[idxr.long()])

    return torch.tensor([32],device=device) + (q_idx.shape[0] + ebits)