from scipy.special import expit
from sklearn.metrics import roc_auc_score, accuracy_score
from tqdm import tqdm
from config import *
from utils import real_split, synthetic_split, generator, test_generator

def train_loop(num_features, num_users, tepoch, meta_lr, main_lr, neighbour_margin, lambda_reg, loss_reg, alpha,
               batch_size, cold_start, increase_after, increase_factor, increase_limit, dynamic_neighbour,
               training_points, init_weight, dataset_type, result_df, return_weight=False):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    auc_list = []
    acc_list = []
    total_loss_list = []
    meta_loss_list = []
    output_signature = (
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(num_features,), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32),
        tf.RaggedTensorSpec(shape=(None, num_features), dtype=tf.float32)
    )
    if dataset_type == 'synthetic':
        df_train, df_test = synthetic_split(result_df, training_points)
    elif dataset_type == 'real':
        df_train, df_test = real_split(result_df, training_points)
    optimizer_user = tf.keras.optimizers.Adam(learning_rate=main_lr)
    optimizer_meta = tf.keras.optimizers.Adam(learning_rate=meta_lr)
    epsilon = 1e-8
    new_nm = neighbour_margin
    with tf.device('/CPU:0'):
        if init_weight == 'normal':
            user_matrix = tf.Variable(tf.random.normal([num_users, num_features]), name='user_matrix')
        else:
            user_matrix = tf.Variable(tf.zeros([num_users, num_features]), name='user_matrix')

        dataset = tf.data.Dataset.from_generator(lambda: generator(df_train), output_signature=output_signature)
        dataset = dataset.shuffle(buffer_size=batch_size, seed=SEED).batch(batch_size)
        for epoch in range(tepoch):
            if dynamic_neighbour:
                if epoch > cold_start and epoch % increase_after == 0 and new_nm < increase_limit:
                    new_nm += (increase_factor * new_nm)
            for u_id, batch_features, batch_labels, query_responses in tqdm(dataset, leave=False,
                                                                            total=(len(df_train) // batch_size) + 1,
                                                                            desc='Epoch: {} \t Batch'.format(epoch)):
                if epoch <= cold_start:
                    with tf.GradientTape() as tape:
                        batch_matrix = tf.gather(user_matrix, u_id)
                        user_predictions = tf.math.sigmoid(tf.einsum('ij,ij->i', batch_matrix, batch_features))

                        dense_queries = query_responses.to_tensor()
                        query_mask = tf.sequence_mask(query_responses.row_lengths(), maxlen=tf.shape(dense_queries)[1])
                        user_query_dot = tf.einsum('ij,bqd->bq', batch_matrix, dense_queries)
                        user_query_pred_dense = tf.math.sigmoid(user_query_dot)
                        user_query_pred = tf.ragged.boolean_mask(user_query_pred_dense, query_mask)

                        dense_opt_query_pred = user_query_pred.to_tensor()
                        query_mask = tf.sequence_mask(user_query_pred.row_lengths(),
                                                      maxlen=tf.shape(dense_opt_query_pred)[1])
                        opt_predictions_expanded = tf.expand_dims(user_predictions, axis=1)
                        diff = opt_predictions_expanded - dense_opt_query_pred
                        labels_expanded = tf.expand_dims(batch_labels, axis=1)
                        factor = (2 * labels_expanded - 1)
                        loss_terms = tf.nn.softplus(-tf.cast(factor, tf.float32) * diff)
                        loss_terms_masked = tf.where(query_mask, loss_terms, tf.zeros_like(loss_terms))
                        rank_loss = tf.reduce_sum(loss_terms_masked, axis=1)

                        B = tf.shape(user_predictions)[0]

                        dense_query_pred = user_query_pred.to_tensor()
                        query_mask = tf.sequence_mask(user_query_pred.row_lengths(),
                                                      maxlen=tf.shape(dense_query_pred)[1])
                        y_pred_first = tf.expand_dims(user_predictions, axis=1)
                        y_pred_full = tf.concat([y_pred_first, dense_query_pred], axis=1)
                        y_true_first = tf.expand_dims(batch_labels, axis=1)
                        y_true_rest = tf.zeros_like(dense_query_pred)
                        y_true_full = tf.concat([y_true_first, y_true_rest], axis=1)
                        first_mask = tf.ones((B, 1), dtype=tf.bool)
                        full_mask = tf.concat([first_mask, query_mask], axis=1)
                        bce_loss = - (y_true_full * tf.math.log(y_pred_full + epsilon) +
                                      (1 - y_true_full) * tf.math.log(1 - y_pred_full + epsilon))
                        bce_loss_masked = tf.where(full_mask, bce_loss, tf.zeros_like(bce_loss))
                        bce_loss = tf.reduce_sum(bce_loss_masked, axis=1)
                        total_loss_ = loss_reg * rank_loss + bce_loss
                        total_loss = tf.reduce_mean(total_loss_)
                        grad = tape.gradient(total_loss, user_matrix)
                        grad_subset = tf.gather(grad, u_id)
                        optimizer_user.apply_gradients([
                            (tf.IndexedSlices(grad_subset, u_id, dense_shape=user_matrix.shape), user_matrix)
                        ])
                else:
                    with tf.GradientTape() as tape_main:
                        main_vectors = tf.gather(user_matrix, u_id)
                        neighbor_matrix_blocked = tf.stop_gradient(user_matrix)

                        expanded_batch = tf.expand_dims(main_vectors, axis=1)
                        expanded_matrix = tf.expand_dims(neighbor_matrix_blocked, axis=0)
                        differences = expanded_batch - expanded_matrix
                        distances = tf.sqrt(tf.maximum(tf.reduce_sum(tf.square(differences), axis=-1), epsilon))
                        current_batch_size = tf.shape(u_id)[0]
                        mask = (distances < new_nm) & (distances > epsilon)
                        indices_within_threshold = tf.ragged.boolean_mask(
                            tf.tile(tf.range(num_users)[None, :], [current_batch_size, 1]), mask)
                        neighbor_vector_matrix = tf.ragged.map_flat_values(
                            lambda idx: tf.gather(tf.stop_gradient(user_matrix), idx),
                            indices_within_threshold)
                        neighbor_distance_matrix = tf.ragged.boolean_mask(distances, mask)

                        user_predictions = tf.math.sigmoid(tf.einsum('ij,ij->i', main_vectors, batch_features))
                        inv_dist = 1.0 / (neighbor_distance_matrix + epsilon)
                        weights = inv_dist / (tf.reduce_sum(inv_dist, axis=1, keepdims=True) + epsilon)
                        dense_neigh_vectors = neighbor_vector_matrix.to_tensor()
                        dense_weights = weights.to_tensor()
                        row_lengths = neighbor_vector_matrix.row_lengths()
                        max_neighbors = tf.shape(dense_neigh_vectors)[1]
                        mask_tensor = tf.sequence_mask(row_lengths, max_neighbors)
                        dot_products = tf.einsum('bnd,bd->bn', dense_neigh_vectors, batch_features)
                        predictions = tf.math.sigmoid(dot_products)
                        weighted_predictions = predictions * dense_weights
                        weighted_predictions = tf.where(mask_tensor, weighted_predictions,
                                                        tf.zeros_like(weighted_predictions))
                        neighbor_predictions = tf.reduce_sum(weighted_predictions, axis=1)

                        opt_predictions = (1 - alpha) * user_predictions + alpha * neighbor_predictions

                        dense_queries = query_responses.to_tensor()
                        query_mask = tf.sequence_mask(query_responses.row_lengths(), maxlen=tf.shape(dense_queries)[1])
                        user_query_dot = tf.einsum('ij,bqd->bq', main_vectors, dense_queries)
                        user_query_pred = tf.math.sigmoid(user_query_dot)
                        # Block gradient on neighbor branch for the query part
                        dense_neigh_vectors_2 = tf.stop_gradient(neighbor_vector_matrix.to_tensor())
                        dense_weights_2 = tf.stop_gradient(weights.to_tensor())
                        neigh_mask = tf.sequence_mask(neighbor_vector_matrix.row_lengths(),
                                                      maxlen=tf.shape(dense_neigh_vectors_2)[1])
                        neigh_query_dot = tf.einsum('bnd,bqd->bnq', dense_neigh_vectors_2, dense_queries)
                        neigh_query_pred = tf.math.sigmoid(neigh_query_dot)
                        weighted_neigh_pred = neigh_query_pred * tf.expand_dims(dense_weights_2, axis=-1)
                        neighbor_query_pred = tf.reduce_sum(weighted_neigh_pred, axis=1)

                        opt_query_pred_dense = (1 - alpha) * user_query_pred + alpha * neighbor_query_pred
                        opt_query_pred = tf.ragged.boolean_mask(opt_query_pred_dense, query_mask)
                        dense_opt_query_pred = opt_query_pred.to_tensor()
                        query_mask_final = tf.sequence_mask(opt_query_pred.row_lengths(),
                                                            maxlen=tf.shape(dense_opt_query_pred)[1])

                        opt_predictions_expanded = tf.expand_dims(opt_predictions, axis=1)
                        diff = opt_predictions_expanded - dense_opt_query_pred
                        labels_expanded = tf.expand_dims(batch_labels, axis=1)
                        factor = (2 * labels_expanded - 1)
                        loss_terms = tf.nn.softplus(-tf.cast(factor, tf.float32) * diff)
                        loss_terms_masked = tf.where(query_mask_final, loss_terms, tf.zeros_like(loss_terms))
                        rank_loss = tf.reduce_sum(loss_terms_masked, axis=1)

                        B = tf.shape(opt_predictions)[0]
                        dense_query_pred = opt_query_pred.to_tensor()
                        query_mask_final2 = tf.sequence_mask(opt_query_pred.row_lengths(),
                                                             maxlen=tf.shape(dense_query_pred)[1])
                        y_pred_first = tf.expand_dims(opt_predictions, axis=1)
                        y_pred_full = tf.concat([y_pred_first, dense_query_pred], axis=1)
                        y_true_first = tf.expand_dims(batch_labels, axis=1)
                        y_true_rest = tf.zeros_like(dense_query_pred)
                        y_true_full = tf.concat([y_true_first, y_true_rest], axis=1)
                        first_mask = tf.ones((B, 1), dtype=tf.bool)
                        full_mask = tf.concat([first_mask, query_mask_final2], axis=1)
                        bce_loss = - (y_true_full * tf.math.log(y_pred_full + epsilon) +
                                      (1 - y_true_full) * tf.math.log(1 - y_pred_full + epsilon))
                        bce_loss_masked = tf.where(full_mask, bce_loss, tf.zeros_like(bce_loss))
                        bce_loss = tf.reduce_sum(bce_loss_masked, axis=1)

                        total_loss_main = tf.reduce_sum(loss_reg * rank_loss + bce_loss)

                    grad_main = tape_main.gradient(total_loss_main, user_matrix)
                    grad_main_subset = tf.gather(grad_main, u_id)

                    with tf.GradientTape() as tape_neigh:
                        main_vectors_stopped = tf.stop_gradient(tf.gather(user_matrix, u_id))

                        expanded_batch = tf.expand_dims(main_vectors_stopped, axis=1)
                        expanded_matrix = tf.expand_dims(user_matrix, axis=0)
                        differences = expanded_batch - expanded_matrix
                        distances = tf.sqrt(tf.maximum(tf.reduce_sum(tf.square(differences), axis=-1), epsilon))
                        current_batch_size = tf.shape(u_id)[0]
                        mask = (distances < new_nm) & (distances > epsilon)
                        indices_within_threshold = tf.ragged.boolean_mask(
                            tf.tile(tf.range(num_users)[None, :], [current_batch_size, 1]),
                            mask
                        )
                        neighbor_vector_matrix = tf.ragged.map_flat_values(
                            lambda idx: tf.gather(user_matrix, idx),
                            indices_within_threshold
                        )
                        neighbor_distance_matrix = tf.ragged.boolean_mask(distances, mask)

                        user_predictions = tf.math.sigmoid(tf.einsum('ij,ij->i', main_vectors_stopped, batch_features))
                        inv_dist = 1.0 / (neighbor_distance_matrix + epsilon)
                        weights = inv_dist / (tf.reduce_sum(inv_dist, axis=1, keepdims=True) + epsilon)
                        dense_neigh_vectors = neighbor_vector_matrix.to_tensor()
                        dense_weights = weights.to_tensor()
                        row_lengths = neighbor_vector_matrix.row_lengths()
                        max_neighbors = tf.shape(dense_neigh_vectors)[1]
                        mask_tensor = tf.sequence_mask(row_lengths, max_neighbors)
                        dot_products = tf.einsum('bnd,bd->bn', dense_neigh_vectors, batch_features)
                        predictions = tf.math.sigmoid(dot_products)
                        weighted_predictions = predictions * dense_weights
                        weighted_predictions = tf.where(mask_tensor, weighted_predictions,
                                                        tf.zeros_like(weighted_predictions))
                        neighbor_predictions = tf.reduce_sum(weighted_predictions, axis=1)

                        opt_predictions = (1 - alpha) * user_predictions + alpha * neighbor_predictions

                        dense_queries = query_responses.to_tensor()
                        query_mask = tf.sequence_mask(query_responses.row_lengths(), maxlen=tf.shape(dense_queries)[1])
                        user_query_dot = tf.einsum('ij,bqd->bq', main_vectors_stopped, dense_queries)
                        user_query_pred = tf.math.sigmoid(user_query_dot)
                        dense_neigh_vectors_2 = neighbor_vector_matrix.to_tensor()
                        dense_weights_2 = weights.to_tensor()
                        neigh_mask = tf.sequence_mask(neighbor_vector_matrix.row_lengths(),
                                                      maxlen=tf.shape(dense_neigh_vectors_2)[1])
                        neigh_query_dot = tf.einsum('bnd,bqd->bnq', dense_neigh_vectors_2, dense_queries)
                        neigh_query_pred = tf.math.sigmoid(neigh_query_dot)
                        weighted_neigh_pred = neigh_query_pred * tf.expand_dims(dense_weights_2, axis=-1)
                        neighbor_query_pred = tf.reduce_sum(weighted_neigh_pred, axis=1)

                        opt_query_pred_dense = (1 - alpha) * user_query_pred + alpha * neighbor_query_pred
                        opt_query_pred = tf.ragged.boolean_mask(opt_query_pred_dense, query_mask)
                        dense_opt_query_pred = opt_query_pred.to_tensor()
                        query_mask_final = tf.sequence_mask(opt_query_pred.row_lengths(),
                                                            maxlen=tf.shape(dense_opt_query_pred)[1])

                        opt_predictions_expanded = tf.expand_dims(opt_predictions, axis=1)
                        diff = opt_predictions_expanded - dense_opt_query_pred
                        labels_expanded = tf.expand_dims(batch_labels, axis=1)
                        factor = (2 * labels_expanded - 1)
                        loss_terms = tf.nn.softplus(-tf.cast(factor, tf.float32) * diff)
                        loss_terms_masked = tf.where(query_mask_final, loss_terms, tf.zeros_like(loss_terms))
                        rank_loss = tf.reduce_sum(loss_terms_masked, axis=1)

                        B = tf.shape(opt_predictions)[0]
                        dense_query_pred = opt_query_pred.to_tensor()
                        query_mask_final2 = tf.sequence_mask(opt_query_pred.row_lengths(),
                                                             maxlen=tf.shape(dense_query_pred)[1])
                        y_pred_first = tf.expand_dims(opt_predictions, axis=1)
                        y_pred_full = tf.concat([y_pred_first, dense_query_pred], axis=1)
                        y_true_first = tf.expand_dims(batch_labels, axis=1)
                        y_true_rest = tf.zeros_like(dense_query_pred)
                        y_true_full = tf.concat([y_true_first, y_true_rest], axis=1)
                        first_mask = tf.ones((B, 1), dtype=tf.bool)
                        full_mask = tf.concat([first_mask, query_mask_final2], axis=1)
                        bce_loss = - (y_true_full * tf.math.log(y_pred_full + epsilon) +
                                      (1 - y_true_full) * tf.math.log(1 - y_pred_full + epsilon))
                        bce_loss_masked = tf.where(full_mask, bce_loss, tf.zeros_like(bce_loss))
                        bce_loss = tf.reduce_sum(bce_loss_masked, axis=1)

                        total_loss_neigh = tf.reduce_sum(loss_reg * rank_loss + bce_loss) + \
                                           lambda_reg * tf.reduce_sum(neighbor_distance_matrix)

                    grad_neigh = tape_neigh.gradient(total_loss_neigh, user_matrix)

                    neighbors_flat = indices_within_threshold.values
                    unique_neighbor_ids = tf.unique(neighbors_flat).y

                    unique_neighbor_ids = tf.reshape(
                        tf.sets.difference(
                            tf.expand_dims(unique_neighbor_ids, 0),
                            tf.expand_dims(u_id, 0)
                        ).values,
                        [-1]
                    )

                    optimizer_user.apply_gradients([
                        (tf.IndexedSlices(grad_main_subset, u_id, dense_shape=user_matrix.shape), user_matrix)
                    ])
                    grad_neigh_subset = tf.gather(grad_neigh, unique_neighbor_ids)
                    optimizer_meta.apply_gradients([
                        (tf.IndexedSlices(grad_neigh_subset, unique_neighbor_ids, dense_shape=user_matrix.shape),
                         user_matrix)
                    ])
            if epoch <= cold_start:
                total_loss_list.append(total_loss.numpy())
            else:
                total_loss_list.append(total_loss_main.numpy())
                meta_loss_list.append(total_loss_neigh.numpy())

        if dataset_type == 'synthetic':  # per user accuracy for synthetic data
            gp = df_test.groupby('user_id')
            mx = user_matrix.numpy()
            for _, dfx_ in gp:
                try:
                    logx = np.dot(mx[_], np.array(dfx_.main_vector.apply(lambda x: list(x)).tolist()).T)
                    y_pred_ = expit(logx)
                    auc_list.append(roc_auc_score(dfx_.feedback.values, y_pred_))
                    acc_list.append(accuracy_score(dfx_.feedback.values, np.where(y_pred_ > 0.5, 1, 0)))
                except Exception as e:
                    print(_, e)
        elif dataset_type == 'real':  # global accuracy for real data due to per-user sparsity
            test_dataset = tf.data.Dataset.from_generator(lambda: test_generator(df_test),
                                                          output_signature=output_signature).batch(
                len(df_test))
            test_u, test_batch, test_label, test_query = next(iter(test_dataset))
            test_matrix = tf.gather(user_matrix, test_u)
            test_predictions = tf.math.sigmoid(tf.einsum('ij,ij->i', test_matrix, test_batch))
            test_pred_class = tf.cast(test_predictions > 0.5, tf.int32)
            accuracy_metric = tf.keras.metrics.BinaryAccuracy()
            auc_metric = tf.keras.metrics.AUC()
            accuracy_metric.update_state(test_label, test_pred_class)
            acc = accuracy_metric.result().numpy()
            try:
                auc_metric.update_state(test_label, test_predictions)
                auc = auc_metric.result().numpy()
            except Exception:
                auc = -1.0
            acc_list.append(acc)
            auc_list.append(auc)

    if return_weight:
        return acc_list, auc_list, total_loss_list, meta_loss_list, user_matrix.numpy()
    else:
        return acc_list, auc_list, total_loss_list, meta_loss_list

