import numpy as np
import algorithm
import utils
import time
from loader import *

def evaluateSketch(data, sketch, estFuns):
    '''
    Evaluate a sketch on a given dataset which is a
    vector of counts (the frequency histogram)
    '''
    sketch.vectorUpdate(data)
    #for i, freq in enumerate(data):
    #    sketch.update(i, freq)
    uerr = np.zeros(len(estFuns)) # unweighted
    werr = np.zeros(len(estFuns)) # weighted
    for i, freq in enumerate(data):
        for j, estFun in enumerate(estFuns):
            est = estFun(sketch.estimate(i))
            error = np.abs(est - freq)
            if freq != 0: # don't count error on elems which didn't appear
                uerr[j] += error
                werr[j] += freq * error
    return uerr, werr

def createAndEvalMisraGries(data, ctrs, preds):
    estFun = lambda x: x
    all_uerrs = []
    all_werrs = []
    for k in ctrs:
        newData = np.copy(data)
        if preds is not None:
            k = k // 2 # use half of the space for predicted heavy hitters
            topK = np.argpartition(preds.flatten(), -k)[-k:] # find top k predicted elems
            newData[topK] = 0 # perfect prediction on these data

        sketch = algorithm.MisraGries(k)
        errs = evaluateSketch(newData, sketch, [estFun])
        all_uerrs.append(errs[0])
        all_werrs.append(errs[1])
    return all_uerrs, all_werrs

def createAndEvalCountSketch(data, W, L, estFuns, trials=1):
    n = len(data)
    uerrs = np.zeros((len(estFuns), trials))
    werrs = np.zeros((len(estFuns), trials))
    for trial in range(trials):
        sketch = algorithm.CountSketch(n, W, L)
        errs = evaluateSketch(data, sketch, estFuns)
        uerrs[:,trial] = errs[0]
        werrs[:,trial] = errs[1]
    return uerrs, werrs

def searchThresh(data, cols, threshConsts, trials=1, preds=None):
    n = len(data)
    uerrsStd = []
    uerrsNonNeg = []
    uerrsOneTable = []
    werrsStd = []
    werrsNonNeg = []
    werrsOneTable = []
    for col in cols:
        errs = None
        if preds is None:
            estFuns = [lambda x: x, lambda x: max(x, 0)]
            estFuns += [lambda x, c=const: x if x >= c * n / col else 0
                        for const in threshConsts]
            errs = createAndEvalCountSketch(data, col, 3, estFuns, trials)
        else:
            col = col // 2 # use half of the space for predicted heavy hitters
            estFuns = [lambda x: x, lambda x: max(x, 0)]
            estFuns += [lambda x, c=const: x if x >= c * n / col else 0
                        for const in threshConsts]
            topK = np.argpartition(preds.flatten(), -col)[-col:] # find top k predicted elems
            newData = np.copy(data)
            newData[topK] = 0 # perfect prediction on these data
            errs = createAndEvalCountSketch(newData, col, 3, estFuns, trials)

        uerrs = errs[0]
        werrs = errs[1]
        uerrsStd.append(uerrs[0,:])
        werrsStd.append(werrs[0,:])
        uerrsNonNeg.append(uerrs[1,:])
        werrsNonNeg.append(werrs[1,:])
        uerrsOneTable.append(uerrs[2:,:])
        werrsOneTable.append(werrs[2:,:])

    return uerrsStd, uerrsNonNeg, uerrsOneTable, \
            werrsStd, werrsNonNeg, werrsOneTable

def writeErrs(errs, folder):
    keys_cs = ['uerrsStd', 'uerrsNonNeg', 'uerrsOneTable', \
            'werrsStd', 'werrsNonNeg', 'werrsOneTable']
    keys_mg = ['uerrsMG', 'werrsMG']
    keys = None
    if len(errs) == len(keys_cs):
        keys = keys_cs
    elif len(errs) == len(keys_mg):
        keys = keys_mg
    else:
        raise Error()
    for i, key in enumerate(keys):
        data = np.array(errs[i])
        if len(data.shape) == 3:
            newData = data[:,0,:]
            for j in range(1, data.shape[1]):
                newData = np.vstack((newData, data[:,j,:]))
            data = newData
        np.savetxt(folder+key+'.txt', np.array(data))

def runCountSketchExperiment(data, cols, threshConsts, trials, preds=None, outFolder='logs/cs-'):
    timestamp = str(time.time())
    print(timestamp)
    errs = searchThresh(data, cols, threshConsts, trials=trials, preds=preds)
    np.savetxt(outFolder+timestamp+'-widths.txt', cols)
    np.savetxt(outFolder+timestamp+'-consts.txt', threshConsts)
    writeErrs(errs, outFolder+timestamp+'-')

def runMisraGriesExperiment(data, ctrs, preds=None, outFolder='logs/mg-'):
    timestamp = str(time.time())
    print(timestamp)
    errs = createAndEvalMisraGries(data, ctrs, preds=preds)
    np.savetxt(outFolder+timestamp+'-counters.txt', ctrs)
    writeErrs(errs, outFolder+timestamp+'-')

def syntheticLarge():
    n = 500000
    cols = [50, 100, 250, 500, 1000]
    ctrs = [c*3 for c in cols]
    threshConsts = np.array([1, 2, 5])
    trials = 10
    data = generateZipf(n, 2)
    runCountSketchExperiment(data, cols, threshConsts, trials)
    runMisraGriesExperiment(data, ctrs, trials)

def syntheticSmall():
    n = 1000
    cols = np.arange(50,101,50)
    ctrs = [c*3 for c in cols]
    threshConsts = np.array([1, 2, 5])
    trials = 5
    data = generateZipf(n)
    runCountSketchExperiment(data, cols, threshConsts, trials)
    runMisraGriesExperiment(data, ctrs, trials)

def realExperiment(dataset='CAIDA', withPredictions=False):
    if dataset == 'CAIDA':
        data, preds = loadCAIDA()
        prefix = 'logs/ip/'
    elif dataset == 'AOL':
        data, preds = loadAOL()
        prefix = 'logs/aol/'
    else:
        raise Error('Unknown dataset')
    if not withPredictions:
        preds = [None for i in range(len(preds))]
    cols = [50, 100, 250, 500, 1000]
    threshConsts = np.array([1, 2, 5])
    ctrs = [c*3//2 for c in cols]
    trials = 10
    max_range = None
    if dataset == 'CAIDA':
        max_range = 50
    elif dataset == 'AOL':
        max_range = 80
    for i in range(max_range):
        print(i)
        # runCountSketchExperiment(data[i], cols, threshConsts, trials, preds=preds[i],
                      # outFolder='logs/aamand/aol/day{}-pred{}-'.format(i, withPredictions))
        runMisraGriesExperiment(data[i], ctrs, preds=preds[i],
                      outFolder=prefix+'mg-day{}-pred{}-'.format(i, withPredictions))

def main():
    # syntheticSmall()
    realExperiment('CAIDA', True)
    realExperiment('AOL', True)
    # realExperiment('CAIDA', False)
    # realExperiment('AOL', False)

if __name__ == "__main__":
    main()
