from utils import genDataLinear, distDiscrepancy
from linear import linearModel
import torch
import numpy as np
from scipy.stats import norm

method = 'pivotal'
# method = 'quantile'
# method = 'normal'

bootstrap_algo = 'standard'
# bootstrap_algo = 'bayesian'
# bootstrap_algo = 'residual'

def weighted_percentile(data, percents, weights=None):
    ''' percents in units of 1%
        weights specifies the frequency (count) of data.
    '''
    if weights is None:
        return np.percentile(data, percents)

    ind = np.argsort(data)
    d = data[ind]
    w = weights[ind]
    p = 1.*w.cumsum(0)/w.sum()*100

    low = np.interp(percents[0], p, d)
    high = np.interp(percents[1], p, d)

    return [low, high]

num_particles = [20, 50, 100, 200]
num_repeat = 1000
LR = linearModel(4)
n_train = 50

alphas = [0.8, 0.9, 0.95]
b_result = {}
c_result = {}
b_cov = {}
c_cov = {}
MLE = []

for num_particle in num_particles:
    b_result[num_particle] = []
    c_result[num_particle] = []
    b_cov[num_particle] = {}
    c_cov[num_particle] = {}
    for alpha in alphas:
        b_cov[num_particle][alpha] = np.zeros(num_repeat)
        c_cov[num_particle][alpha] = np.zeros(num_repeat)

print('get bootstrap samples')
for _ in range(num_repeat):
    print('num repeat = ', _)
    X, Y = genDataLinear(n_train)
    beta_mle = LR.fit(X, Y)
    MLE.append(beta_mle)
    for idx, num_particle in enumerate(num_particles):
        beta_bs = LR.bootstrap(X, Y, num_particle, method=bootstrap_algo)
        if num_particle == 20:
            lr_use = 0.5 / 50.
        else:
            lr_use = 0.5 / num_particle

        beta_ct, ctw = LR.centoridBootstrap(X, Y, num_particle, epochs=2000, lr=lr_use)
        b_result[num_particle].append(beta_bs)
        c_result[num_particle].append((beta_ct, ctw))

def get_ci(sample, alpha, method='quantile', pivotal=None, weight=None):
    if method == 'quantile':
        low = (1. - alpha)/2.
        up = (1. - low)
        return np.percentile(sample, [low * 100, up * 100])

    elif method == 'normal':
        low = (1. - alpha)/2.
        if weight is not None:
            b_mu = torch.sum(sample * weight)/torch.sum(weight)
            b_std = torch.sqrt(torch.sum((sample - b_mu)**2 * weight)/torch.sum(weight))
        else:
            b_std = torch.std(sample).numpy()
        return [pivotal.item() + norm.ppf(low) * b_std, pivotal.item() - norm.ppf(low) * b_std]
    elif method == 'pivotal':
        low = (1. - alpha) / 2.
        up = (1. - low)
        tmp = weighted_percentile(sample, [low * 100, up * 100], weights=weight)

        return [2*pivotal-tmp[1], 2*pivotal-tmp[0]]

for num_particle in num_particles:
    b_cov[num_particle] = {}
    c_cov[num_particle] = {}
    for alpha in alphas:
        b_cov[num_particle][alpha] = np.zeros(num_repeat)
        c_cov[num_particle][alpha] = np.zeros(num_repeat)

for num_particle in num_particles:
    for j in range(num_repeat):
        # bootstrap
        sample = b_result[num_particle][j][0, :].squeeze()
        for alpha in alphas:
            if method == 'normal':
                ci = get_ci(sample, alpha, method='normal', pivotal=MLE[j][0])
            elif method == 'pivotal':
                ci = get_ci(sample, alpha, method='pivotal', pivotal=MLE[j][0])
            else:
                ci = get_ci(sample, alpha)

            if 1. <= ci[1] and 1. >= ci[0]:
                b_cov[num_particle][alpha][j] += 1.

        # centroid
        sample = c_result[num_particle][j][0][0, :].squeeze()
        weight = c_result[num_particle][j][1].squeeze()
        for alpha in alphas:
            if method == 'normal':
                ci = get_ci(sample, alpha, method='normal', pivotal=MLE[j][0])
            elif method == 'pivotal':
                ci = get_ci(sample, alpha, method='pivotal', pivotal=MLE[j][0])
            else:
                ci = get_ci(sample, alpha, weight=weight)
            if 1. <= ci[1] and 1. >= ci[0]:
                c_cov[num_particle][alpha][j] += 1.

for num_particle in num_particles:
    print('num particles: ', num_particle)
    for alpha in alphas:
        print('alpha: ', alpha)

        b_res = b_cov[num_particle][alpha]
        b_avg = np.mean(b_res)
        b_std = np.std(b_res)/np.sqrt(num_repeat)

        c_res = c_cov[num_particle][alpha]
        c_avg = np.mean(c_res)
        c_std = np.std(c_res) / np.sqrt(num_repeat)

        print('bootstrap mean: ', alpha - b_avg)
        print('bootstrap std: ', b_std)
        print('centroid mean: ', alpha - c_avg)
        print('centroid std: ', c_std)




