import os
import random
import shutil

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from scipy.special import expit  # Sigmoid function
from scipy.stats import gamma
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

pandarallel.initialize(progress_bar=False)

SEED = 42
np.random.seed(42)  # For reproducibility


def generate_user_weights(num_features_, num_users_per_group_, c_, sigma_):
    users_ = []
    for feature in tqdm(range(num_features_), desc="Generating user weights"):
        for _ in range(num_users_per_group_):
            # Preferred feature: Gamma-distributed with mean c
            preferred_weight = gamma.rvs(a=c_, scale=sigma_)
            # Non-preferred features: Gamma-distributed with near-zero mean
            non_preferred_weights = gamma.rvs(a=0.1, scale=sigma_, size=num_features_ - 1)
            # Construct weight vector
            weight_vector = np.zeros(num_features_)
            weight_vector[feature] = preferred_weight
            weight_vector[np.arange(num_features_) != feature] = non_preferred_weights
            users_.append((feature, weight_vector))
    return users_


def generate_data_vectors(num_features_, num_data_points_, c_, sigma_):
    data_ = []
    for _ in tqdm(range(num_data_points_), desc="Generating data vectors"):
        # Randomly select a feature to activate
        active_feature = np.random.randint(num_features_)
        # Active feature: Gamma-distributed with mean c
        active_value = gamma.rvs(a=c_, scale=sigma_)
        # Inactive features: Negative Gamma-distributed with mean -c
        inactive_values = -gamma.rvs(a=c_, scale=sigma_, size=num_features_ - 1)
        # inactive_values = -gamma.rvs(a=1., scale=sigma_, size=num_features_ - 1)
        # Construct data vector
        data_vector = np.zeros(num_features_)
        data_vector[active_feature] = active_value
        data_vector[np.arange(num_features_) != active_feature] = inactive_values
        data_.append((active_feature, data_vector))
    return data_


def compute_predictions(users_, data_):
    # Extract user features and weights
    user_features = [uf for uf, _ in users_]
    user_weights = np.array([uw for _, uw in users_])  # Shape: (num_users, num_features)

    for df, dv in tqdm(data_, desc="Processing predictions"):
        dot_product = np.dot(user_weights, dv)  # Compute dot product for one data vector
        predictions = expit(dot_product)  # Apply sigmoid
        yield (df, list(zip(user_features, predictions)))


def evaluate_performance(results_):
    correct = 0
    total = 0
    for data_feature, predictions in tqdm(results_, desc="Evaluating performance"):
        for user_feature, prediction in predictions:
            if user_feature == data_feature:
                # Prediction should be close to 1 for matching features
                if prediction > 0.5:
                    correct += 1
            else:
                # Prediction should be close to 0 for non-matching features
                if prediction < 0.5:
                    correct += 1
            total += 1
    accuracy = correct / total
    return accuracy


def assign_data_points(data_points, od_pairs_, min_size, max_size):
    np.random.shuffle(data_points)

    assignments = {i: [] for i in range(od_pairs_)}
    assigned_count = 0

    for i in tqdm(range(od_pairs_), desc="Assigning data points"):
        num_points = np.random.randint(min_size, max_size + 1)

        if assigned_count + num_points > len(data_points):
            num_points = len(data_points) - assigned_count
        temp_points = data_points[assigned_count: assigned_count + num_points]
        assignments[i] = [(x[0], x[2]) for x in temp_points]
        assigned_count += num_points

        if assigned_count >= len(data_points):
            break

    return assignments


def assign_user_to_od(users_, max_size, min_size, od_pairs_):
    user_od_dict_ = {}
    for datapoint in tqdm(users_, desc="Assigning user ODs"):
        i = datapoint[0]
        items = np.random.randint(min_size, max_size)
        indices = np.random.choice(range(od_pairs_), size=items, replace=False)
        user_od_dict_[i] = list(indices)
    return user_od_dict_


def generate_interaction_dataframe(user_od_dict_, od_assignment_, users_):
    """
    For each user, for every assigned OD pair, this version splits the available data
    into items whose active feature matches the user's preferred feature and those that don't.
    It then randomly selects from each group (if available) with equal probability,
    forcing roughly 50/50 label distribution.
    """
    user_item_interaction_ = []  # Each entry: [user_id, OD_pair_id, vector_id, feedback]

    for user_id, od_list in tqdm(user_od_dict_.items(), desc="Generating user item interactions"):
        # Extract the user's preferred feature (u_users is structured as (user_id, feature, weight_vector))
        user_pref_feature = users_[user_id][1]

        for od in od_list:
            vals = od_assignment_[od]  # List of tuples: (vector_id, data_vector)

            # Partition items into matching and non-matching items based on the user's preferred feature
            matching_items = [v for v in vals if v[1][user_pref_feature] > 0]
            non_matching_items = [v for v in vals if v[1][user_pref_feature] <= 0]

            # With 50/50 chance, try to select an item from each group
            if matching_items and non_matching_items:
                if np.random.rand() <= 0.5:
                    chosen = random.choice(matching_items)
                    feedback = 1
                else:
                    chosen = random.choice(non_matching_items)
                    feedback = 0
            elif matching_items:
                chosen = random.choice(matching_items)
                feedback = 1
            elif non_matching_items:
                chosen = random.choice(non_matching_items)
                feedback = 0
            else:
                continue  # Skip if no item is available

            # chosen is a tuple (vector_id, data_vector)
            user_item_interaction_.append([user_id, od, chosen[0], feedback])

    return user_item_interaction_


