"""
Common utilities for weak DFL experiments with AFIRO problem.

This module contains shared functions for:
- Random state management
- Dataset generation (with scaled/unscaled modes)
- Model creation
- Loss functions
"""

import os
import gc
import random
import numpy as np
import tensorflow as tf


# --- Random State Management ---
def reset_random_state(seed=42):
    """Reset all random states for reproducibility."""
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    tf.keras.backend.clear_session()
    tf.random.set_seed(seed)
    try:
        tf.config.experimental.enable_op_determinism()
    except Exception:
        pass
    gc.collect()


# --- Dataset Generation ---
def _build_weight_matrix(c_base, feature_dim, rng, scale_factor=0.1, use_scaling=True):
    """
    Build weight matrix for dataset generation.
    
    Args:
        c_base: Base cost vector
        feature_dim: Number of features
        rng: Random number generator
        scale_factor: Scaling factor for perturbations (only used if use_scaling=True)
        use_scaling: If True, scale weights by |c_base| magnitude. If False, use unscaled weights.
    
    Returns:
        W_true: Weight matrix (cost_dim x feature_dim)
    """
    cost_dim = len(c_base)
    active_cost_indices = np.where(c_base != 0)[0]
    W_true = np.zeros((cost_dim, feature_dim))

    num_features_per_cost = feature_dim // len(active_cost_indices)
    if num_features_per_cost == 0:
        raise ValueError("feature_dim is too small for the number of active costs.")

    feature_idx = 0
    for cost_idx in active_cost_indices:
        end_idx = min(feature_idx + num_features_per_cost, feature_dim)
        if use_scaling:
            # Scaled version: perturbations proportional to |c_base|
            scale = abs(c_base[cost_idx]) * scale_factor
            W_true[cost_idx, feature_idx:end_idx] = rng.standard_normal(end_idx - feature_idx) * scale
        else:
            # Unscaled version: standard normal weights
            W_true[cost_idx, feature_idx:end_idx] = rng.standard_normal(end_idx - feature_idx)
        feature_idx = end_idx

    return W_true


def generate_correlated_afiro_dataset(c_base, num_samples, feature_dim, noise_std=0.1, *, 
                                       W_true=None, rng=None, use_scaling=True, scale_factor=0.1):
    """
    Generates a dataset where features Z are correlated with perturbations
    to the base AFIRO cost vector c_base. This creates a meaningful
    predictive task for the DFL model.
    
    Args:
        c_base: Base cost vector
        num_samples: Number of samples to generate
        feature_dim: Number of features
        noise_std: Standard deviation of independent noise
        W_true: Pre-computed weight matrix (optional)
        rng: Random number generator
        use_scaling: If True, use scaled perturbations (proportional to |c_base|).
                    If False, use unscaled perturbations (standard normal).
        scale_factor: Scaling factor for perturbations (only used if use_scaling=True)
    
    Returns:
        Z: Feature matrix (num_samples x feature_dim)
        C: Cost matrix (num_samples x cost_dim)
        W_true: Weight matrix used for generation
    """
    if rng is None:
        rng = np.random.default_rng()
    elif isinstance(rng, np.random.RandomState):
        rng = np.random.default_rng(rng.randint(0, 2 ** 32 - 1))

    cost_dim = len(c_base)

    if W_true is None:
        W_true = _build_weight_matrix(c_base, feature_dim, rng, scale_factor, use_scaling)

    # Generate features Z
    Z = rng.standard_normal((num_samples, feature_dim))

    # Generate cost vectors with additive perturbations
    c_perturbations = Z @ W_true.T
    C = np.tile(c_base, (num_samples, 1)) + c_perturbations

    # Add independent noise
    if noise_std > 0:
        if use_scaling:
            # Scaled noise: proportional to |c_base|
            C += rng.standard_normal((num_samples, cost_dim)) * noise_std * np.abs(c_base)
        else:
            # Unscaled noise: fixed standard deviation
            C += rng.standard_normal((num_samples, cost_dim)) * noise_std

    return Z.astype(np.float32), C.astype(np.float32), W_true


# --- Model Creation ---
def create_model(num_variables, feature_dim, *, seed=42):
    """
    Create a simple linear model for cost prediction.
    
    Args:
        num_variables: Number of decision variables (output dimension)
        feature_dim: Number of features (input dimension)
        seed: Random seed for initialization
    
    Returns:
        tf.keras.Model: Linear model
    """
    kernel_init = tf.keras.initializers.GlorotNormal(seed=seed)
    bias_init = tf.keras.initializers.GlorotUniform(seed=seed)

    inputs = tf.keras.Input(shape=(feature_dim,), dtype=tf.float32)
    dense_outputs = tf.keras.layers.Dense(
        num_variables,
        kernel_initializer=kernel_init,
        bias_initializer=bias_init,
    )(inputs)
    outputs = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1), name='l2_normalization')(dense_outputs)
    return tf.keras.Model(inputs=inputs, outputs=outputs, name="pref_linear")


