"""Real-world dataset empirical study"""



import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import time

# Load MovieLens context data
movie_raw = np.load("movie_summary_with_context.npy")
X = movie_raw[:, :20]
mean_user_rating = movie_raw[:, 21]
imdb_rating = movie_raw[:, 22]

# Load reward data and build movie_id -> list of ratings mapping
reward_raw = np.load("reward.npy")
movie_ratings = {}
for user_id, movie_id, rating in reward_raw:
    movie_id = int(movie_id)
    if movie_id not in movie_ratings:
        movie_ratings[movie_id] = []
    movie_ratings[movie_id].append(rating)

def get_sampled_rating(movie_id, fallback_rating):
    ratings = movie_ratings.get(movie_id, [])
    return np.random.choice(ratings) if ratings else fallback_rating

def sample_truncated_normal(mean, cov, lower=-np.inf, upper=np.inf):
    a = (lower - mean.flatten()) / np.sqrt(np.diag(cov))
    b = (upper - mean.flatten()) / np.sqrt(np.diag(cov))
    samples = truncnorm.rvs(a, b, loc=mean.flatten(), scale=np.sqrt(np.diag(cov)))
    return samples.reshape(-1, 1)

# Parameters
T = 5000
num_runs = 50
sigma2 = 1 ** 2
gamma2 = 1 ** 2
tau = 7.5

# Compute ground-truth best feasible arm
true_reward = mean_user_rating
true_cost = imdb_rating
feasible_indices = np.where(true_cost <= tau)[0]
true_best_idx = feasible_indices[np.argmax(true_reward[feasible_indices])]

def linear_ts_feasible(X, T, tau, mean_user_rating, imdb_rating):
    d = X.shape[1]
    V = np.eye(d)
    S_r = np.zeros((d, 1))
    S_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))
    acc = []
    for t in range(1, T + 1):
        beta_t = np.sqrt(9 * d * np.log(t + 1))
        theta_r_t = np.random.multivariate_normal(theta_hat_r.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
        theta_c_t = np.random.multivariate_normal(theta_hat_c.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
        feasibles = [x.reshape(-1, 1) for x in X if x @ theta_c_t <= tau]
        x_t = feasibles[np.argmax([x.T @ theta_r_t for x in feasibles])] if feasibles else X[np.random.choice(len(X))].reshape(-1, 1)
        idx = np.where((X == x_t.T).all(axis=1))[0][0]
        movie_id = int(movie_raw[idx, 0])
        y_r = get_sampled_rating(movie_id, mean_user_rating[idx])
        y_c = imdb_rating[idx] + np.random.normal(0, np.sqrt(gamma2))
        V += x_t @ x_t.T
        S_r += y_r * x_t
        S_c += y_c * x_t
        theta_hat_r = np.linalg.solve(V, S_r)
        theta_hat_c = np.linalg.solve(V, S_c)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        z_t = F[np.argmax([z.T @ theta_hat_r for z in F])] if F else X[np.random.choice(len(X))].reshape(-1, 1)
        z_idx = np.where((X == z_t.T).all(axis=1))[0][0]
        acc.append(int(z_idx == true_best_idx))
    return acc

def peps(X, T, tau, mean_user_rating, imdb_rating, sigma2=0.2**2, gamma2=0.2**2, eta_lambda=8000):
    d = X.shape[1]
    n = X.shape[0]
    V = np.eye(d)
    S_r = np.zeros((d, 1))
    S_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))
    lam = np.ones(n).flatten() / n
    feasible_indices = np.where(imdb_rating <= tau)[0]
    true_best_idx = feasible_indices[np.argmax(mean_user_rating[feasible_indices])]
    acc = []

    for t in range(1, T + 1):
        epsilon_t = t ** (-0.25)
        idx = np.random.choice(n) if np.random.rand() < epsilon_t else np.random.choice(n, p=lam)
        x_t = X[idx].reshape(-1, 1)
        F = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        z_t = F[np.argmax([z.T @ theta_hat_r for z in F])] if F else X[np.random.choice(n)].reshape(-1, 1)

        while True:
            theta_r_t = sample_truncated_normal(theta_hat_r, eta_lambda * np.linalg.inv(V))
            theta_c_t = sample_truncated_normal(theta_hat_c, eta_lambda * np.linalg.inv(V))
            F_new = [z.reshape(-1, 1) for z in X if z @ theta_c_t <= tau]
            z_t_candidate = F_new[np.argmax([z.T @ theta_r_t for z in F_new])] if F_new else X[np.random.choice(n)].reshape(-1, 1)
            if not np.allclose(z_t_candidate, z_t):
                break

        movie_id = int(movie_raw[idx, 0])
        y_r = get_sampled_rating(movie_id, mean_user_rating[idx])
        y_c = imdb_rating[idx] + np.random.normal(0, np.sqrt(gamma2))

        V += x_t @ x_t.T
        S_r += y_r * x_t
        S_c += y_c * x_t
        theta_hat_r = np.linalg.solve(V, S_r)
        theta_hat_c = np.linalg.solve(V, S_c)

        loss = np.array([
            -((x @ (theta_r_t - theta_hat_r))**2 / sigma2 +
              (x @ (theta_c_t - theta_hat_c))**2 / gamma2).item()
            for x in X
        ])
        etap = sigma2 * (np.log(20) / T / np.log(T))**0.5
        w = np.exp(etap * loss).flatten()
        lam = w / np.sum(w) if np.all(np.isfinite(w)) and np.sum(w) > 0 else np.ones(n).flatten() / n

        z_idx = np.where((X == z_t.T).all(axis=1))[0][0]
        acc.append(int(z_idx == true_best_idx))
    return acc