def get_prediction(row_):
    u_id_ = row_['user_id']
    v_id_ = row_['vector_id']
    u_cols = ['v_{}'.format(i) for i in range(1, num_features + 1)]
    v_cols = ['x_{}'.format(i) for i in range(1, num_features + 1)]
    user_vector = df_users[df_users.user_id == u_id_][u_cols].values
    data_vector = df_data[df_data.vector_id == v_id_][v_cols].values
    if len(user_vector) > 1 or len(data_vector) > 1:
        return 0
    else:
        pred_val = 1 if expit(np.dot(user_vector[0], data_vector[0])) > 0.5 else 0
        return pred_val


def validate_vector_assignment(row_):
    source_od = row_['OD_pair_id']
    source_vector_id = row_['vector_id']
    possible_vec_in_od = df_od_map[df_od_map.OD_pair_id == source_od].values
    non_nan_vec = possible_vec_in_od[~np.isnan(possible_vec_in_od)].astype(int)
    return source_vector_id in non_nan_vec


def naive_model(x_train_, x_test_, y_train_, y_test_):
    model_ = LogisticRegression(class_weight='balanced', solver='lbfgs')

    model_.fit(x_train_, y_train_)
    y_pred_ = model_.predict(x_test_)
    model_score = model_.score(x_test_, y_test_)
    accuracy_ = accuracy_score(y_test_, y_pred_)
    roc_auc_ = roc_auc_score(y_test_, y_pred_)
    cf_mat_ = confusion_matrix(y_test_, y_pred_)

    return model_score, accuracy_, roc_auc_, cf_mat_


def generate_user_preference_groups(k):
    """
    Generate k distinct patterns (for demonstration),
    using a binary representation approach:
      group_id in [0..k-1] -> binary -> H or L
    """
    patterns = []
    for group_id in range(k):
        bin_str = format(group_id, '0{}b'.format(k))[-k:]
        pattern = ['H' if c == '1' else 'L' for c in bin_str]
        patterns.append(pattern)
    return patterns


def data_split_gen_test(result_df_):
    train_data, test_data = [], []
    err_count = 0
    for _ in tqdm(result_df_.user_id.unique(), desc="Generating data splits"):
        temp_df = result_df_[result_df_.user_id == _]
        positive_label = temp_df[temp_df.feedback == 1]
        negative_label = temp_df[temp_df.feedback == 0]
        if len(positive_label) < 2 or len(negative_label) < 2:
            err_count += 1
            continue
        ptrain_indices = np.random.choice(positive_label.index.values.tolist(), size=len(positive_label) // 2,
                                          replace=False)
        ptest_indices = np.setdiff1d(positive_label.index.values, ptrain_indices)
        ntrain_indices = np.random.choice(negative_label.index.values.tolist(), size=len(negative_label) // 2,
                                          replace=False)
        ntest_indices = np.setdiff1d(negative_label.index.values, ntrain_indices)
        train_indices = np.append(ptrain_indices, ntrain_indices)
        test_indices = np.append(ptest_indices, ntest_indices)
        temp_train = temp_df.loc[train_indices]
        temp_test = temp_df.loc[test_indices]
        train_data.extend(temp_train.values.tolist())
        test_data.extend(temp_test.values.tolist())

    _df_train_ = pd.DataFrame(train_data, columns=result_df_.columns)
    _df_test_ = pd.DataFrame(test_data, columns=result_df_.columns)
    print('{}/{} users missing'.format(err_count, result_df_.user_id.nunique()))
    return _df_train_, _df_test_


