
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import time

start_time = time.time()

# Function to sample from truncated normal distribution
def sample_truncated_normal(mean, cov, lower=-np.inf, upper=np.inf):
    a, b = (lower - mean.flatten()) / np.sqrt(np.diag(cov)), (upper - mean.flatten()) / np.sqrt(np.diag(cov))
    samples = truncnorm.rvs(a, b, loc=mean.flatten(), scale=np.sqrt(np.diag(cov)))
    return samples.reshape(-1, 1)


def norm_sq(mat, A):
    return float(mat.T @ A @ mat)


def simplified_bfaips_with_positive_exponent(
    X, T, alpha, tau, sigma2, gamma2, eta_lambda, eta_p, theta_c, theta_r
):
    Z=X
    eta=np.min([sigma2/(8*3*1),gamma2/(8*3*1)])
    eta_r=eta/sigma2
    eta_c=eta/gamma2
    d = X.shape[1]
    n_arms = X.shape[0]
    theta_r_true=theta_r
    theta_c_true=theta_c

    # 初始化
    V_t = np.eye(d)
    S_r = np.zeros((d, 1))
    S_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))

    lambda_t = np.ones(n_arms) / n_arms
    Delta_t = 0.0
    eta_t = np.inf
    alpha = 1 / 4

    accuracy = []

    for t in range(1, T + 1):
        gamma_t = t ** (-alpha)

        # Posterior Sampling
        cov_inv = np.linalg.inv(V_t)
        theta_r_post = sample_truncated_normal(theta_hat_r, eta_r * cov_inv)
        theta_c_post = sample_truncated_normal(theta_hat_c, eta_c * cov_inv)

        # 可行集合
        feasible = [z.reshape(-1, 1) for z in Z if z.T @ theta_hat_c <= tau]

        # Arm selection
        if feasible:
            z_t = feasible[np.argmax([z.T @ theta_hat_r for z in feasible])]
        else:
            z_t = Z[np.random.choice(len(Z))].reshape(-1, 1)

        # Sample X_t from λ_t
        idx = np.random.choice(n_arms, p=lambda_t)
        x_t = X[idx].reshape(-1, 1)

        # Observe noisy reward & cost
        y_r_t = float(x_t.T @ theta_r_true + np.random.normal(0, np.sqrt(sigma2)))
        y_c_t = float(x_t.T @ theta_c_true + np.random.normal(0, np.sqrt(gamma2)))

        # Update estimators
        V_t += x_t @ x_t.T
        S_r += y_r_t * x_t
        S_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_r)
        theta_hat_c = np.linalg.solve(V_t, S_c)

        # ----- AdaHedge: Compute full loss vector -----
        loss_t = np.zeros(n_arms)
        for i, x in enumerate(X):
            x = x.reshape(-1, 1)
            r_loss = norm_sq(theta_r_true - theta_hat_r, x @ x.T) / sigma2
            c_loss = norm_sq(theta_c_true - theta_hat_c, x @ x.T) / gamma2
            loss_t[i] = - (r_loss + c_loss)  # 负损失 = reward proxy

        # Hedge expected loss
        h_t = np.dot(lambda_t, loss_t)

        # Mixed loss
        log_weights = np.log(lambda_t + 1e-12) - eta_t * loss_t
        max_log = np.max(log_weights)
        m_t = - (1 / eta_t) * (max_log + np.log(np.sum(np.exp(log_weights - max_log))))

        # Mixability gap
        delta_t = h_t - m_t
        Delta_t += delta_t

        # Update learning rate
        eta_t = np.log(n_arms) / max(1e-8, Delta_t)

        # Update λ_t using AdaHedge
        weights = lambda_t * np.exp(-eta_t * loss_t)
        if np.sum(weights) == 0 or not np.all(np.isfinite(weights)):
            lambda_t = np.ones(n_arms) / n_arms
        else:
            lambda_t = weights / np.sum(weights)

        # Accuracy log（可选，用于评估目标臂）
        best_arm = np.array([1, 0])
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy







