
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import time

start_time = time.time()

# Function to sample from truncated normal distribution
def sample_truncated_normal(mean, cov, lower=-np.inf, upper=np.inf):
    a, b = (lower - mean.flatten()) / np.sqrt(np.diag(cov)), (upper - mean.flatten()) / np.sqrt(np.diag(cov))
    samples = truncnorm.rvs(a, b, loc=mean.flatten(), scale=np.sqrt(np.diag(cov)))
    return samples.reshape(-1, 1)

# Simplified BFAIPS algorithm with positive exponent for weight update
def simplified_bfaips_with_positive_exponent(X, T, alpha, tau, sigma2, gamma2, eta_lambda, eta_p, theta_c, theta_r):
    d = X.shape[1]
    V_t = np.eye(d)
    S_t_r, S_t_c = np.zeros((d, 1)), np.zeros((d, 1))
    theta_hat_r, theta_hat_c = np.zeros((d, 1)), np.zeros((d, 1))
    
    lambda_t = np.ones(X.shape[0]) / X.shape[0]  # Initial uniform weights
    gains = np.zeros(X.shape[0])
    cumulative_regret = 0
    
    accuracy = []

    for t in range(1, T + 1):
        #print(f"Iteration: {t}")
        
        # Ensure lambda_t is a 1D array and has the correct size
        if lambda_t.size != len(X):
            lambda_t = np.ones(len(X)) / len(X)
        
        idx = np.random.choice(len(X), p=lambda_t)
        x_t = X[idx].reshape(-1, 1)

        ii = 1
        while True:
            F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
            if not F:
                z_t = X[np.random.choice(len(X))].reshape(-1, 1)
            else:
                z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]

            theta_r_t = sample_truncated_normal(theta_hat_r, eta_p*np.linalg.inv(V_t))
            theta_c_t = sample_truncated_normal(theta_hat_c, eta_p*np.linalg.inv(V_t))

            F_new = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]
            if not F_new:
                z_t_candidate = X[np.random.choice(len(X))].reshape(-1, 1)
            else:
                z_t_candidate = F_new[np.argmax([z.T @ theta_r_t for z in F_new])]

            if not np.allclose(z_t_candidate, z_t):
                #print(f"Converged after {ii} inner iterations.")
                break
            else:
                ii += 1

        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(gamma2))

        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        
    

        # Update the gains for AdaHedge
        regret_t = y_r_t - (lambda_t @ gains)
        cumulative_regret += regret_t
        gains[idx] += regret_t

        # Compute AdaHedge weights
        weights = np.exp(eta_lambda * cumulative_regret / (t + 1))

        lambda_t = weights / weights.sum()
        # Ensure lambda_t is 1-dimensional and matches X size
        if lambda_t.size != len(X):
            lambda_t = np.ones(len(X)) / len(X)

        # Accuracy computation
        best_arm = np.array([1, 0])
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy

def sample_truncated_normal(mean, covariance):
    """Sample from a truncated normal distribution."""
    # This is a placeholder. Implement a valid truncated normal sampling method.
    return np.random.multivariate_normal(mean.flatten(), covariance).reshape(-1, 1)



def linear_thompson_sampling_feasible(X, T, tau, sigma2, theta_c, theta_r):
    d = X.shape[1]
    V_t = np.eye(d)
    S_t_r = np.zeros((d, 1))
    S_t_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))
    accuracy = []

    for t in range(1, T + 1):
        # Sample parameters for reward and constraint
        theta_r_t = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
        theta_c_t = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

        # Find feasible arms
        feasible_arms = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]

        # Select the arm with the highest reward among feasible arms
        if feasible_arms:
            x_t = feasible_arms[np.argmax([z.T @ theta_r_t for z in feasible_arms])]
        else:
            # If no arm is feasible, randomly select an arm
            x_t = X[np.random.choice(len(X))].reshape(-1, 1)

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(sigma2))

        # Update posterior
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]
        # Record accuracy
        best_arm = [1, 0]
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy




