from typing import List

from data import *
from julia import Main
import math


def normalize(prob: List[float]) -> List[float]:
    """Normalize the probabilities to make it a distribution"""
    return prob / prob.sum()


def GSW(X):
    """ Gram-Schmidt-Walk design -- a call to the previous implementation with robustness and balance parameter set to 0.5"""
    Main.eval('push!(LOAD_PATH, "./")')
    from julia import GSWDesign
    GSWDesign.X = np.array(X)
    GSWDesign.lamda = 0.5
    z = GSWDesign.sample_gs_walk(GSWDesign.X, GSWDesign.lamda)
    # print(z)
    return z


def rev_idx_map(X, part):
    """Reverse index map that selects a part of the dataset and stores the mapping of the original row """
    rmap = dict({})
    part_idx = 0
    Xaug = []

    # rmap stores the mapping of the new dataset Xaug to the original dataset X
    for i in range(0, len(part)):
        Xaug.append(X[part[part_idx]])
        rmap[i] = part[part_idx]
        part_idx = part_idx + 1

    return Xaug, rmap


def unif_sampling(X, s):
    """Uniform sampling of rows with replacement"""
    n, d = X.shape
    uscores = np.ones(n)

    pi = normalize(uscores)
    sampled_rows = np.random.choice(range(n), s, p=pi, replace=False)  # sampling with replacement

    return list(set(sampled_rows)), pi


def ridge_lev_score_sampling_two_outcome(X, s):
    """Two phase sampling for regression adjustment"""
    n, d = X.shape
    lev_sample_prob = min(10.0*d*np.log(d)/s, 1.0/3.0)
    partition = np.random.choice([0, 1, 2], size=n, p=[lev_sample_prob, lev_sample_prob, 1-2*lev_sample_prob])
    index_set = np.arange(n)

    # compute the maximum norm over rows of X
    zeta = np.max(np.linalg.norm(X, axis=1))

    # Compute the ridge leverage scores
    rlev_scores = np.linalg.norm(X@np.linalg.inv(X.T@X + 3 * zeta**2 * np.identity(d))@np.transpose(X), axis=0)**2

    # Sampling
    rlev_scores_1 = np.zeros(len(rlev_scores))
    rlev_scores_0 = np.zeros(len(rlev_scores))
    rlev_scores_1[partition==1] = rlev_scores[partition==1]
    rlev_scores_0[partition==0] = rlev_scores[partition==0]
    lev_sample_num_1 = int(np.floor(float(np.floor(float(s) / 2)) / 2))
    lev_sample_num_0 = int(np.floor(float(np.ceil(float(s) / 2)) / 2))
    unif_sample_num = int(np.ceil(float(s) / 2))
    rlev_prob_1 = rlev_scores_1/sum(rlev_scores_1)
    rlev_prob_0 = rlev_scores_0/sum(rlev_scores_0)
    lev_samples_1 = np.random.choice(index_set, lev_sample_num_1, p=rlev_prob_1, replace=False)
    lev_samples_0 = np.random.choice(index_set, lev_sample_num_0, p=rlev_prob_0, replace=False)
    unif_samples = np.random.choice(index_set[partition==2], unif_sample_num)

    # Compute probabilities for weights of ridge leverage scores
    p_1 = np.minimum(float(lev_sample_num_1) * rlev_prob_1 * lev_sample_prob, 1.0)
    p_0 = np.minimum(float(lev_sample_num_0) * rlev_prob_0 * lev_sample_prob, 1.0)

    return lev_samples_1, lev_samples_0, list(set(unif_samples)), p_1[lev_samples_1], p_0[lev_samples_0]