def sample_truncated_normal(mean, covariance):
    """Sample from a truncated normal distribution."""
    # This is a placeholder. Implement a valid truncated normal sampling method.
    return np.random.multivariate_normal(mean.flatten(), covariance).reshape(-1, 1)



def linear_thompson_sampling_feasible(X, T, tau, sigma2, theta_c, theta_r):
    d = X.shape[1]
    V_t = np.eye(d)
    S_t_r = np.zeros((d, 1))
    S_t_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))
    accuracy = []

    for t in range(1, T + 1):
        # Sample parameters for reward and constraint
        theta_r_t = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
        theta_c_t = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

        # Find feasible arms
        feasible_arms = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]

        # Select the arm with the highest reward among feasible arms
        if feasible_arms:
            x_t = feasible_arms[np.argmax([z.T @ theta_r_t for z in feasible_arms])]
        else:
            # If no arm is feasible, randomly select an arm
            x_t = X[np.random.choice(len(X))].reshape(-1, 1)

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(sigma2))

        # Update posterior
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]
        # Record accuracy
        best_arm = [1, 0]
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy




def simplified_bfaips_with_ttts(X, T, alpha, tau, sigma2, gamma2, eta_lambda, theta_c, theta_r, best_w):
    d = X.shape[1]
    V_t = np.eye(d)
    S_t_r, S_t_c = np.zeros((d, 1)), np.zeros((d, 1))
    theta_hat_r, theta_hat_c = np.zeros((d, 1)), np.zeros((d, 1))
    
    accuracy = []
    repetition_counts = []  # To store the number of repetitions for each iteration

    for t in range(1, T + 1):
        #print(f"Iteration: {t}")
        
        # Step 1: Sample parameters for reward and constraint using Thompson Sampling
        theta_r_t = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
        theta_c_t = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

        # Step 2: Select constraint candidate (Constraint Threshold)
        feasible_arms = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]
        if not feasible_arms:
            # If no arm is feasible, randomly select one
            constraint_candidate = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            # Choose the feasible arm with the maximum reward score
            constraint_candidate = feasible_arms[np.argmax([z.T @ theta_r_t for z in feasible_arms])]

        # Step 3: Select reward candidate (Reward Threshold)
        repetition_count = 0  # Counter for repetitions
        while True:
            repetition_count += 1  # Increment repetition count

            # Resample theta_r_t and theta_c_t for the reward candidate
            theta_r_t_reward = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
            theta_c_t_reward = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

            # Filter feasible arms for the reward candidate
            feasible_arms_reward = [z.reshape(-1, 1) for z in X if z @ theta_c_t_reward <= tau]
            if not feasible_arms_reward:
                # If no arm is feasible, randomly select one
                reward_candidate = X[np.random.choice(len(X))].reshape(-1, 1)
            else:
                # Choose the feasible arm with the maximum reward score
                reward_candidate = feasible_arms_reward[np.argmax([z.T @ theta_r_t_reward for z in feasible_arms_reward])]

            # Ensure the reward candidate is different from the constraint candidate
            if not np.allclose(constraint_candidate, reward_candidate):
                break

        # Record the number of repetitions for this iteration
        repetition_counts.append(repetition_count)

        # Step 4: Resolve the final arm selection
        # Choose between the two candidates
        if np.random.rand() < best_w[0]:  # Randomly pick one candidate
            x_t = constraint_candidate
        else:
            x_t = reward_candidate

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(gamma2))

        # Update posterior distributions
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]

        # Accuracy computation
        best_arm = np.array([1, 0])
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    # Print repetition statistics
    #print(f"Average repetitions: {np.mean(repetition_counts):.2f}")
    #print(f"Max repetitions: {np.max(repetition_counts)}")

    return accuracy, repetition_counts