if __name__ == '__main__':
    c = 5.0
    sigma = 2.0
    num_features = 5
    num_users_per_group = 100
    od_pairs = 1000
    min_assignment, max_assignment = 6, 8  # OD assignment per pair
    num_data_points = max_assignment * od_pairs  # Derived total data points
    save_files = True
    min_interaction, max_interaction = 100, 101  # Interactions per user (you can adjust these if needed)

    assert max_interaction <= od_pairs, 'Decrease max_interaction OR increase od_pairs'

    folder_name = 'synthetic_data'
    if os.path.isdir(folder_name):
        choice = input('Folder {} already exists: Remove?'.format(folder_name))
        if choice.lower() in ['y', "yes", 1, '1']:
            shutil.rmtree(folder_name)
            os.mkdir(folder_name)
        else:
            exit(1)
    else:
        os.mkdir(folder_name)

    users = generate_user_weights(num_features, num_users_per_group, c, sigma)
    data = generate_data_vectors(num_features, num_data_points, c, sigma)

    u_dict = {}
    np.random.seed(42)
    for idx, uv in tqdm(users, desc='Users'):
        if idx not in u_dict:
            u_dict[idx] = []
        u_dict[idx].append(uv)
    count = 0
    for idx, dp in tqdm(data, desc='Data'):
        neighbour_pred = []
        for k, v in u_dict.items():
            user_same = u_dict[k][np.random.choice(len(u_dict[k]))]
            dot_product = np.dot(user_same, dp)
            if k == idx:
                true_prediction = expit(dot_product)
            else:
                not_true = expit(dot_product)
                neighbour_pred.append(not_true)
        if (true_prediction > 0.5 and any(x > 0.5 for x in neighbour_pred)) or (
                true_prediction < 0.5 and any(x < 0.5 for x in neighbour_pred)):
            count += 1

    print(round(count * 100 / len(data), 2), '% errors')

    u_users = [(i, x[0], x[1]) for i, x in enumerate(users)]
    u_data = [(i, x[0], x[1]) for i, x in enumerate(data)]

    # interactions
    od_assignment = assign_data_points(u_data, od_pairs, min_assignment, max_assignment)
    user_od_dict = assign_user_to_od(u_users, max_interaction, min_interaction, od_pairs)
    user_item_interaction = generate_interaction_dataframe(user_od_dict, od_assignment, u_users)
    # saving
    flat_users = [[x[0], x[1], *x[2]] for x in u_users]
    df_users = pd.DataFrame(flat_users, columns=['user_id', 'FeatureGroup_id'] + ['v_{}'.format(i) for i in
                                                                                  range(1, num_features + 1)])
    flat_data = [[x[0], x[1], *x[2]] for x in u_data]
    df_data = pd.DataFrame(flat_data, columns=['vector_id', 'FeatureGroup_id'] + ['x_{}'.format(i) for i in
                                                                                  range(1, num_features + 1)])
    od_assignment_vector_map = []
    for k, v in tqdm(od_assignment.items(), desc='OD Assignment'):
        indices_assigned = [x[0] for x in v] if len(v) == max_assignment else [x[0] for x in v] + [np.nan] * (
                max_assignment - len(v))
        od_assignment_vector_map.append([k, *indices_assigned])
    df_od_map = pd.DataFrame(od_assignment_vector_map,
                             columns=['OD_pair_id'] + ['v_id_{}'.format(i) for i in range(max_assignment)])
    df_interaction_map = pd.DataFrame(user_item_interaction, columns=['user_id', 'OD_pair_id', 'vector_id', 'feedback'])
    # Validation
    flag = 0
    result = df_interaction_map.parallel_apply(lambda x: True if get_prediction(x) == x['feedback'] else False, axis=1)
    if False in result:
        flag = 1
        print('Validation 1 Failed: Check User to Vector Mapping')
    else:
        print('Test 1 Successful!')
    result = df_interaction_map.parallel_apply(lambda x: validate_vector_assignment(x), axis=1)
    if False in result:
        flag = 1
        print('Validation 1 Failed: Check User to Vector Mapping')
    else:
        print('Test 2 Successful!')

    valid_uids = df_train.user_id.unique()

    # Step 1: Filter the DataFrames to only include rows with valid user IDs
    df_users = df_users[df_users['user_id'].isin(valid_uids)]
    df_interaction_map = df_interaction_map[df_interaction_map['user_id'].isin(valid_uids)]

    # Step 2: Create a mapping from the old user_id to new sequential user_id
    # Sorting ensures the new user_ids are assigned in order (0, 1, 2, ...)
    sorted_valid_uids = sorted(valid_uids)
    user_id_map = {old_id: new_id for new_id, old_id in enumerate(sorted_valid_uids)}

    # Step 3: Update the user_id column in each DataFrame that has this column
    df_users['user_id'] = df_users['user_id'].map(user_id_map)
    df_interaction_map['user_id'] = df_interaction_map['user_id'].map(user_id_map)

    # If other dataframes (e.g., df_od_map and df_data) have a 'user_id' column, update them as well:
    if 'user_id' in df_od_map.columns:
        df_od_map['user_id'] = df_od_map['user_id'].map(user_id_map)
    if 'user_id' in df_data.columns:
        df_data['user_id'] = df_data['user_id'].map(user_id_map)

    if save_files:
        # Step 4: Save the updated DataFrames to CSV
        df_users.to_csv('./{}/df_users.csv'.format(folder_name), index=False)
        df_interaction_map.to_csv('./{}/df_interaction_map.csv'.format(folder_name), index=False)
        df_od_map.to_csv('./{}/df_od_map.csv'.format(folder_name), index=False)
        df_data.to_csv('./{}/df_data.csv'.format(folder_name), index=False)
