import numpy as np
import algorithm
import utils
import time
from loader import *

def evaluateSketch(data, sketch, estFuns):
    '''
    Evaluate a sketch on a given dataset which is a 
    vector of counts (the frequency histogram)
    '''
    sketch.vectorUpdate(data)
    #for i, freq in enumerate(data):
    #    sketch.update(i, freq)
    uerr = np.zeros(len(estFuns)) # unweighted
    werr = np.zeros(len(estFuns)) # weighted
    for i, freq in enumerate(data):
        for j, estFun in enumerate(estFuns):
            est = estFun(sketch.estimate(i))
            error = np.abs(est - freq)
            if freq != 0: # don't count error on elems which didn't appear
                uerr[j] += error
                werr[j] += freq * error
    return uerr, werr

def createAndEvalSketch(data, W, L, estFuns, trials=1):
    n = len(data)
    uerrs = np.zeros((len(estFuns), trials))
    werrs = np.zeros((len(estFuns), trials))
    for trial in range(trials):
        sketch = algorithm.CountSketch(n, W, L)
        errs = evaluateSketch(data, sketch, estFuns) 
        uerrs[:,trial] = errs[0]
        werrs[:,trial] = errs[1]
    return uerrs, werrs

def searchThresh(data, cols, threshConsts, trials=1, preds=None):
    n = len(data)
    uerrsStd = [] 
    uerrsNonNeg = [] 
    uerrsOneTable = [] 
    werrsStd = [] 
    werrsNonNeg = [] 
    werrsOneTable = [] 
    for col in cols:
        errs = None
        if preds is None:
            estFuns = [lambda x: x, lambda x: max(x, 0)]
            estFuns += [lambda x, c=const: x if x >= c * n / col else 0
                        for const in threshConsts]
            errs = createAndEvalSketch(data, col, 3, estFuns, trials)
        else:
            col = col // 2 # use half of the space for predicted heavy hitters
            estFuns = [lambda x: x, lambda x: max(x, 0)]
            estFuns += [lambda x, c=const: x if x >= c * n / col else 0
                        for const in threshConsts]
            topK = np.argpartition(preds.flatten(), -col)[-col:] # find top k predicted elems
            newData = np.copy(data)
            newData[topK] = 0 # perfect prediction on these data
            errs = createAndEvalSketch(newData, col, 3, estFuns, trials) 

        uerrs = errs[0]
        werrs = errs[1]
        uerrsStd.append(uerrs[0,:])
        werrsStd.append(werrs[0,:])
        uerrsNonNeg.append(uerrs[1,:])
        werrsNonNeg.append(werrs[1,:])
        uerrsOneTable.append(uerrs[2:,:])
        werrsOneTable.append(werrs[2:,:])

    return uerrsStd, uerrsNonNeg, uerrsOneTable, \
            werrsStd, werrsNonNeg, werrsOneTable

def writeErrs(errs, folder):
    keys = ['uerrsStd', 'uerrsNonNeg', 'uerrsOneTable', \
            'werrsStd', 'werrsNonNeg', 'werrsOneTable']
    assert(len(keys) == len(errs))
    for i, key in enumerate(keys):
        data = np.array(errs[i])
        if len(data.shape) == 3:
            newData = data[:,0,:]
            for j in range(1, data.shape[1]):
                newData = np.vstack((newData, data[:,j,:]))
            data = newData
        np.savetxt(folder+key+'.txt', np.array(data))

def runExperiment(data, cols, threshConsts, trials, preds=None, outFolder='logs/'):
    timestamp = str(time.time())
    print(timestamp)
    errs = searchThresh(data, cols, threshConsts, trials=trials, preds=preds)
    np.savetxt(outFolder+timestamp+'-widths.txt', cols)
    np.savetxt(outFolder+timestamp+'-consts.txt', threshConsts)
    writeErrs(errs, outFolder+timestamp+'-')

def syntheticLarge():
    n = 500000
    cols = [50, 100, 250, 500, 1000]
    threshConsts = np.array([1, 2, 5])
    trials = 10
    data = generateZipf(n, 2)
    runExperiment(data, cols, threshConsts, trials)

def syntheticSmall():
    n = 1000
    cols = np.arange(50,101,50)
    threshConsts = np.array([1, 2, 5])
    trials = 5
    data = generateZipf(n)
    runExperiment(data, cols, threshConsts, trials)

def realExperiment(dataset='CAIDA', withPredictions=False):
    if dataset == 'CAIDA':
        data, preds = loadCAIDA()
        prefix = 'logs/ip/'
    elif dataset == 'AOL':
        data, preds = loadAOL()
        prefix = 'logs/aol/'
    else:
        raise Error('Unknown dataset')
    if not withPredictions:
        preds = [None for i in range(len(preds))]
    cols = [50, 100, 250, 500, 1000]
    threshConsts = np.array([1, 2, 5])
    trials = 10
    for i in range(49,50):
        print(i)
        runExperiment(data[i], cols, threshConsts, trials, preds=preds[i],
                      outFolder=prefix+'day{}-pred{}-'.format(i, withPredictions))

def main():
    # syntheticLarge()
    realExperiment('CAIDA', True)

if __name__ == "__main__":
    main()
