import numpy as np
import pickle
import concurrent.futures
import pandas as pd

def compute_mse_for_cluster(cluster_user_embedding, cluster_item_embedding, arbitrary_item_embedding, cluster_rating, cluster_mask, eval=False):
    pred = np.matmul(cluster_user_embedding, arbitrary_item_embedding.T)
    true_pred = np.matmul(cluster_user_embedding, cluster_item_embedding.T)
    # noise = np.random.normal(0, 1.0, true_pred.shape)
    # true_pred = true_pred + noise
    if eval:
        error = (pred - cluster_rating)**2 - (true_pred - cluster_rating)**2
    else:
        noise = np.random.normal(0, 0.4, pred.shape)
        true_pred = true_pred + noise
        error = (pred - cluster_rating)**2 - (true_pred - cluster_rating)**2
    error = error[cluster_mask]
    return max(np.mean(error), 1e-6)

def initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m):
    n_tot = len(cluster_user_embeddings)
    n = int(0.5 * n_tot)
    i = np.random.randint(low=0, high=n)
    Theta = [cluster_item_embeddings[i]]
    prob = [compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i],Theta[-1], cluster_ratings[i], cluster_masks[i]) for i in range(n)]
    for t in range(m-1):
        prob = np.array([min(compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], Theta[-1], cluster_ratings[i], cluster_masks[i]), prob[i]) for i in range(n)])
        i = np.random.choice(n, p=prob/sum(prob)) # select user
        Theta.append(cluster_item_embeddings[i]) # query preference and update
    # return Theta
    prob = [min([compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], theta, cluster_ratings[i], cluster_masks[i], eval=True) for theta in Theta]) for i in range(n+1, n_tot, 1)]
    return sum(prob)

def epsilon_greedy_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m):
    n_tot = len(cluster_user_embeddings)
    n = int(0.5 * n_tot)
    i = np.random.randint(low=0, high=n)
    Theta = [cluster_item_embeddings[i]]
    prob = [compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i],Theta[-1], cluster_ratings[i], cluster_masks[i]) for i in range(n)]
    for t in range(m-1):
        prob = np.array([min(compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], Theta[-1], cluster_ratings[i], cluster_masks[i]), prob[i]) for i in range(n)])
        # i = np.random.choice(n, p=prob/sum(prob)) # select user
        max_entry = np.max(prob)
        std_dev_factor = 1.0  # Adjust this factor based on your requirements
        # Calculate the standard deviation
        std_dev = std_dev_factor * max_entry
        # Generate Gaussian noise with the same shape as 'prob'
        noise = np.random.normal(0, std_dev, prob.shape)
        # epsilon greedy
        i = np.argmax(prob + noise)
        Theta.append(cluster_item_embeddings[i]) # query preference and update
    # return Theta
    prob = [min([compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], theta, cluster_ratings[i], cluster_masks[i], eval=True) for theta in Theta]) for i in range(n+1, n_tot, 1)]
    return sum(prob)

def greedy_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m):
    n_tot = len(cluster_user_embeddings)
    n = int(0.5 * n_tot)
    i = np.random.randint(low=0, high=n)
    Theta = [cluster_item_embeddings[i]]
    prob = [compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i],Theta[-1], cluster_ratings[i], cluster_masks[i]) for i in range(n)]
    for t in range(m-1):
        prob = np.array([min(compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], Theta[-1], cluster_ratings[i], cluster_masks[i]), prob[i]) for i in range(n)])
        # i = np.random.choice(n, p=prob/sum(prob)) # select user
        max_entry = np.max(prob)
        # greedy
        i = np.argmax(prob)
        Theta.append(cluster_item_embeddings[i]) # query preference and update
    # return Theta
    prob = [min([compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], theta, cluster_ratings[i], cluster_masks[i], eval=True) for theta in Theta]) for i in range(n+1, n_tot, 1)]
    return sum(prob)

def uniform_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m):
    n_tot = len(cluster_user_embeddings)
    n = int(0.5 * n_tot)
    idx = np.random.choice(np.arange(n), size=m, replace=True)
    Theta = [cluster_item_embeddings[i] for i in idx]
    # return Theta
    prob = [min([compute_mse_for_cluster(cluster_user_embeddings[i], cluster_item_embeddings[i], theta, cluster_ratings[i], cluster_masks[i], eval=True) for theta in Theta]) for i in range(n+1, n_tot, 1)]
    return sum(prob)

