import numpy as np
import pandas as pd
import random

def generate_synthetic_means(K, gap, bias, seed):
    """
    Generate the mean under synthetic data:
    - mu_off: offline mean, ensure that the gap between the best arm and the second best arm is gap
    - mu_on: online mean = mu_off + bias (truncated to [0,1])

    Parameters:
    K: number of arms
    gap: gap between the best arm and the second best arm
    bias: offset of the online mean relative to the offline mean
    seed: random seed

    Return:
    mu_off: offline mean array
    mu_on: online mean array
    """
    np.random.seed(seed)
    mu_best = np.random.uniform(0.5 + gap, 1.0)  # 确保它不会太小
    mu_second_best = max(0.0, mu_best - gap)
    mu_rest = np.random.uniform(0.0, mu_second_best, K - 2)
    mu_off = np.sort(np.append(mu_rest, [mu_second_best, mu_best]))
    random_choices = np.random.random(len(mu_off))
    signs = np.where(random_choices < 0.5, 1, -1)
    mu_on = mu_off + signs * bias  
    mu_on = np.clip(mu_on, 0.0, 1.0)
    
    print("mu_off:", mu_off)
    print("mu_on:", mu_on)
    return mu_off, mu_on

def generate_synthetic_relative_offline_data(mu, N_offline):
    """
    Synthetic environment: Generate offline Relative data.
    For each pair of arms (i,j), let p_ij = (mu[i] - mu[j] + 1) / 2,
    outcome ~ Bernoulli(p_ij).

    Parameters:
    mu: offline mean array
    N_offline: number of offline duels for each pair of arms

    Returns:
    offline_data: (i, j, outcome) list
    mu: offline mean array
    """
    K = len(mu)
    offline_data = []
    for i in range(K):
        for j in range(i + 1, K):
            for _ in range(N_offline):
                p_ij = 1.0 / (1.0 + np.exp(mu[j] - mu[i]))
                outcome = np.random.binomial(1, p_ij)
                offline_data.append((i, j, outcome))
    return offline_data, mu

def compute_real_means_offline_online(offline_ratings_file, online_ratings_file, top_K, min_ratings):
    """
    Real data environment: read offline/online data from CSV,
    calculate the top 100 movies with the most ratings, then randomly select K movies,
    calculate the offline mean mu_off and online mean mu_on respectively.

    Parameters:
    offline_ratings_file: offline rating CSV file path
    online_ratings_file: online rating CSV file path (optional)
    top_K: number of arms selected
    min_ratings: minimum number of ratings required for a movie

    Returns:
    mu_off: offline mean dictionary
    mu_on: online mean dictionary (None if no online data)
    selected_ids: selected movie ID list
    """
    df_off = pd.read_csv(offline_ratings_file)
    movie_counts_off = df_off.groupby('movieId').size()
    valid_off = movie_counts_off[movie_counts_off >= min_ratings].index
    df_off = df_off[df_off['movieId'].isin(valid_off)]
    top_100_movies = movie_counts_off.sort_values(ascending=False).head(100).index
    selected_movies = random.sample(list(top_100_movies), top_K)
    df_off = df_off[df_off['movieId'].isin(selected_movies)]
    
    df_off.loc[:, 'rating_norm'] = df_off['rating'] 
    
    mu_off = {}
    for mid in selected_movies:
        arr = df_off[df_off['movieId'] == mid]['rating_norm'].values
        mu_off[mid] = np.mean(arr)
    selected_ids = list(mu_off.keys())
    
    if online_ratings_file is None:
        return mu_off, None, selected_ids
    
    df_on = pd.read_csv(online_ratings_file)
    df_on = df_on[df_on['movieId'].isin(selected_ids)]
    
    df_on.loc[:, 'rating_norm'] = df_on['rating']  
    
    mu_on = {}
    for mid in selected_ids:
        arr_on = df_on[df_on['movieId'] == mid]['rating_norm'].values
        if len(arr_on) == 0:
            mu_on[mid] = mu_off[mid]
        else:
            mu_on[mid] = np.mean(arr_on)
    
    print("selected_ids:", selected_ids)
    print("mu_off:", mu_off)
    print("mu_on:", mu_on)  
    return mu_off, mu_on, selected_ids

