
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


# ----------------------------
# 1. Basic Functions and Swap Rounding
# ----------------------------
def logistic(x, gamma0=0, gamma1=1):
    """Compute the logistic function element-wise."""
    return 1.0 / (1.0 + np.exp(-(gamma0 + gamma1 * x)))

def swap_rounding_with_probabilities_fixed_variance(p, Y):
    """
    Standard swap rounding that selects an arbitrary pair of fractional indices.
    
    Returns:
        A         : Final (binary) assignments.
        swaps     : List of swap details.
        p_history : List of probability vectors before each swap.
    """
    n = len(p)
    p_current = p.copy()
    swaps = []
    p_history = []
    while True:
        idx = np.where((p_current > 0) & (p_current < 1))[0]
        if len(idx) < 2:
            break
        i, j = idx[0], idx[1]
        p_history.append(p_current.copy())
        alpha = min(p_current[i], 1 - p_current[j], p_current[j], 1 - p_current[i])
        prob_i = p_current[i] / (p_current[i] + p_current[j])
        if np.random.rand() < prob_i:
            p_current[i] += alpha
            p_current[j] -= alpha
        else:
            p_current[i] -= alpha
            p_current[j] += alpha
        swaps.append({'i': i, 'j': j, 'alpha': alpha, 'prob_i': prob_i})
    A = (p_current >= 0.5).astype(int)
    return A, swaps, p_history

# ----------------------------
# 2. Additional Independent Assignment Methods
# ----------------------------
def random_assignment_based_on_total_probabilities(p):
    n = len(p)
    total = int(np.round(np.sum(p)))
    A = np.zeros(n, dtype=int)
    indices = np.random.choice(n, total, replace=False)
    A[indices] = 1
    return A

def bernoulli_sampling_until_total_assignments(p):
    n = len(p)
    total = int(np.round(np.sum(p)))
    A = np.zeros(n, dtype=int)
    count = 0
    while count < total:
        sample = np.random.binomial(1, p, n)
        for i in range(n):
            if sample[i] == 1 and A[i] == 0:
                A[i] = 1
                count += 1
                if count == total:
                    break
    return A

def conditional_poission_sample(p):
    total = int(np.round(np.sum(p)))
    while True:
        sample = np.random.binomial(1, p)
        if sample.sum() == total:
            return sample

# ----------------------------
# 3. Covariate-Ordered Swap Rounding
# ----------------------------
def nearest_neighbor_ordering(X):
    n = X.shape[0]
    remaining = list(range(n))
    order = []
    current = remaining.pop(0)
    order.append(current)
    while remaining:
        distances = np.linalg.norm(X[remaining] - X[current], axis=1)
        idx = np.argmin(distances)
        current = remaining.pop(idx)
        order.append(current)
    return order

def covariate_ordered_swap_rounding(p, X, max_iter=1000):
    n = len(p)
    p_current = p.copy()
    swap_details = []
    order = nearest_neighbor_ordering(X)
    for _ in range(max_iter):
        swap_made = False
        for k in range(n):
            i = order[k]
            j = order[(k+1) % n]
            if 0 < p_current[i] < 1 and 0 < p_current[j] < 1:
                alpha = min(p_current[i], 1 - p_current[j], p_current[j], 1 - p_current[i])
                prob_i = p_current[i] / (p_current[i] + p_current[j])
                if np.random.rand() < prob_i:
                    p_current[i] += alpha
                    p_current[j] -= alpha
                else:
                    p_current[i] -= alpha
                    p_current[j] += alpha
                swap_details.append({'i': i, 'j': j, 'alpha': alpha, 'prob_i': prob_i})
                swap_made = True
        if not swap_made:
            break
    fractional = np.where((p_current > 0) & (p_current < 1))[0]
    while len(fractional) >= 2:
        i, j = fractional[0], fractional[1]
        alpha = min(p_current[i], 1 - p_current[j], p_current[j], 1 - p_current[i])
        prob_i = p_current[i] / (p_current[i] + p_current[j])
        if np.random.rand() < prob_i:
            p_current[i] += alpha
            p_current[j] -= alpha
        else:
            p_current[i] -= alpha
            p_current[j] += alpha
        swap_details.append({'i': i, 'j': j, 'alpha': alpha, 'prob_i': prob_i})
        fractional = np.where((p_current > 0) & (p_current < 1))[0]
    A = (p_current >= 0.5).astype(int)
    return A, swap_details