# --- Loss Functions ---
@tf.function
def mpll_pref(c_hat, x_positive, f_label, q_negatives, maximise=False):
    """
    Maximum Pairwise Log-Likelihood loss for preference learning.
    GPU-optimized with @tf.function decorator.
    
    Args:
        c_hat: Predicted cost vector
        x_positive: Positive example (optimal solution)
        f_label: Label (+1 for good, -1 for bad)
        q_negatives: Negative examples (nearby vertices)
    
    Returns:
        Loss value
    """
    score_x = tf.reduce_sum(x_positive * c_hat)
    scores_q = tf.reduce_sum(q_negatives * c_hat, axis=1)

    utility_x = score_x#tf.math.sigmoid(score_x)
    utility_q = scores_q#tf.math.sigmoid(scores_q)
    if maximise:
        diff = utility_x - utility_q
    else:
        diff = utility_q - utility_x  # Q-X for minimisation
    # Use numerically stable softplus instead of log(1 + exp(x))
    loss = tf.nn.softplus(-f_label * diff)
    return tf.reduce_sum(loss)


@tf.function
def bce_loss_pref(c_hat, x_positive, f_label, q_negatives):
    """
    Binary Cross-Entropy loss for preference learning.
    GPU-optimized with @tf.function decorator.
    
    Args:
        c_hat: Predicted cost vector
        x_positive: Positive example (optimal solution)
        f_label: Label (+1 for good, -1 for bad)
        q_negatives: Negative examples (nearby vertices)
    
    Returns:
        Loss value
    """
    score_x = tf.reduce_sum(x_positive * c_hat)
    scores_q = tf.reduce_sum(q_negatives * c_hat, axis=1)

    utility_x = tf.math.sigmoid(score_x)
    utility_q = tf.math.sigmoid(scores_q)

    all_scores = tf.concat([tf.reshape(utility_x, [1]), utility_q], axis=0)
    label_x = (f_label + 1.0) / 2.0
    labels_q = tf.zeros(tf.shape(q_negatives)[0], dtype=tf.float32)
    all_labels = tf.concat([tf.reshape(label_x, [1]), labels_q], axis=0)
    loss = tf.keras.losses.binary_crossentropy(y_true=all_labels, y_pred=all_scores, from_logits=False)
    return tf.reduce_sum(loss)


@tf.function
def hinge_loss_pref(c_hat, x_positive, f_label, q_negatives, margin=1.0):
    """
    Hinge loss for preference learning.
    GPU-optimized with @tf.function decorator.
    
    Args:
        c_hat: Predicted cost vector
        x_positive: Positive example (optimal solution)
        f_label: Label (+1 for good, -1 for bad)
        q_negatives: Negative examples (nearby vertices)
        margin: Margin for hinge loss
    
    Returns:
        Loss value
    """
    score_x = tf.reduce_sum(x_positive * c_hat)
    scores_q = tf.reduce_sum(q_negatives * c_hat, axis=1)

    max_abs = tf.maximum(tf.reduce_max(tf.abs(scores_q)), tf.abs(score_x))
    scale = tf.maximum(max_abs, tf.constant(1.0, dtype=tf.float32))

    score_x_norm = score_x / scale
    scores_q_norm = scores_q / scale

    pos_loss = tf.maximum(0.0, margin - f_label * score_x_norm)
    neg_loss = tf.maximum(0.0, margin + f_label * scores_q_norm)

    return pos_loss + tf.reduce_sum(neg_loss)


# --- Regret Calculation ---
def calculate_regret(c_true, x_star, x_hat, epsilon=1e-8):
    """
    Calculate regret with safe division to avoid divide-by-zero errors.
    
    Args:
        c_true: True cost vector
        x_star: Optimal solution for true costs
        x_hat: Solution from predicted costs
        epsilon: Small value to avoid division by zero
    
    Returns:
        Regret value (relative error)
    """
    cost_star = np.dot(c_true, x_star)  # Optimal cost
    cost_hat = np.dot(c_true, x_hat)    # Predicted cost
    
    # For minimization: regret = |cost_hat - cost_star| / |cost_star|
    # Use max with epsilon to avoid division by zero
    denominator = max(abs(cost_star), epsilon)
    regret = float(abs(cost_hat - cost_star) / denominator)
    return regret