def simplified_bfaips_with_ttts(X, T, alpha, tau, sigma2, gamma2, eta_lambda, theta_c, theta_r, best_w):
    d = X.shape[1]
    V_t = np.eye(d)
    S_t_r, S_t_c = np.zeros((d, 1)), np.zeros((d, 1))
    theta_hat_r, theta_hat_c = np.zeros((d, 1)), np.zeros((d, 1))
    
    accuracy = []
    repetition_counts = []  # To store the number of repetitions for each iteration

    for t in range(1, T + 1):
        #print(f"Iteration: {t}")
        
        # Step 1: Sample parameters for reward and constraint using Thompson Sampling
        theta_r_t = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
        theta_c_t = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

        # Step 2: Select constraint candidate (Constraint Threshold)
        feasible_arms = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]
        if not feasible_arms:
            # If no arm is feasible, randomly select one
            constraint_candidate = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            # Choose the feasible arm with the maximum reward score
            constraint_candidate = feasible_arms[np.argmax([z.T @ theta_r_t for z in feasible_arms])]

        # Step 3: Select reward candidate (Reward Threshold)
        repetition_count = 0  # Counter for repetitions
        while True:
            repetition_count += 1  # Increment repetition count

            # Resample theta_r_t and theta_c_t for the reward candidate
            theta_r_t_reward = sample_truncated_normal(theta_hat_r, np.linalg.inv(V_t))
            theta_c_t_reward = sample_truncated_normal(theta_hat_c, np.linalg.inv(V_t))

            # Filter feasible arms for the reward candidate
            feasible_arms_reward = [z.reshape(-1, 1) for z in X if z @ theta_c_t_reward <= tau]
            if not feasible_arms_reward:
                # If no arm is feasible, randomly select one
                reward_candidate = X[np.random.choice(len(X))].reshape(-1, 1)
            else:
                # Choose the feasible arm with the maximum reward score
                reward_candidate = feasible_arms_reward[np.argmax([z.T @ theta_r_t_reward for z in feasible_arms_reward])]

            # Ensure the reward candidate is different from the constraint candidate
            if not np.allclose(constraint_candidate, reward_candidate):
                break

        # Record the number of repetitions for this iteration
        repetition_counts.append(repetition_count)

        # Step 4: Resolve the final arm selection
        # Choose between the two candidates
        if np.random.rand() < best_w[0]:  # Randomly pick one candidate
            x_t = constraint_candidate
        else:
            x_t = reward_candidate

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(gamma2))

        # Update posterior distributions
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]

        # Accuracy computation
        best_arm = np.array([1, 0])
        if np.allclose(z_t.flatten(), best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    # Print repetition statistics
    #print(f"Average repetitions: {np.mean(repetition_counts):.2f}")
    #print(f"Max repetitions: {np.max(repetition_counts)}")

    return accuracy, repetition_counts













# New algorithm: Random + Empirical Best Feasible Arm
def random_empirical_best(X, T, tau, sigma2, theta_c, theta_r):
    d = X.shape[1]
    rewards = np.zeros(len(X))  # Track cumulative rewards for each arm
    counts = np.zeros(len(X))  # Track the number of times each arm was chosen

    accuracy = []

    for t in range(1, T + 1):
        # Randomly select an arm
        idx = np.random.choice(len(X))
        x_t = X[idx].reshape(-1, 1)

        # Observe reward and constraint
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(sigma2))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(sigma2))

        # Update empirical rewards and counts
        rewards[idx] += y_r_t
        counts[idx] += 1

        # Compute empirical means for feasible arms
        empirical_means = rewards / np.maximum(counts, 1)  # Avoid division by zero
        feasible_indices = [i for i in range(len(X)) if X[i].reshape(-1, 1).T @ theta_c <= tau]
        if feasible_indices:
            # Find the empirical best feasible arm
            best_idx = max(feasible_indices, key=lambda i: empirical_means[i])
            best_arm = X[best_idx].flatten()
        else:
            # If no arm is feasible, default to random
            best_arm = X[np.random.choice(len(X))].flatten()

        # Check accuracy
        if np.allclose(best_arm, [1, 0]):  # Compare to the true best arm
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy


