import argparse
import numpy as np
from functools import partial
from scipy.stats import levy_stable, iqr, binom
from scipy.special import gamma
from tqdm import tqdm
import pandas as pd

def parse_args():
    parser = argparse.ArgumentParser(description="parse parameters")
    parser.add_argument('--stats', default="moments")
    parser.add_argument('--domain_size', type=int, default=1000)
    parser.add_argument('--value_bound', type=int, default=1)
    # todo: change the size of sketch
    parser.add_argument('--sketch_size', type=int, default=50)
    parser.add_argument('--dist', default='binomial')
    parser.add_argument('--iters', type=int, default=100)
    parser.add_argument('--stream_size', type=int, default=1000)
    parser.add_argument('--alpha', type=float, default=1.5)
    parser.add_argument('--query_method', default='geometric_mean')
    parser.add_argument('--sampling_rate', type=float, default=0.1)
    args = parser.parse_args()
    return args

def frequency_moment(alpha, x):
    return np.sum(np.power(x, alpha))

def entropy(x):
    return np.sum(np.multiply(np.log(x+1e-7), x))

def log_norm(x):
    return np.sum(np.log(x))

def read_data():
    df = pd.read_csv("../data/app_events.csv")
    df['event_id'] = df['event_id'].astype('category')
    df['event_id_cat'] = df['event_id'].cat.codes
    return df['event_id_cat'].tolist()

class KV_stream:

    def __init__(self, domain_size, value_bound, dist='uniform'):
        self.domain_size = domain_size
        self.value_bound = value_bound
        self.dist = dist
        if dist == 'uniform':
            self.p_k = partial(np.random.randint, low=0, high=domain_size)
        elif dist == 'binomial':
            self.p_k = partial(binom.rvs, n=domain_size-1, p=0.5)
        elif dist == 'app':
            self.p_k = None
        else:
            raise NotImplementedError
        if dist == 'app':
            self.p_v = lambda x: 1
        else:
            self.p_v = partial(np.random.randint, low=1, high=value_bound+1)
        self.prepared = None
        self.idx = -1
        self.fm = None
        self.log_norm = None
        self.entropy = None

    def prepare(self, size):
        if self.dist == 'app':
            if self.prepared is not None:
                return
            self.size = 32473067 
            data = read_data()
            # print(data[-1])
            self.prepared = list(zip(data, np.ones_like(data)))
        else:
            self.size = size
            self.prepared = list(zip(self.p_k(size=size), self.p_v(size=size)))

    def sample(self):
        if self.prepared is None:
            raise NotImplementedError
        else:
            self.idx += 1
            return self.prepared[self.idx]

    def refresh(self):
        self.prepare(self.size)
        self.idx = -1

    def exact_fm(self, alpha=1.5):
        if self.fm is not None:
            return self.fm
        if self.prepared is not None:
            hist = np.zeros(self.domain_size)
            for k, v in self.prepared:
               hist[k] += v
            return frequency_moment(alpha, hist)
        else:
            raise NotImplementedError

    def exact_entropy(self):
        if self.entropy is not None:
            return self.entropy
        if self.prepared is not None:
            hist = np.zeros(self.domain_size)
            for k, v in self.prepared:
               hist[k] += v
            return entropy(hist)
        else:
            raise NotImplementedError

    def exact_log_norm(self):
        if self.log_norm is not None:
            return self.log_norm
        if self.prepared is not None:
            hist = np.zeros(self.domain_size)
            for k, v in self.prepared:
               hist[k] += v
            return log_norm(hist)
        else:
            raise NotImplementedError

def median(alpha, x):
    return np.median(np.power(np.abs(x), alpha))

def interquantile(alpha, x):
    return iqr(np.power(np.abs(x), alpha))

#todo: implement geometric mean and harmonic mean for smaller alpha
def geometric_mean(alpha, sketch_size, x):
    return np.prod(np.power(np.abs(x), alpha/sketch_size))/np.power(2*gamma(alpha/sketch_size)*gamma(1-1/sketch_size)*np.sin(np.pi*alpha/2/sketch_size)/np.pi, sketch_size)

def harmonic_mean(alpha, sketch_size, x):
   raise NotImplementedError