# ----------------------------
# 4. Re-randomization Assignment Using Vectorization
# ----------------------------
def compute_mahalanobis_threshold(X, num_randomizations=10, quantile=0.1):
    n = X.shape[0]
    S = np.cov(X, rowvar=False)
    invS = np.linalg.pinv(S)
    distances = []
    for _ in range(num_randomizations):
        A_rand = np.random.binomial(1, 0.5, n)
        if A_rand.sum() == 0 or A_rand.sum() == n:
            continue
        mean_treated = np.mean(X[A_rand == 1], axis=0)
        mean_control = np.mean(X[A_rand == 0], axis=0)
        diff = mean_treated - mean_control
        d = np.sqrt(np.dot(diff, np.dot(invS, diff)))
        distances.append(d)
    if len(distances) == 0:
        return np.inf
    return np.quantile(np.array(distances), quantile)

def re_randomization_assignment_vectorized(p, X, num_randomizations=10, quantile=0.1):
    n = len(p)
    T = compute_mahalanobis_threshold(X, num_randomizations=num_randomizations, quantile=quantile)
    S = np.cov(X, rowvar=False)
    invS = np.linalg.pinv(S)
    A_all = np.random.binomial(1, p, (num_randomizations, n))
    distances = np.empty(num_randomizations)
    for i in range(num_randomizations):
        A = A_all[i]
        if A.sum() == 0 or A.sum() == n:
            distances[i] = np.inf
        else:
            mean_treated = np.mean(X[A == 1], axis=0)
            mean_control = np.mean(X[A == 0], axis=0)
            diff = mean_treated - mean_control
            distances[i] = np.sqrt(np.dot(diff, np.dot(invS, diff)))
    valid_idx = np.where(distances <= T)[0]
    if valid_idx.size > 0:
        chosen_idx = np.random.choice(valid_idx)
    else:
        chosen_idx = np.argmin(distances)
    return A_all[chosen_idx], distances[chosen_idx]

def re_randomization_assignment(p, X, num_randomizations=10, quantile=0.1, max_iter=50):
    T = compute_mahalanobis_threshold(X, num_randomizations=num_randomizations, quantile=quantile)
    n = len(p)
    S = np.cov(X, rowvar=False)
    invS = np.linalg.pinv(S)
    for _ in range(max_iter):
        A = np.random.binomial(1, p)
        if A.sum() == 0 or A.sum() == n:
            continue
        mean_treated = np.mean(X[A == 1], axis=0)
        mean_control = np.mean(X[A == 0], axis=0)
        diff = mean_treated - mean_control
        d = np.sqrt(np.dot(diff, np.dot(invS, diff)))
        if d <= T:
            return A, d
    return A, d

# ----------------------------
# 4a. Maximum Simulation for Effective Propensity Score
# ----------------------------
def simulate_effective_propensity(p, X, num_randomizations=10, quantile=0.1, M=10):
    n = len(p)
    accepted_assignments = np.zeros((M, n))
    count = 0
    while count < M:
        A, d = re_randomization_assignment_vectorized(p, X, num_randomizations=num_randomizations, quantile=quantile)
        if A.sum() == 0 or A.sum() == n:
            continue
        accepted_assignments[count, :] = A
        count += 1
    effective_ps = accepted_assignments.mean(axis=0)
    return effective_ps

# 5. Variance Estimator Functions (IPW)
# ----------------------------
def compute_rho_ij(p0_i, p0_j):
    if p0_i + p0_j <= 1:
        return -p0_i * p0_j
    else:
        return - (1 - p0_i) * (1 - p0_j)

def compute_IPW_variance_swap(A, Y, p0, tau_hat, swaps):
    n = len(p0)
    term1 = np.sum(A * (Y**2) / (p0**2) + (1 - A) * (Y**2) / ((1 - p0)**2))
    term2 = n * (tau_hat**2)
    term3 = 0.0
    for swap in swaps:
        i = swap['i']
        j = swap['j']
        delta_ij = (A[i]*A[j]*Y[i]*Y[j] / ((p0[i]**2) * (p0[j]**2)) +
                    A[i]*(1-A[j])*Y[i]*Y[j] / ((p0[i]**2) * ((1-p0[j])**2)) +
                    (1-A[i])*A[j]*Y[i]*Y[j] / (((1-p0[i])**2) * (p0[j]**2)) +
                    (1-A[i])*(1-A[j])*Y[i]*Y[j] / (((1-p0[i])**2) * ((1-p0[j])**2)))
        weight = compute_rho_ij(p0[i], p0[j])
        term3 += weight * delta_ij
    variance_est = (term1 - term2 + 2 * term3) / (n**2)
    variance_est2 = (term1 - term2) / (n**2)
    return variance_est, variance_est2

# 6. Simulation Study Across Different n (with tqdm as before)
# ----------------------------
n_values = [50, 100, 250, 500, 1000, 5000]
#n_values = [50]
default_n_iter = 100  # reduced for speed testing
default_S = 100       # reduced for speed testing