def ridge_lev_score_sampling_two_outcome_2(X, s):
    """Two phase sampling for regression adjustment"""
    n, d = X.shape
    oversample_param = np.log(d)*15.0
    index_set = np.arange(n)

    # compute the maximum norm over rows of X
    zeta = np.max(np.linalg.norm(X, axis=1))

    # Compute the ridge leverage scores
    #rlev_scores = np.linalg.norm(X@np.linalg.inv(X.T@X + 3 * zeta**2 * np.identity(d))@np.transpose(X), axis=0)**2
    rlev_scores = np.linalg.norm(X@np.linalg.inv(X.T@X)@np.transpose(X), axis=0)**2

    # Sampling
    lev_samples_1, lev_samples_0 = [], []
    for idx in range(0, len(index_set)):
        if np.random.rand() <= min(rlev_scores[idx] * oversample_param, 1.0):
            lev_samples_1.append(idx)
        if np.random.rand() <= min(rlev_scores[idx] * oversample_param, 1.0):
            lev_samples_0.append(idx)

    unif_samples = np.random.choice(index_set, s, replace=False)

    # Compute probabilities for weights of ridge leverage scores
    p = np.minimum(rlev_scores * oversample_param, 1.0)

    return lev_samples_1, lev_samples_0, list(set(unif_samples)), p[lev_samples_1], p[lev_samples_0]


def unif_estimator(X, Y1, Y0, s=0):
    """Estimator for uniform sampling. This is one of the baselines"""
    n, d = X.shape
    tauS = 0

    sampled_rows, pi = unif_sampling(X, s)

    tau1, tau0 = 0.0, 0.0

    # HT estimator with inverse probability weighting to make it unbiased.
    for i in range(0, len(sampled_rows)):
        if np.random.rand() <= 0.5:
            tau1 += (2.0 / pi[sampled_rows[i]]) * Y1[sampled_rows[i]]
        else:
            tau0 += (2.0 / pi[sampled_rows[i]]) * Y0[sampled_rows[i]]

    tauS = (1.0 / float(s * n)) * (tau1 - tau0)

    return tauS


def ATE_gsw_estimator(X, Y1, Y0, s=0):
    """ATE estimator on the coreset using just recursive calls to GSW. This is the estimator of our algorithm."""
    n, d = X.shape
    part1, part0, niter = coreset(X, s)
    niter1, niter0 = niter, niter

    tau1, tau0 = 0.0, 0.0

    for idx in part1:
        tau1 += Y1[idx]

    for idx in part0:
        tau0 += Y0[idx]

    # HT estimator
    tauS = (1.0 / float(n)) * (2 ** niter1 * (tau1) - 2 ** niter0 * (tau0))

    return tauS


def ATE_gsw_pop_estimator(X, Y1, Y0):
    """ATE estimator on the entire population using just a single call to GSW. This is one of the population level baselines for comparison."""
    n, d = X.shape
    sign = GSW(X)
    part1, part0 = [], []
    for idx in range(0, len(sign)):
        if sign[idx]:
            part1.append(idx)
        else:
            part0.append(idx)

    tau1, tau0 = 0.0, 0.0

    for idx in part1:
        tau1 += Y1[idx]

    for idx in part0:
        tau0 += Y0[idx]

    tauS = (1.0 / float(n)) * (2.0 * (tau1) - 2.0 * (tau0))

    return tauS


def ATE_rand_pop_estimator(X, Y1, Y0):
    """ATE using complete randomization. This is another population level baseline."""
    n, d = X.shape
    part1, part0 = [], []
    for idx in range(0, n):
        if np.random.rand() <= 0.5:
            part1.append(idx)
        else:
            part0.append(idx)

    tau1, tau0 = 0.0, 0.0

    for idx in part1:
        tau1 += Y1[idx]

    for idx in part0:
        tau0 += Y0[idx]

    # HT estimator using GSW on the entire population
    tauS = (1.0 / float(n)) * (2.0 * (tau1) - 2.0 * (tau0))

    return tauS


def ATE(Y1, Y0):
    """ATE assuming access to treatment and control values for the entire population"""
    n = len(Y1)
    tau = 0.0
    for i in range(0, n):
        tau += Y1[i] - Y0[i]

    tau = (float(1.0) / float(n)) * tau
    return tau