# New algorithm: Random + Empirical Best Feasible Arm
def random_empirical_best(X, T, tau, sigma2, theta_c, theta_r):
    d = X.shape[1]
    rewards = np.zeros(len(X))  # Track cumulative rewards for each arm
    counts = np.zeros(len(X))  # Track the number of times each arm was chosen

    accuracy = []

    for t in range(1, T + 1):
        # Randomly select an arm
        idx = np.random.choice(len(X))
        x_t = X[idx].reshape(-1, 1)

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(sigma2))

        # Update empirical rewards and counts
        rewards[idx] += y_r_t
        counts[idx] += 1

        # Compute empirical means for feasible arms
        empirical_means = rewards / np.maximum(counts, 1)  # Avoid division by zero
        feasible_indices = [i for i in range(len(X)) if X[i].reshape(-1, 1).T @ theta_c <= tau]
        if feasible_indices:
            # Find the empirical best feasible arm
            best_idx = max(feasible_indices, key=lambda i: empirical_means[i])
            best_arm = X[best_idx].flatten()
        else:
            # If no arm is feasible, default to random
            best_arm = X[np.random.choice(len(X))].flatten()

        # Check accuracy
        if np.allclose(best_arm, [1, 0]):  # Compare to the true best arm
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy


def calculate_optimal_rate_with_full_weights(X, delta_c, delta_r, variance, x1, A1, A2, A3, num_samples=100):
    """
    计算最佳速率 Γ 和对应的权重向量 W，假设 σ^2 = γ^2 = variance。

    参数：
    X: 特征向量矩阵 (N x d)。
    delta_c: 约束差值向量 (N)。
    delta_r: 奖励差值向量 (N)。
    variance: 奖励和约束的噪声方差 (σ^2 = γ^2)。
    x1: 最佳可行臂的特征向量 (d)。
    A1, A2, A3: 三个臂集合的索引。
    num_samples: 采样的随机权重向量数量。

    返回：
    Γ 的值和对应的权重向量 W。
    """
    N, d = X.shape
    best_w = None
    max_gamma = float('-inf')

    # 遍历随机生成的权重向量
    for _ in range(num_samples):
        w = np.random.dirichlet(np.ones(N))  # 生成一个归一化的随机权重向量

        # 计算 V_w
        V_w = sum(w[i] * np.outer(X[i], X[i]) for i in range(N))
        V_w_inv = np.linalg.inv(V_w)

        # 定义各类臂的值
        def arm_value(i, arm_type):
            if arm_type == "A1":  # Infeasible & Superoptimal
                return (delta_c[i] ** 2) / (2 * variance * np.linalg.norm(X[i] @ V_w_inv) ** 2)
            elif arm_type == "A2":  # Feasible & Suboptimal
                return (delta_r[i] ** 2) / (2 * variance * np.linalg.norm((X[i] - x1) @ V_w_inv) ** 2)
            elif arm_type == "A3":  # Infeasible & Suboptimal
                return ((delta_c[i] ** 2) / (2 * variance * np.linalg.norm(X[i] @ V_w_inv) ** 2) +
                        (delta_r[i] ** 2) / (2 * variance * np.linalg.norm((X[i] - x1) @ V_w_inv) ** 2))

        # 计算当前 w 的 min 值
        min_values = []
        for i in range(N):  # 遍历所有臂
            if i in A1:
                min_values.append(arm_value(i, "A1"))
            elif i in A2:
                min_values.append(arm_value(i, "A2"))
            elif i in A3:
                min_values.append(arm_value(i, "A3"))
        current_gamma = min(min_values)

        # 更新最大 Γ 和对应的 w
        if current_gamma > max_gamma:
            max_gamma = current_gamma
            best_w = w

    return max_gamma, best_w

# 新算法：基于最佳速率的经验最佳臂