methods = ["No Swap (IPW)", "Swap Rounding", "Covariate Matched Swap",
           "Random Assignment", "Limited Bernoulli", "Re-randomization", "Self-Normalized IPW"]

avg_variance = {method: [] for method in methods}
ci_lower = {method: [] for method in methods}
ci_upper = {method: [] for method in methods}

# We'll also record the average tau_hat for re-randomization for each n.
avg_tau_rerand_by_n = {}

# Outer loop using tqdm(range(...)) exactly as in your example.
for n in n_values:
    print(f"\nProcessing sample size n = {n}")
    if n == 5000:
        curr_n_iter = 15
        curr_S = 10
    elif n == 1000:
        curr_n_iter = 50
        curr_S = 20
    elif n == 500:
        curr_n_iter = 100
        curr_S = 25
    else:
        curr_n_iter = default_n_iter
        curr_S = default_S

    iter_variances = {method: [] for method in methods}
    # This list will store the average tau_hat for re-randomization for each outer iteration.
    tau_rerand_outer = []

    for iteration in tqdm(range(curr_n_iter), desc="Iterations for n = " + str(n)):
        X = np.random.normal(0, 1, (n, 3))
        epsilon = np.random.normal(0, 1, n)
        gamma_assignment = np.random.uniform(-1, 1, size=3)
        p0 = logistic(X.dot(gamma_assignment))
        p0 = np.clip(p0, 0.01, 0.99)
        p0 = np.random.uniform(.05,.95,n)
        beta0 = 0
        true_beta = np.array([1, 1, 1])
        tau_true = 2
        Y0 = beta0 + X.dot(true_beta)
        Y1 = Y0 + tau_true + epsilon
        
        # Compute effective propensity once per outer iteration.
        # (Note: In your original version you computed effective_ps once per dataset.)
        effective_ps = simulate_effective_propensity(p0, X, num_randomizations=100, quantile=0.1, M=25)
        effective_ps = np.clip(effective_ps, 0.05, 0.95)
        
        tau_no_swap = []
        tau_swap = []
        tau_cov = []
        tau_random = []
        tau_bernoulli = []
        tau_rerand = []
        tau_selfnorm = []
        
        for s in range(curr_S):
            # Method 1: Standard IPW (No Swap)
            A_no_swap = np.random.binomial(1, p0, n)
            Y_no_swap = A_no_swap * Y1 + (1 - A_no_swap) * Y0
            tau_hat_no = np.mean(A_no_swap * Y_no_swap / p0 - (1 - A_no_swap) * Y_no_swap / (1 - p0))
            tau_no_swap.append(tau_hat_no)
            
            # Method 2: Swap Rounding
            A_swap, swaps, _ = swap_rounding_with_probabilities_fixed_variance(p0, Y1)
            Y_swap = A_swap * Y1 + (1 - A_swap) * Y0
            tau_hat_swap = np.mean(A_swap * Y_swap / p0 - (1 - A_swap) * Y_swap / (1 - p0))
            tau_swap.append(tau_hat_swap)
            
            # Method 3: Covariate Matched Swap Rounding
            A_cov, cov_swaps = covariate_ordered_swap_rounding(p0, X)
            Y_cov = A_cov * Y1 + (1 - A_cov) * Y0
            tau_hat_cov = np.mean(A_cov * Y_cov / p0 - (1 - A_cov) * Y_cov / (1 - p0))
            tau_cov.append(tau_hat_cov)
            
            # Method 4: Random Assignment
            A_random = random_assignment_based_on_total_probabilities(p0)
            p_uniform = int(np.round(np.sum(p0))) / n
            Y_random = A_random * Y1 + (1 - A_random) * Y0
            tau_hat_random = np.mean(A_random * Y_random / p_uniform - (1 - A_random) * Y_random / (1 - p_uniform))
            tau_random.append(tau_hat_random)
            
            # Method 5: Limited Bernoulli Assignment
            A_bernoulli = conditional_poission_sample(p0)
            Y_bernoulli = A_bernoulli * Y1 + (1 - A_bernoulli) * Y0
            tau_hat_bernoulli = np.mean(A_bernoulli * Y_bernoulli / p0 - (1 - A_bernoulli) * Y_bernoulli / (1 - p0))
            tau_bernoulli.append(tau_hat_bernoulli)
            
            # Method 6: Re-randomization with Effective Propensity Adjustment
            A_rerand, d_val = re_randomization_assignment(p0, X, num_randomizations=50, quantile=0.1, max_iter=50)
            Y_rerand = A_rerand * Y1 + (1 - A_rerand) * Y0
            tau_hat_rerand = np.mean(A_rerand * Y_rerand / effective_ps - (1 - A_rerand) * Y_rerand / (1 - effective_ps))
            tau_rerand.append(tau_hat_rerand)
            
            # Method 7: Self-Normalized IPW (No Swaps)
            A_sn = np.random.binomial(1, p0, n)
            Y_sn = A_sn * Y1 + (1 - A_sn) * Y0
            numer_treat = np.sum(A_sn * Y_sn / p0)
            denom_treat = np.sum(A_sn / p0)
            numer_control = np.sum((1 - A_sn) * Y_sn / (1 - p0))
            denom_control = np.sum((1 - A_sn) / (1 - p0))
            tau_hat_sn = numer_treat / denom_treat - numer_control / denom_control
            tau_selfnorm.append(tau_hat_sn)
        
        # Save the variance for each method (for later analysis)
        iter_variances["No Swap (IPW)"].append(np.var(tau_no_swap, ddof=1))
        iter_variances["Swap Rounding"].append(np.var(tau_swap, ddof=1))
        iter_variances["Covariate Matched Swap"].append(np.var(tau_cov, ddof=1))
        iter_variances["Random Assignment"].append(np.var(tau_random, ddof=1))
        iter_variances["Limited Bernoulli"].append(np.var(tau_bernoulli, ddof=1))
        iter_variances["Re-randomization"].append(np.var(tau_rerand, ddof=1))
        iter_variances["Self-Normalized IPW"].append(np.var(tau_selfnorm, ddof=1))
        
        # For re-randomization, record the average tau_hat from this outer iteration
        avg_tau_iter = np.mean(tau_rerand)
        tau_rerand_outer.append(avg_tau_iter)
    
    # After finishing all iterations for a given n, compute overall averages
    for method in methods:
        vals = np.array(iter_variances[method])
        mean_var = np.mean(vals)
        std_err = np.std(vals, ddof=1) / np.sqrt(curr_n_iter)
        ci_lower_val = mean_var - 1.96 * std_err
        ci_upper_val = mean_var + 1.96 * std_err
        avg_variance[method].append(mean_var)
        ci_lower[method].append(ci_lower_val)
        ci_upper[method].append(ci_upper_val)
    
    # Compute the overall average tau_hat for re-randomization for this sample size n
    overall_avg_tau_rerand = np.mean(tau_rerand_outer)
    avg_tau_rerand_by_n[n] = overall_avg_tau_rerand
    print(f"For n = {n}, the average tau_hat (Re-randomization) = {overall_avg_tau_rerand:.4f}")
    print(f"Finished sample size n = {n}")
    