def unif_gsw(X, Y1, Y0, s=0):
    """Estimator for uniform sampling plus Gram-Schmidt Walk."""
    n, d = X.shape
    tauS = 0

    sampled_rows, pi = unif_sampling(X, s)

    sign = GSW(X[sampled_rows,:])
    part1, part0 = [], []
    for idx in range(0, len(sign)):
        if sign[idx]:
            part1.append(idx)
        else:
            part0.append(idx)

    tau1, tau0 = 0.0, 0.0

    for idx in part1:
        tau1 += Y1[sampled_rows[idx]]

    for idx in part0:
        tau0 += Y0[sampled_rows[idx]]

    tauS = (1.0 / float(s)) * (2.0 * (tau1) - 2.0 * (tau0))

    return tauS


def reg_adj_with_rlev_score(X, Y1, Y0, s=0):
    """Estimator for uniform sampling plus Gram-Schmidt Walk."""
    n, d = X.shape
    tauS = 0
    #X = np.concatenate((X, np.ones((n,1))), axis=1)
    X = X - np.mean(X, axis=0)

    lev_samples_1, lev_samples_0, unif_samples, p_1, p_0 = ridge_lev_score_sampling_two_outcome_2(X, s)
    #unif_samples, pi = unif_sampling(X, s)

    beta_1 = np.linalg.lstsq(X[lev_samples_1,:]/p_1[:, np.newaxis], Y1.T[lev_samples_1]/p_1)[0]
    beta_0 = np.linalg.lstsq(X[lev_samples_0,:]/p_0[:, np.newaxis], Y0.T[lev_samples_0]/p_0)[0]

    tau1, tau0 = 0.0, 0.0

    num_1 = 0
    num_0 = 0
    for i in range(0, len(unif_samples)):
        if np.random.rand() <= 0.5:
            tau1 += Y1[unif_samples[i]]
            num_1 += 1
        else:
            tau0 += Y0[unif_samples[i]]
            num_0 += 1
    tau1 /= float(num_1)
    tau0 /= float(num_0)

    #beta_1 = np.linalg.lstsq(X,Y1)[0]
    #beta_0 = np.linalg.lstsq(X,Y0)[0]
    #beta_2 = np.linalg.lstsq(X,Y1-Y0)[0]
    tauS = (tau1 - tau0) + (1.0 / float(len(unif_samples))) * sum(X[unif_samples,:] @ (beta_1 - beta_0))
    #tauS =  (tau1 - tau0) + (1.0 / float(len(unif_samples))) * sum(X[unif_samples,:] @ (beta_2))

    return tauS


def raht_gsw(X, Y1, Y0, s=0):
    """Estimator for uniform sampling plus Gram-Schmidt Walk."""
    n, d = X.shape
    tauS = 0

    X = np.concatenate((X, np.ones((n,1))), axis=1)
    #X = X - np.mean(X, axis=0)

    lev_samples_1, lev_samples_0, unif_samples, p_1, p_0 = ridge_lev_score_sampling_two_outcome_2(X, s)
    #unif_samples, pi = unif_sampling(X, s)

    beta_1 = np.linalg.lstsq(X[lev_samples_1,:]/p_1[:, np.newaxis], Y1.T[lev_samples_1]/p_1)[0]
    beta_0 = np.linalg.lstsq(X[lev_samples_0,:]/p_0[:, np.newaxis], Y0.T[lev_samples_0]/p_0)[0]

    sign = GSW(X[unif_samples,:])
    part1, part0 = [], []
    for idx in range(0, len(sign)):
        if sign[idx]:
            part1.append(idx)
        else:
            part0.append(idx)

    tau1, tau0 = 0.0, 0.0

    for idx in part1:
        tau1 += Y1[unif_samples[idx]]

    for idx in part0:
        tau0 += Y0[unif_samples[idx]]

    #beta_1 = np.linalg.lstsq(X,Y1)[0]
    #beta_0 = np.linalg.lstsq(X,Y0)[0]
    #beta_2 = np.linalg.lstsq(X,Y1-Y0)[0]
    tauS = (1.0 / float(len(unif_samples))) * ((2.0 * (tau1) - 2.0 * (tau0)) + sum(X[unif_samples,:] @ (beta_1 - beta_0)))
    #tauS = (1.0 / float(len(unif_samples))) * ((2.0 * (tau1) - 2.0 * (tau0)) + sum(X[unif_samples,:] @ (beta_2)))

    return tauS
