import numpy as np
import pandas as pd
import random

def generate_synthetic_means(K, gap, bias, seed):
    """
    Generate synthetic means:
      - mu_off: offline means, ensuring the gap between the best and second-best arm equals 'gap'
      - mu_on: online means = mu_off + bias (clipped to [0,1])
    """
    np.random.seed(seed)
    mu_best = np.random.uniform(0.5 + gap, 1.0)
    mu_second_best = max(0.0, mu_best - gap)
    mu_rest = np.random.uniform(0.0, mu_second_best, K - 2)
    mu_off = np.sort(np.append(mu_rest, [mu_second_best, mu_best]))
    random_choices = np.random.random(len(mu_off))
    signs = np.where(random_choices < 0.5, 1, -1)
    mu_on = mu_off + signs * bias
    mu_on = np.clip(mu_on, 0.0, 1.0)
    
    print("mu_off:", mu_off)
    print("mu_on:", mu_on)
    return mu_off, mu_on

def generate_synthetic_stochastic_offline_data(mu, N_offline):
    """
    Synthetic environment: sample offline rewards for each arm a:
      Rewards are sampled from N(µ(a), 0.1).
    Returns a list of (arm, reward) and the mu array.
    """
    K = len(mu)
    offline_data = []
    for arm in range(K):
        for _ in range(N_offline):
            reward = np.random.normal(loc=mu[arm], scale = 1)
            offline_data.append((arm, reward))
    return offline_data, mu

def compute_real_means_offline_online(offline_ratings_file, online_ratings_file, top_K, min_ratings, seed):
    """
    Real-world environment: read offline/online data from CSV,
    compute the top 100 movies with the most ratings, then randomly select K movies,
    and compute offline mean mu_off and online mean mu_on for them,
    return mu_off, mu_on, and the list of selected movie IDs as a dictionary.
    """
    np.random.seed(seed)
    df_off = pd.read_csv(offline_ratings_file)
    movie_counts_off = df_off.groupby('movieId').size()
    valid_off = movie_counts_off[movie_counts_off >= min_ratings].index
    df_off = df_off[df_off['movieId'].isin(valid_off)]
    
    top_100_movies = movie_counts_off.sort_values(ascending=False).head(100).index
    selected_movies = random.sample(list(top_100_movies), top_K)
    df_off = df_off[df_off['movieId'].isin(selected_movies)]
    df_off['rating_norm'] = df_off['rating'] / 5.0

    mu_off = {}
    for mid in selected_movies:
        arr = df_off[df_off['movieId'] == mid]['rating_norm'].values
        mu_off[mid] = np.mean(arr)
    selected_ids = list(mu_off.keys())
    
    if online_ratings_file is None:
        return mu_off, None, selected_ids
    
    df_on = pd.read_csv(online_ratings_file)
    df_on = df_on[df_on['movieId'].isin(selected_ids)]
    df_on['rating_norm'] = df_on['rating'] / 5.0
    
    mu_on = {}
    for mid in selected_ids:
        arr_on = df_on[df_on['movieId'] == mid]['rating_norm'].values
        if len(arr_on) == 0:
            mu_on[mid] = mu_off[mid]
        else:
            mu_on[mid] = np.mean(arr_on)
    
    print("selected_ids:", selected_ids)
    print("mu_off:", mu_off)
    print("mu_on:", mu_on)  
    return mu_off, mu_on, selected_ids

def generate_movielens_stochastic(ratings_file, select_id):
    """
    Real-world scheme 1: directly use existing ratings to form rating lists for each movie (normalized to [0,1]).
    """
    df = pd.read_csv(ratings_file)
    df_selected = df[df['movieId'].isin(select_id)]
    df_selected['rating_norm'] = df_selected['rating'] / 5.0
    
    movie_rewards = {}
    for movie in select_id:
        movie_rewards[movie] = df_selected[df_selected['movieId'] == movie]['rating_norm'].values
    return movie_rewards

def compute_V_matrix(mu_off, mu_on, arm_ids=None):
    """
    According to the formula:
    V(a_i, a_j) = 1/2 * ((mu_off[i] - mu_off[j] + 1)) - 1/2 * ((mu_on[i] - mu_on[j] + 1)).
    If mu_off, mu_on is a dict, arm_ids is the actual movie ID list;
    If it is a bus, it directly corresponds to the index.
    Returns the numpy memory of (K, K).
    """
    
    alpha = 1
    if isinstance(mu_off, dict):
        if arm_ids is None:
            arm_ids = list(mu_off.keys())
        K = len(arm_ids)
        V = np.zeros((K, K))
        for i_idx, i_id in enumerate(arm_ids):
            for j_idx, j_id in enumerate(arm_ids):
                val_off = (mu_off[i_id] - mu_off[j_id] + 1.0)
                val_on = (mu_on[i_id] - mu_on[j_id] + 1.0) if (mu_on is not None) else val_off
                V[i_idx, j_idx] = alpha * (0.5 * val_off - 0.5 * val_on)
        return V
    else:
        K = len(mu_off)
        V = np.zeros((K, K))
        for i in range(K):
            for j in range(K):
                val_off = (mu_off[i] - mu_off[j] + 1.0)
                val_on = (mu_on[i] - mu_on[j] + 1.0)
                V[i, j] = alpha * (0.5 * val_off - 0.5 * val_on)
        return V