# ----------------------------
# 8. Plotting the Results
# ----------------------------
plt.figure(figsize=(12, 8))
# Optionally, you can use a style: plt.style.use('seaborn')

colors = {
    "No Swap (IPW)": "blue",
    "Swap Rounding": "red",
    "Covariate Matched Swap": "green",
    "Random Assignment": "orange",
    "Limited Bernoulli": "purple",
    "Re-randomization": "brown",
    "Self-Normalized IPW": "magenta"
}
markers = {
    "No Swap (IPW)": "o",
    "Swap Rounding": "s",
    "Covariate Matched Swap": "D",
    "Random Assignment": "^",
    "Limited Bernoulli": "v",
    "Re-randomization": "P",
    "Self-Normalized IPW": "X"
}

# Remove one of the duplicate loops.
for method in methods:
    lower_err = np.array(avg_variance[method]) - np.array(ci_lower[method])
    upper_err = np.array(ci_upper[method]) - np.array(avg_variance[method])
    yerr = [lower_err, upper_err]
    label_str = "Standard IPW" if method == "No Swap (IPW)" else method
    plt.errorbar(n_values, avg_variance[method], yerr=yerr,
                 label=label_str,
                 marker=markers[method],
                 color=colors[method],
                 capsize=5,
                 linestyle='-',
                 linewidth=2,
                 markersize=8)

plt.xscale('log')
plt.xticks(n_values, [str(n) for n in n_values], fontsize=18)
plt.xlabel('Sample Size (n)', fontsize=20)
plt.ylabel('Empirical Variance of τ̂', fontsize=20)
plt.yticks(fontsize=18)
plt.title('Uniform Distributed p-values: Average Variance with 95% CI for Different Assignment Methods', fontsize=20)
plt.legend(fontsize=20)
plt.grid(True, which="both", ls="--", lw=0.75)
plt.tight_layout()
plt.savefig('uniform.pdf', bbox_inches="tight", pad_inches=0.0)
plt.show()