def generate_movielens_relative_offline_data(ratings_file, select_id, N_offline_per_pair, feedback_mode):
    df = pd.read_csv(ratings_file)
    df_selected = df[df['movieId'].isin(select_id)]
    
    df_selected.loc[:, 'rating_norm'] = df_selected['rating'] 

    mu_off = {}
    movie_ratings = {}
    for movie in select_id:
        ratings = df_selected[df_selected['movieId'] == movie]['rating_norm'].values
        movie_ratings[movie] = ratings
        mu_off[movie] = np.mean(ratings) if len(ratings) > 0 else 0.5

    offline_data = []
    K = len(select_id)

    if feedback_mode == "bernoulli":
        for i_idx in range(K):
            for j_idx in range(i_idx + 1, K):
                i_id = select_id[i_idx]
                j_id = select_id[j_idx]
                p_ij = 1.0 / (1.0 + np.exp(mu_off[j_id] - mu_off[i_id]))
                for _ in range(N_offline_per_pair):
                    outcome = np.random.binomial(1, p_ij)
                    offline_data.append((i_idx, j_idx, outcome))

    elif feedback_mode == "data":
        sample_size = 10
        for i_idx in range(K):
            for j_idx in range(i_idx + 1, K):
                i_id = select_id[i_idx]
                j_id = select_id[j_idx]

                # check if there are enough data
                if i_id not in movie_ratings or j_id not in movie_ratings:
                    print(f"data mode: i_id: {i_id}, j_id: {j_id}, missing data, using default 0.5")
                    outcome = 1 if 0.5 >= 0.5 else 0
                    for _ in range(N_offline_per_pair):
                        offline_data.append((i_idx, j_idx, outcome))
                    continue

                ratings_i = movie_ratings[i_id]
                ratings_j = movie_ratings[j_id]

                for _ in range(N_offline_per_pair):
                    if len(ratings_i) >= sample_size:
                        sampled_ratings_i = np.random.choice(ratings_i, size=sample_size, replace=False)
                    else:
                        sampled_ratings_i = np.random.choice(ratings_i, size=sample_size, replace=True)

                    if len(ratings_j) >= sample_size:
                        sampled_ratings_j = np.random.choice(ratings_j, size=sample_size, replace=False)
                    else:
                        sampled_ratings_j = np.random.choice(ratings_j, size=sample_size, replace=True)

                    avg_rating_i = np.mean(sampled_ratings_i)
                    avg_rating_j = np.mean(sampled_ratings_j)
                    if avg_rating_i==avg_rating_j:
                        outcome = random.choice([0, 1])
                    else:
                        outcome = 1 if avg_rating_i > avg_rating_j else 0
                    offline_data.append((i_idx, j_idx, outcome))

    else:
        raise ValueError("Invalid feedback_mode: choose 'bernoulli' or 'data'")

    return offline_data, mu_off

def generate_movielens_stochastic_online_data(ratings_file, select_id):
    """
    Parameters:
    ratings_file: CSV file path
    select_id: selected movieId list

    Returns:
    movie_rewards: dictionary of movie ID to normalized rating list
    mu_on: online mean dictionary (based on the mean of all ratings)
    """
    df = pd.read_csv(ratings_file)
    df_selected = df[df['movieId'].isin(select_id)]
    
    df_selected.loc[:, 'rating_norm'] = df_selected['rating'] 
    
    movie_rewards = {}
    mu_on = {}
    
    for movie in select_id:
        ratings = df_selected[df_selected['movieId'] == movie]['rating_norm'].values
        movie_rewards[movie] = ratings
        mu_on[movie] = np.mean(ratings) if len(ratings) > 0 else 0.5  

    return movie_rewards, mu_on

def compute_V_matrix(mu_off, mu_on, arm_ids=None):
    """
    Calculate the V matrix according to the formula:
    V(a_i, a_j) = 1/2 * ((mu_off[i] - mu_off[j] + 1)) - 1/2 * ((mu_on[i] - mu_on[j] + 1)).

    Parameters:
    mu_off: offline mean (array or dictionary)
    mu_on: online mean (array or dictionary, if None, use mu_off)
    arm_ids: list of arm IDs (used if mu_off is a dictionary)

    Returns:
    V: numpy array of (K, K)
    """
    alpha = 1

    if isinstance(mu_off, dict):
        if arm_ids is None:
            arm_ids = list(mu_off.keys())
        K = len(arm_ids)
        V = np.zeros((K, K))
        for i_idx, i_id in enumerate(arm_ids):
            for j_idx, j_id in enumerate(arm_ids):
                val_off = (mu_off[i_id] - mu_off[j_id] + 1.0)
                val_on = (mu_on[i_id] - mu_on[j_id] + 1.0) if (mu_on is not None) else val_off
                V[i_idx, j_idx] = alpha * (0.5 * val_off - 0.5 * val_on)
        return V
    else:
        K = len(mu_off)
        V = np.zeros((K, K))
        for i in range(K):
            for j in range(K):
                val_off = (mu_off[i] - mu_off[j] + 1.0)
                val_on = (mu_on[i] - mu_on[j] + 1.0)
                V[i, j] = alpha * (0.5 * val_off - 0.5 * val_on)
        return V