def calculate_optimal_rate_with_full_weights(X, delta_c, delta_r, variance, x1, A1, A2, A3, num_samples=100):
    """
    计算最佳速率 Γ 和对应的权重向量 W，假设 σ^2 = γ^2 = variance。

    参数：
    X: 特征向量矩阵 (N x d)。
    delta_c: 约束差值向量 (N)。
    delta_r: 奖励差值向量 (N)。
    variance: 奖励和约束的噪声方差 (σ^2 = γ^2)。
    x1: 最佳可行臂的特征向量 (d)。
    A1, A2, A3: 三个臂集合的索引。
    num_samples: 采样的随机权重向量数量。

    返回：
    Γ 的值和对应的权重向量 W。
    """
    N, d = X.shape
    best_w = None
    max_gamma = float('-inf')

    # 遍历随机生成的权重向量
    for _ in range(num_samples):
        w = np.random.dirichlet(np.ones(N))  # 生成一个归一化的随机权重向量

        # 计算 V_w
        V_w = sum(w[i] * np.outer(X[i], X[i]) for i in range(N))
        V_w_inv = np.linalg.inv(V_w)

        # 定义各类臂的值
        def arm_value(i, arm_type):
            if arm_type == "A1":  # Infeasible & Superoptimal
                return (delta_c[i] ** 2) / (2 * variance * np.linalg.norm(X[i] @ V_w_inv) ** 2)
            elif arm_type == "A2":  # Feasible & Suboptimal
                return (delta_r[i] ** 2) / (2 * variance * np.linalg.norm((X[i] - x1) @ V_w_inv) ** 2)
            elif arm_type == "A3":  # Infeasible & Suboptimal
                return ((delta_c[i] ** 2) / (2 * variance * np.linalg.norm(X[i] @ V_w_inv) ** 2) +
                        (delta_r[i] ** 2) / (2 * variance * np.linalg.norm((X[i] - x1) @ V_w_inv) ** 2))

        # 计算当前 w 的 min 值
        min_values = []
        for i in range(N):  # 遍历所有臂
            if i in A1:
                min_values.append(arm_value(i, "A1"))
            elif i in A2:
                min_values.append(arm_value(i, "A2"))
            elif i in A3:
                min_values.append(arm_value(i, "A3"))
        current_gamma = min(min_values)

        # 更新最大 Γ 和对应的 w
        if current_gamma > max_gamma:
            max_gamma = current_gamma
            best_w = w

    return max_gamma, best_w

# 新算法：基于最佳速率的经验最佳臂

def optimal_rate_empirical_best(X, T, tau, theta_c, theta_r, variance, best_w):
    """
    基于计算出的最佳权重向量 w，按时间步推荐臂。

    参数：
    X: 特征向量矩阵 (N x d)。
    T: 时间步。
    tau: 约束阈值。
    theta_c: 约束向量 (d x 1)。
    theta_r: 奖励向量 (d x 1)。
    variance: 奖励和约束的噪声方差。
    best_w: 根据最佳速率计算的最优权重向量。

    返回：
    每个时间步的推荐臂。
    """
    N, d = X.shape
    V_t = np.eye(d)
    S_t_r = np.zeros((d, 1))
    S_t_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))

    accuracy = []

    for t in range(1, T + 1):
        # 根据权重选择一个臂
        idx = np.random.choice(N, p=best_w)
        x_t = X[idx].reshape(-1, 1)

        # 观测奖励和约束
        y_r_t = x_t.T @ theta_r + np.random.normal(0, np.sqrt(variance))
        y_c_t = x_t.T @ theta_c + np.random.normal(0, np.sqrt(variance))

        # 更新 V_t 和 S_t
        V_t += x_t @ x_t.T
        S_t_r += y_r_t * x_t
        S_t_c += y_c_t * x_t

        # 更新估计的 theta
        theta_hat_r = np.linalg.solve(V_t, S_t_r)
        theta_hat_c = np.linalg.solve(V_t, S_t_c)
        
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        if not F:
            z_t = X[np.random.choice(len(X))].reshape(-1, 1)
        else:
            z_t = F[np.argmax([z.T @ theta_hat_r for z in F])]

        # 计算准确率
        true_best_arm = X[0]
        if np.allclose(z_t.flatten(), true_best_arm):
            accuracy.append(1)
        else:
            accuracy.append(0)

    return accuracy

# 示例输入
X = np.array([
    [1, 0],
    [0, 0.15],
    [1.2, 1.2],
    [0, 1],
    [np.cos(0.2), np.sin(0.2)]
])
T = 1000
alpha=0
tau = 0.5
sigma2, gamma2 = 0.1, 0.1
eta_lambda = 1
eta_p = 4
theta_c_new = np.array([[0], [1]])
theta_r_new = np.array([[1], [0]])

# 计算最佳速率和权重向量
delta_c = np.array([x @ theta_c_new - tau for x in X])
delta_r = np.array([1 - x @ theta_r_new for x in X])
x1 = X[0]
A1 = [i for i in range(len(X)) if delta_c[i] > 0 and delta_r[i] <= 0]
A2 = [i for i in range(len(X)) if delta_c[i] <= 0 and delta_r[i] > 0]
A3 = [i for i in range(len(X)) if delta_c[i] > 0 and delta_r[i] > 0]

_, best_w = calculate_optimal_rate_with_full_weights(X, delta_c, delta_r, sigma2, x1, A1, A2, A3)

# Run the algorithm multiple times
num_runs = 5
cumulative_accuracies_feasible = []
cumulative_accuracies_bfaips = []
cumulative_accuracies_ttts = []
cumulative_accuracies_optimal = []