def optimal_rate_empirical_best(X, T, tau, theta_c, theta_r, variance, best_w):
    """
    基于计算出的最佳权重向量 w，按时间步推荐臂。

    参数：
    X: 特征向量矩阵 (N x d)。
    T: 时间步。
    tau: 约束阈值。
    theta_c: 约束向量 (d x 1)。
    theta_r: 奖励向量 (d x 1)。
    variance: 奖励和约束的噪声方差。
    best_w: 根据最佳速率计算的最优权重向量。

    返回：
    每个时间步的推荐臂。
    """
    N, d = X.shape
    V_t = np.eye(d)
    S_t_r = np.zeros((d, 1))
    S_t_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))

    accuracy = []

    for t in range(1, T + 1):
        # 根据权重选择一个臂
        idx = np.random.choice(N, p=best_w)
        x_t = X[idx].reshape(-1, 1)

        # 观测奖励和约束
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(variance))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(variance))

        # 更新 V_t 和 S_t
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t

        # 更新估计的 theta
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]

        # 计算准确率
        true_best_arm = X[0]
        if np.allclose(z_t.flatten(), true_best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy

# 示例输入
X = np.array([
    [1, 0],
    [0, 0.15],
    [1.2, 1.2],
    [0, 1],
    [np.cos(0.1), np.sin(0.1)]
])


num_arms = 5
angles = np.random.uniform(0, 2 * np.pi, num_arms)
random_arms = np.array([np.cos(angles), np.sin(angles)]).T
X = np.vstack([[1, 0], *random_arms])

# 保存更新后的 arm set 到文件
np.savetxt('updated_arm_set.txt', X, fmt='%.6f', header='Updated Arm Set X (x, y)', comments='')

X=np.array([[ 1.        ,  0.        ],
       [-0.95937098,  0.28214769],
       [ 0.97827546, -0.20730926],
       [ 0.99560616,  0.09363956],
       [-0.72308869, -0.6907552 ],
       [ 0.79404379, -0.60786055]])



T = 2000
alpha=0
tau = np.inf
sigma2, gamma2 = 0.1, 0.1
eta_lambda = 1
eta_p = 3
theta_c_new = np.array([[0], [1]])
theta_r_new = np.array([[1], [0]])

# 计算最佳速率和权重向量
delta_c = np.array([x @ theta_c_new - tau for x in X])
delta_r = np.array([1 - x @ theta_r_new for x in X])
x1 = X[0]
A1 = [i for i in range(len(X)) if delta_c[i] > 0 and delta_r[i] <= 0]
A2 = [i for i in range(len(X)) if delta_c[i] <= 0 and delta_r[i] > 0]
A3 = [i for i in range(len(X)) if delta_c[i] > 0 and delta_r[i] > 0]

_, best_w = calculate_optimal_rate_with_full_weights(X, delta_c, delta_r, sigma2, x1, A1, A2, A3)

# Run the algorithm multiple times
num_runs = 10
cumulative_accuracies_feasible = []
cumulative_accuracies_bfaips = []
cumulative_accuracies_ttts = []
cumulative_accuracies_optimal = []
cumulative_accuracies_bfaips_ew = []  # ✅ NEW: BFAIPS (EW)

start_time = time.time()
for zz in range(num_runs):
    # Feasible Linear Thompson Sampling
    accuracy_feasible = linear_thompson_sampling_feasible(X, T, tau, sigma2, theta_c_new, theta_r_new)
    
    # BFAIPS
    accuracy_bfaips = simplified_bfaips_with_positive_exponent(
        X, T, alpha, tau, sigma2, gamma2, eta_lambda, eta_p, theta_c_new, theta_r_new
    )
    
    # TTTS
    accuracy_ttts, _ = simplified_bfaips_with_ttts(
        X, T, alpha, tau, sigma2, gamma2, eta_lambda, theta_c_new, theta_r_new, best_w
    )
    
    # Optimal Rate Algorithm
    accuracy_optimal = optimal_rate_empirical_best(
        X, T, tau, theta_c_new, theta_r_new, sigma2, best_w
    )

   
    # Compute cumulative accuracy
    cumulative_accuracy_feasible = np.cumsum(accuracy_feasible) / np.arange(1, T + 1)
    cumulative_accuracy_bfaips = np.cumsum(accuracy_bfaips) / np.arange(1, T + 1)
    cumulative_accuracy_ttts = np.cumsum(accuracy_ttts) / np.arange(1, T + 1)
    cumulative_accuracy_optimal = np.cumsum(accuracy_optimal) / np.arange(1, T + 1)
    

    # Store results
    cumulative_accuracies_feasible.append(cumulative_accuracy_feasible)
    cumulative_accuracies_bfaips.append(cumulative_accuracy_bfaips)
    cumulative_accuracies_ttts.append(cumulative_accuracy_ttts)
    cumulative_accuracies_optimal.append(cumulative_accuracy_optimal)
    

    timenow = time.time()
    print(zz / num_runs, timenow - start_time)

# Convert to arrays
cumulative_accuracies_feasible = np.array(cumulative_accuracies_feasible)
cumulative_accuracies_bfaips = np.array(cumulative_accuracies_bfaips)
cumulative_accuracies_ttts = np.array(cumulative_accuracies_ttts)
cumulative_accuracies_optimal = np.array(cumulative_accuracies_optimal)


# Save
np.save('cumulative_accuracies_feasible.npy', cumulative_accuracies_feasible)
np.save('cumulative_accuracies_bfaips.npy', cumulative_accuracies_bfaips)
np.save('cumulative_accuracies_ttts.npy', cumulative_accuracies_ttts)
np.save('cumulative_accuracies_optimal.npy', cumulative_accuracies_optimal)


# Calculate mean and standard error for all algorithms
mean_accuracy_feasible = cumulative_accuracies_feasible.mean(axis=0)
std_error_feasible = cumulative_accuracies_feasible.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_bfaips = cumulative_accuracies_bfaips.mean(axis=0)
std_error_bfaips = cumulative_accuracies_bfaips.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_ttts = cumulative_accuracies_ttts.mean(axis=0)
std_error_ttts = cumulative_accuracies_ttts.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_optimal = cumulative_accuracies_optimal.mean(axis=0)
std_error_optimal = cumulative_accuracies_optimal.std(axis=0) / np.sqrt(num_runs)



# Plot results
plt.figure(figsize=(10, 6))

# Plot for Feasible Linear Thompson Sampling
plt.plot(range(1, T + 1), mean_accuracy_feasible, label="Linear TS (Feasible)", color="green")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_feasible - 1 * std_error_feasible,
    mean_accuracy_feasible + 1 * std_error_feasible,
    color="green",
    alpha=0.2
)

# Plot for BFAIPS
plt.plot(range(1, T + 1), mean_accuracy_bfaips, label="BFAIPS", color="blue")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_bfaips - 1 * std_error_bfaips,
    mean_accuracy_bfaips + 1 * std_error_bfaips,
    color="blue",
    alpha=0.2
)



# Plot for TTTS
plt.plot(range(1, T + 1), mean_accuracy_ttts, label="TTTS", color="orange")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_ttts - 1 * std_error_ttts,
    mean_accuracy_ttts + 1 * std_error_ttts,
    color="orange",
    alpha=0.2
)

# Plot for Oracle
plt.plot(range(1, T + 1), mean_accuracy_optimal, label="Oracle", color="red")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_optimal - 1 * std_error_optimal,
    mean_accuracy_optimal + 1 * std_error_optimal,
    color="red",
    alpha=0.2
)

plt.xlabel("Time Step")
plt.ylabel("Accuracy")
plt.title("Accuracy Comparison: Linear TS (Feasible) vs BFAIPS vs BFAIPS (EW) vs TTTS vs Oracle")
plt.legend()
plt.grid(True)
plt.show()