def ttts(X, T, tau, mean_user_rating, imdb_rating):
    d = X.shape[1]
    V = np.eye(d)
    S_r = np.zeros((d, 1))
    S_c = np.zeros((d, 1))
    theta_hat_r = np.zeros((d, 1))
    theta_hat_c = np.zeros((d, 1))
    acc = []
    for t in range(1, T + 1):
        beta_t = np.sqrt(9 * d * np.log(t + 1))
        theta_r_1 = np.random.multivariate_normal(theta_hat_r.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
        theta_c_1 = np.random.multivariate_normal(theta_hat_c.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
        feas_1 = [x.reshape(-1, 1) for x in X if x @ theta_c_1 <= tau]
        cand1 = feas_1[np.argmax([x.T @ theta_r_1 for x in feas_1])] if feas_1 else X[np.random.choice(len(X))].reshape(-1, 1)
        while True:
            theta_r_2 = np.random.multivariate_normal(theta_hat_r.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
            theta_c_2 = np.random.multivariate_normal(theta_hat_c.flatten(), (beta_t ** 2) * np.linalg.inv(V)).reshape(-1, 1)
            feas_2 = [x.reshape(-1, 1) for x in X if x @ theta_c_2 <= tau]
            cand2 = feas_2[np.argmax([x.T @ theta_r_2 for x in feas_2])] if feas_2 else X[np.random.choice(len(X))].reshape(-1, 1)
            if not np.allclose(cand1, cand2):
                break
        x_t = cand1 if np.random.rand() < 0.5 else cand2
        idx = np.where((X == x_t.T).all(axis=1))[0][0]
        movie_id = int(movie_raw[idx, 0])
        y_r = get_sampled_rating(movie_id, mean_user_rating[idx])
        y_c = imdb_rating[idx] + np.random.normal(0, np.sqrt(gamma2))
        V += x_t @ x_t.T
        S_r += y_r * x_t
        S_c += y_c * x_t
        theta_hat_r = np.linalg.solve(V, S_r)
        theta_hat_c = np.linalg.solve(V, S_c)
        feas = [z.reshape(-1, 1) for z in X if z @ theta_hat_c <= tau]
        z_t = feas[np.argmax([z.T @ theta_hat_r for z in feas])] if feas else X[np.random.choice(len(X))].reshape(-1, 1)
        z_idx = np.where((X == z_t.T).all(axis=1))[0][0]
        acc.append(int(z_idx == true_best_idx))
    return acc

# Run experiments
acc_ts, acc_peps, acc_ttts = [], [], []
for run in range(num_runs):
    start = time.time()
    acc_ts.append(np.cumsum(linear_ts_feasible(X, T, tau, mean_user_rating, imdb_rating)) / np.arange(1, T + 1))
    acc_peps.append(np.cumsum(peps(X, T, tau, mean_user_rating, imdb_rating)) / np.arange(1, T + 1))
    acc_ttts.append(np.cumsum(ttts(X, T, tau, mean_user_rating, imdb_rating)) / np.arange(1, T + 1))
    print(f"Run {run + 1}/{num_runs} completed in {time.time() - start:.2f} seconds")

acc_ts, acc_peps, acc_ttts = map(np.array, [acc_ts, acc_peps, acc_ttts])

# Compute mean and stderr
mean_ts = acc_ts.mean(axis=0)
err_ts = acc_ts.std(axis=0) / np.sqrt(num_runs)
mean_peps = acc_peps.mean(axis=0)
err_peps = acc_peps.std(axis=0) / np.sqrt(num_runs)
mean_ttts = acc_ttts.mean(axis=0)
err_ttts = acc_ttts.std(axis=0) / np.sqrt(num_runs)

# Save results
np.save("acc_ts_ML5000.npy", acc_ts)
np.save("acc_peps_ML5000.npy", acc_peps)
np.save("acc_ttts_ML5000.npy", acc_ttts)

# Plot
plt.figure(figsize=(10, 6))
plt.plot(mean_ts, label='Linear TS (Feasible)', color='green')
plt.fill_between(range(T), mean_ts - err_ts, mean_ts + err_ts, color='green', alpha=0.2)
plt.plot(mean_peps, label='PEPS', color='purple')
plt.fill_between(range(T), mean_peps - err_peps, mean_peps + err_peps, color='purple', alpha=0.2)
plt.plot(mean_ttts, label='TTTS', color='orange')
plt.fill_between(range(T), mean_ttts - err_ttts, mean_ttts + err_ttts, color='orange', alpha=0.2)
plt.xlabel("Time Step")
plt.ylabel("Accuracy")
plt.title("MovieLens Accuracy: BLFAIPS vs TTTS vs Linear TS")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