# Wrapper function for 'initialize'
def initialize_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials):
    def initialize_once():
        return initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m)
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(initialize_once) for _ in range(num_trials)]
    
    return [result.result() for result in results]

# Wrapper function for 'greedy_initialize'
def initialize_greedy_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials):
    def initialize_once():
        return greedy_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m)
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(initialize_once) for _ in range(num_trials)]
    
    return [result.result() for result in results]

# Wrapper function for 'epsilon_greedy_initialize'
def initialize_epsilon_greedy_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials):
    def initialize_once():
        return epsilon_greedy_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m)
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(initialize_once) for _ in range(num_trials)]
    
    return [result.result() for result in results]

# Wrapper function for 'uniform_initialize'
def uniform_initialize_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials):
    def uniform_initialize_once():
        return uniform_initialize(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m)
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(uniform_initialize_once) for _ in range(num_trials)]
    
    return [result.result() for result in results]

# Define the file name
# file_name = "cluster_data.pkl"
num_clusters = 1000
file_name = "test_cluster_data_{}.pkl".format(num_clusters)

# Load the data from the file
with open(file_name, 'rb') as file:
    loaded_data = pickle.load(file)

# Retrieve the dictionaries from the loaded data
cluster_user_embeddings = loaded_data["cluster_user_embeddings"]
cluster_item_embeddings = loaded_data["cluster_item_embeddings"]
cluster_ratings = loaded_data["cluster_ratings"]
cluster_masks = loaded_data["cluster_mask"]

print(len(cluster_user_embeddings[1]))


num_trials = 250

# Define the range for m
m_values = range(5, 15, 1)

# Initialize empty lists to store results
results = []

for m in m_values:
    print("number of services: ", m)
    # Call 'initialize' function in parallel
    Theta_initialize = initialize_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials)

    # Call 'grredy_initialize' function in parallel
    Theta_greedy_initialize = initialize_greedy_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials)

    # Call 'epsilon_greedy_initialize' function in parallel
    Theta_epsilon_greedy_initialize = initialize_epsilon_greedy_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials)

    # Call 'uniform_initialize' function in parallel
    Theta_uniform_initialize = uniform_initialize_parallel(cluster_user_embeddings, cluster_item_embeddings, cluster_ratings, cluster_masks, m, num_trials)

    # Calculate the average for 'initialize'
    average_initialize = np.mean(Theta_initialize)
    print("AcQUIre : ", average_initialize)

    average_greedy_initialize = np.mean(Theta_greedy_initialize)
    print("Greedy : ", average_greedy_initialize)

    average_epsilon_greedy_initialize = np.mean(Theta_epsilon_greedy_initialize)
    print("Epsilon Greedy : ", average_epsilon_greedy_initialize)

    # Calculate the standard deviation for 'initialize'
    std_dev_initialize = np.std(Theta_initialize)

    # Calculate the average for 'uniform_initialize'
    average_uniform_initialize = np.mean(Theta_uniform_initialize)
    print("Random : ", average_uniform_initialize)

    # Calculate the standard deviation for 'uniform_initialize'
    std_dev_uniform_initialize = np.std(Theta_uniform_initialize)

    # # Append results to the list
    # results.append({
    #     'm': m,
    #     'Average Initialize': average_initialize,
    #     'Std Dev Initialize': std_dev_initialize,
    #     'Average Uniform Initialize': average_uniform_initialize,
    #     'Std Dev Uniform Initialize': std_dev_uniform_initialize
    # })
    # Append results to the list
    results = {
        "AcQUIre" : Theta_initialize, 
        "Random" : Theta_uniform_initialize,
        "Greedy Initialize" : Theta_greedy_initialize, 
        "Epsilon Greedy Initialize" : Theta_epsilon_greedy_initialize
    }

    np.save("icml_{}.npy".format(m), results)
# Create a DataFrame from the results
# results_df = pd.DataFrame(results)

# Save the results to a CSV file
# results_df.to_csv('test_results_{}.csv'.format(num_clusters), index=False)