for zz in range(num_runs):
    
    # Feasible Linear Thompson Sampling
    accuracy_feasible = linear_thompson_sampling_feasible(X, T, tau, sigma2, theta_c_new, theta_r_new)
    # BFAIPS
    accuracy_bfaips = simplified_bfaips_with_positive_exponent(X, T, alpha, tau, sigma2, gamma2, eta_lambda, eta_p, theta_c_new, theta_r_new)
    # TTTS
    accuracy_ttts, _ = simplified_bfaips_with_ttts(X, T, alpha, tau, sigma2, gamma2, eta_lambda, theta_c_new, theta_r_new,best_w)
    # Optimal Rate Algorithm
    accuracy_optimal = optimal_rate_empirical_best(X, T, tau, theta_c_new, theta_r_new, sigma2, best_w)

    # Compute cumulative accuracy
    cumulative_accuracy_feasible = np.cumsum(accuracy_feasible) / np.arange(1, T + 1)
    cumulative_accuracy_bfaips = np.cumsum(accuracy_bfaips) / np.arange(1, T + 1)
    cumulative_accuracy_ttts = np.cumsum(accuracy_ttts) / np.arange(1, T + 1)
    cumulative_accuracy_optimal = np.cumsum(accuracy_optimal) / np.arange(1, T + 1)

    # Store results
    cumulative_accuracies_feasible.append(cumulative_accuracy_feasible)
    cumulative_accuracies_bfaips.append(cumulative_accuracy_bfaips)
    cumulative_accuracies_ttts.append(cumulative_accuracy_ttts)
    cumulative_accuracies_optimal.append(cumulative_accuracy_optimal)
    timenow=time.time()
    print(zz/num_runs, timenow-start_time)
# Convert to arrays
cumulative_accuracies_feasible = np.array(cumulative_accuracies_feasible)
cumulative_accuracies_bfaips = np.array(cumulative_accuracies_bfaips)
cumulative_accuracies_ttts = np.array(cumulative_accuracies_ttts)
cumulative_accuracies_optimal = np.array(cumulative_accuracies_optimal)


#Save
np.save('cumulative_accuracies_feasible.npy', cumulative_accuracies_feasible)
np.save('cumulative_accuracies_bfaips.npy', cumulative_accuracies_bfaips)
np.save('cumulative_accuracies_ttts.npy', cumulative_accuracies_ttts)
np.save('cumulative_accuracies_optimal.npy', cumulative_accuracies_optimal)



# Calculate mean and standard error for all algorithms
mean_accuracy_feasible = cumulative_accuracies_feasible.mean(axis=0)
std_error_feasible = cumulative_accuracies_feasible.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_bfaips = cumulative_accuracies_bfaips.mean(axis=0)
std_error_bfaips = cumulative_accuracies_bfaips.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_ttts = cumulative_accuracies_ttts.mean(axis=0)
std_error_ttts = cumulative_accuracies_ttts.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_optimal = cumulative_accuracies_optimal.mean(axis=0)
std_error_optimal = cumulative_accuracies_optimal.std(axis=0) / np.sqrt(num_runs)

# Plot results
plt.figure(figsize=(10, 6))

# Plot for Feasible Linear Thompson Sampling
plt.plot(range(1, T + 1), mean_accuracy_feasible, label="Linear TS (Feasible)", color="green")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_feasible - 1 * std_error_feasible,
    mean_accuracy_feasible + 1 * std_error_feasible,
    color="green",
    alpha=0.2
)

# Plot for BFAIPS
plt.plot(range(1, T + 1), mean_accuracy_bfaips, label="BFAIPS", color="blue")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_bfaips - 1 * std_error_bfaips,
    mean_accuracy_bfaips + 1 * std_error_bfaips,
    color="blue",
    alpha=0.2
)

# Plot for TTTS
plt.plot(range(1, T + 1), mean_accuracy_ttts, label="TTTS", color="orange")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_ttts - 1 * std_error_ttts,
    mean_accuracy_ttts + 1 * std_error_ttts,
    color="orange",
    alpha=0.2
)

# Plot for Oracle
plt.plot(range(1, T + 1), mean_accuracy_optimal, label="Oracle", color="red")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_optimal - 1 * std_error_optimal,
    mean_accuracy_optimal + 1 * std_error_optimal,
    color="red",
    alpha=0.2
)

plt.xlabel("Time Step")
plt.ylabel("Accuracy")
plt.title("Accuracy Comparison: Linear TS (Feasible) vs BFAIPS vs LinTTTS (Feasible) vs Oracle")
plt.legend()
plt.grid(True)
plt.show()
plt.savefig('example_plot_1.pdf', bbox_inches='tight') 