class F_sketch:

    def __init__(self, domain_size, sketch_size, alpha=-1, sampling_rate=None):
        self.alpha = alpha
        self.domain_size = domain_size
        self.sketch_size = sketch_size
        self.sampling_rate = sampling_rate
        if self.alpha == -1:
            self.rv_upper = levy_stable(1.05, 0)
            self.rv_lower = levy_stable(0.95, 0)
            self.pm_upper = self.rv_upper.rvs(size=[sketch_size, domain_size])
            self.pm_lower = self.rv_lower.rvs(size=[sketch_size, domain_size])
            self.sketch_upper = np.zeros(sketch_size)
            self.sketch_lower = np.zeros(sketch_size)
        else:
            self.rv = levy_stable(alpha, 0)
            self.pm = self.rv.rvs(size=[sketch_size, domain_size])
            self.sketch = np.zeros(sketch_size)
        self.idx = 0

    def insert(self, k, v):
        if self.sampling_rate is None and alpha != -1:
            self.sampling_rate = 1./self.sketch_size
        elif self.sampling_rate is None and alpha == -1:
            self.sampling_rate = 1/0.15/self.sketch_size
        if self.sampling_rate != 1:
            # v = np.multiply(np.random.choice(2, size=self.sketch_size, p=[1-self.sampling_rate, self.sampling_rate]), v)
            v = np.random.choice(2, p=[1-self.sampling_rate, self.sampling_rate]) * v
        if self.alpha == -1:
            self.sketch_upper += np.multiply(self.pm_upper[:, k], v)
            self.sketch_lower += np.multiply(self.pm_lower[:, k], v)
        else:
            self.sketch += np.multiply(self.pm[:, k], v)
        self.idx = self.idx + 1

    # def query_entropy(self, low=0.95, high=1.05, **kwargs):
    #     return (self.query(alpha=high) - self.query(alpha=low)) / 2. / (high-1)

    def query(self, alpha, method='median'):
        if alpha == -1:
            upper = geometric_mean(1.05, self.sketch_size, self.sketch_upper) / np.power(self.sampling_rate, 1.05)
            lower = geometric_mean(0.95, self.sketch_size, self.sketch_lower) / np.power(self.sampling_rate, 0.95)
            return 10 * (upper - lower)
        else:
            if method == 'iqr':
                return interquantile(alpha, self.sketch)
            elif method == 'median':
                return median(alpha, self.sketch)
            elif method == 'geometric_mean':
                return geometric_mean(alpha, self.sketch_size, self.sketch) / np.power(self.sampling_rate, alpha)
            elif method == 'harmonic_mean':
                # todo: assert self.alpha < 1
                return harmonic_mean(alpha, self.sketch_size, self.sketch) / np.power(self.sampling_rate, alpha)
        

    def refresh(self):
        self.idx = 0
        if self.alpha == -1:
            self.pm_upper = self.rv_upper.rvs(size=[self.sketch_size, self.domain_size])
            self.pm_lower = self.rv_lower.rvs(size=[self.sketch_size, self.domain_size])
            self.sketch_upper = np.zeros(self.sketch_size)
            self.sketch_lower = np.zeros(self.sketch_size)
        else:
            self.pm = self.rv.rvs(size=[self.sketch_size, self.domain_size])
            self.sketch = np.zeros(self.sketch_size)

def estimate_fm(stream, args):

    sketch = F_sketch(args.domain_size, args.sketch_size, args.alpha)

    stats = []
    results = []
    for i in range(args.iters):
        stream.refresh()
        stats.append(stream.exact_fm(args.alpha))
        sketch.refresh()
        for j in range(args.stream_size):
            k, v = stream.sample()
            sketch.insert(k, v)
        results.append(sketch.query(args.alpha, method=args.query_method))

    return stats, results

def estimate_entropy(stream, args):

    sketch = F_sketch(args.domain_size, args.sketch_size, args.alpha)

    stats = []
    results = []
    for i in range(args.iters):
        stream.refresh()
        stats.append(stream.exact_entropy())
        sketch.refresh()
        for j in range(args.stream_size):
            k, v = stream.sample()
            sketch.insert(k, v)
        results.append(sketch.query(-1))

    return stats, results

def mul_error(res, acc):
    return np.abs(np.divide(acc, res) - 1)

def evaluate(args):
    stream = KV_stream(args.domain_size, args.value_bound, args.dist)
    stream.prepare(args.stream_size)

    # todo: fix random seed to fix data stream
    # todo: gradually increase the size of the data stream and show the trade-off between space and error and privacy and error
    if args.stats == 'moments':
        stats, results = estimate_fm(stream, args)
        errors = mul_error(stats, results)

    elif args.stats == 'log_norm':
        args.alpha = 1.05
        results = args.domain_size * np.log(estimate_fm(stream, args) / args.domain_size) / args.alpha
        errors = mul_error(stats, results)

    #todo: this is not correct now
    elif args.stats == 'entropy':
        args.alpha = -1
        stats, results = estimate_entropy(stream, args)
        print(stats, results)
        errors = mul_error(stats, results)

    median = np.median(errors)
    interquantile = iqr(errors)

    return median, interquantile

if __name__ == '__main__':

    args = parse_args()

    if args.dist == 'app':
        args.stream_size = 32473067
        args.domain_size = 1488096 
        file_template = "../results/app.txt"
        f = open(file_template, "w")
        f.write("x, y, err\n")
        for alpha in [0.25, 0.5, 1.0, 1.25, 1.5, 1.75, 2.0]:
            args.alpha = alpha
            mean, std = evaluate(args)
            print(mean, std)
            f.write("%f, %f, %f\n"%(alpha, mean, std))
        f.close()
    else:
        file_template = "../results/%s_%.2f.txt"

        for alpha in [args.alpha]: # [0.25, 0.5, 1.0, 1.25, 1.5, 1.75, 2.0]:
            print("Deal with the case when alpha = %.2f"%(alpha))
            f = open(file_template%(args.dist, alpha), "w")
            f.write("x, y, err\n")
            args.alpha = alpha
            for stream_size in tqdm(np.ceil(np.logspace(4, 7, num=10)).astype(int)):
                args.stream_size = stream_size
                mean, std = evaluate(args)
                print(mean, std)
                f.write("%i, %f, %f\n"%(stream_size, mean, std))
            f.close()

    # todo: include privacy parameters and error bound calculation
