import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


# add ../utils folder as a module
from pathlib import Path
import sys

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "utils"))

from minimisation_utils import analyze_lp
from solve_afiro_lp import load_lp_problem, solve_with_scipy
from common_utils import (
    reset_random_state,
    generate_correlated_afiro_dataset,
    create_model,
    mpll_pref,
    bce_loss_pref,
    hinge_loss_pref,
    calculate_regret
)
from tqdm import tqdm
tf.config.set_visible_devices([], 'GPU')
reset_random_state(42)
def decision_relevant_projection(A, c):
    # Projects c onto orthogonal complement of Col(A^T).
    # Uses pseudoinverse for robustness.
    # P_col = A.T @ (A @ A.T)^+ @ A
    AAT = A @ A.T
    AAT_pinv = np.linalg.pinv(AAT)          # stable even if singular
    P_col = A.T @ (AAT_pinv @ A)
    P_perp = np.eye(A.shape[1]) - P_col
    return P_perp @ c

def gen_plots(model_pref_v, regret_history_pref, obj_values_history, bias_to_cbase_metrics, 
              lambda_value, tau, c_base, regret_history_spo=None):
    """Generate plots for training results.
    
    Args:
        model_pref_v: Trained model
        regret_history_pref: List of regret values per epoch
        obj_values_history: List of objective value differences per epoch
        bias_to_cbase_metrics: Dictionary containing bias recovery metrics
        lambda_value: Lambda hyperparameter value
        tau: Tau threshold value
        c_base: True base cost vector
    """
    # Create figure with 2 plots only: regret and cosine similarity
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    # Plot 1: Regret over epochs
    ax.plot(regret_history_pref, label=f'Pref (λ={lambda_value:.2e})', marker='o', linewidth=2)
    if regret_history_spo is not None:
        ax.plot(regret_history_spo, label='SPO+', marker='x', linewidth=2, color='red')
    ax.set_xlabel("Epoch", fontsize=12)
    ax.set_ylabel("Average Relative Regret", fontsize=12)
    ax.set_title(f"Regret Convergence (tau={tau})", fontsize=14, fontweight='bold')
    ax.axhline(0, color='grey', linestyle='--', linewidth=0.8, label="Zero Regret")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'weak_dfl_tau_{tau}_results.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print final statistics
    print("\n" + "=" * 70)
    print("FINAL RESULTS")
    print("=" * 70)
    print(f"Final Regret: {regret_history_pref[-1]:.6f}")
    print(f"Final Cosine Similarity: {bias_to_cbase_metrics['cosine_sim'][-1]:.6f}")
    print("=" * 70)



def train_weak_dfl(c_base, A_eq, b_eq, bounds, Z_train, C_train, Z_test, C_test,
                   learning_rate, num_epochs, batch_size, tau, lambda_value,
                   n_o_delta, n_samples, n_e_scale, feature_dim):
    """Train the weak DFL model.
    
    Args:
        c_base: Base cost vector
        A_eq: Equality constraint matrix
        b_eq: Equality constraint vector
        bounds: Variable bounds
        Z_train: Training features
        C_train: Training cost vectors
        Z_test: Test features
        C_test: Test cost vectors
        learning_rate: Learning rate for optimizer
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        tau: Threshold for preference labeling
        lambda_value: Weight for hinge loss
        n_o_delta: Nearby objective delta parameter
        n_samples: Number of nearby samples
        n_e_scale: Nearby epsilon scale
        feature_dim: Feature dimension
        
    Returns:
        model: Trained model
        regret_history: List of regret values per epoch
        obj_values_history: List of objective value differences per epoch
        bias_to_cbase_metrics: Dictionary containing bias recovery metrics
    """
    num_variables = len(c_base)
    num_train_samples = len(Z_train)
    
    print(f"\n--- Training Preference-Based DFL (Vertex) Model: lambda = {lambda_value} ---")
    reset_random_state(42)
    model = create_model(num_variables, feature_dim, seed=42)
    
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.999)
    regret_history = []
    obj_values_history = []
    
    # Track distance between learned bias and true c_base
    bias_to_cbase_metrics = {
        'l2_distance': [],
        'l1_distance': [],
        'cosine_sim': [],
        'relative_error': [],
        'max_error': []
    }
    
    for epoch in range(num_epochs):
        label_list = []
        # Shuffle training data while maintaining Z-C pairing
        indices = np.arange(num_train_samples)
        np.random.shuffle(indices)
        Z_train_shuffled = Z_train[indices]
        C_train_shuffled = C_train[indices]
        
        for i in tqdm(range(0, num_train_samples, batch_size), total=np.ceil(num_train_samples/batch_size).astype(int), desc=f"Epoch {epoch}/{num_epochs}", leave=True):
            Z_batch = Z_train_shuffled[i:i + batch_size]
            C_batch_true = C_train_shuffled[i:i + batch_size]
            
            with tf.GradientTape() as tape:
                C_hat_batch = model(Z_batch, training=True)
                batch_loss = tf.constant(0.0, dtype=tf.float32)
                valid_examples = 0
                
                for j in range(len(Z_batch)):
                    c_hat = C_hat_batch[j]
                    c_true = C_batch_true[j]
                    
                    skip_example = False
                    with tape.stop_recording():
                        c_hat_np = tf.stop_gradient(c_hat).numpy()
                        x_hat_np, nearby_vertices = analyze_lp(
                            c_hat_np,
                            A_eq=A_eq,
                            b_eq=b_eq,
                            bounds=bounds,
                            verbose=False,
                            nearby_min_obj_delta=n_o_delta,
                            nearby_samples=n_samples,
                            nearby_eps_scales=(n_e_scale,)
                        )
                        if x_hat_np is None or not nearby_vertices:
                            skip_example = True
                        else:
                            q_negatives_np = np.array(nearby_vertices, dtype=np.float32)
                            x_true_np = solve_with_scipy(c_true, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
                            utility_true_val = float(np.dot(c_true, x_true_np))
                            utility_pred_val = float(np.dot(c_true, x_hat_np))
                            label_val = 1.0 if abs(utility_pred_val - utility_true_val) <= tau * (abs(utility_true_val) + 1e-12) else -1.0
                            label_list.append(label_val)
                    
                    if skip_example:
                        continue

                    q_negatives = tf.constant(q_negatives_np)
                    x_hat_tf = tf.constant(x_hat_np.astype(np.float32))
                    label = tf.constant(label_val, dtype=tf.float32)
                    pl_loss = mpll_pref(c_hat, x_hat_tf, label, q_negatives)
                    hinge_loss = hinge_loss_pref(c_hat, x_hat_tf, label, q_negatives)

                    batch_loss += (1 - lambda_value) * pl_loss + lambda_value * hinge_loss
                    valid_examples += 1
                
                if valid_examples == 0:
                    loss = tf.constant(0.0, dtype=tf.float32)
                else:
                    loss = batch_loss / tf.cast(valid_examples, tf.float32)
            
            trainable_vars = model.trainable_variables
            grads = tape.gradient(loss, trainable_vars)
            grads_and_vars = [(g, v) for g, v in zip(grads, trainable_vars) if g is not None]
            if grads_and_vars:
                optimizer.apply_gradients(grads_and_vars)
        
        # Evaluation on test set
        epoch_regrets = []
        obj_values = []
        for z_sample, c_sample in zip(Z_test, C_test):
            c_true_sample = c_base.copy()
            c_hat_sample = model(z_sample.reshape(1, -1), training=False)[0].numpy()
            x_star = solve_with_scipy(c_true_sample, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            x_hat = solve_with_scipy(c_hat_sample, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            
            regret_val = calculate_regret(c_true_sample, x_star, x_hat)
            val_star = np.dot(c_true_sample, x_star)  # Optimal cost
            val_hat = np.dot(c_hat_sample, x_hat)  # Predicted cost

            # For minimization: regret = |cost_hat - cost_star| / |cost_star|
            # Use max with epsilon to avoid division by zero
            denominator = max(abs(val_star), 1e-8)
            objx = float(abs(val_hat - val_star) / denominator)
            obj_values.append(objx)
            # obj_values.append(np.linalg.norm(c_true_sample - c_hat_sample))
            epoch_regrets.append(regret_val)
        
        avg_regret = float(np.mean(epoch_regrets)) if epoch_regrets else float('nan')
        regret_history.append(avg_regret)
        
        avg_obj = float(np.mean(obj_values)) if obj_values else float('nan')
        obj_values_history.append(avg_obj)
        
        # Extract learned bias and compute metrics
        dense_layer = None
        for layer in model.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer = layer
                break
        if dense_layer is not None:
            b_learned = dense_layer.bias.numpy()
        else:
            b_learned = np.zeros_like(c_base)
        l2_dist = np.linalg.norm(b_learned - c_base)
        l1_dist = np.sum(np.abs(b_learned - c_base))
        max_err = np.max(np.abs(b_learned - c_base))

        c_base_perp = decision_relevant_projection(A_eq, c_base)
        b_learned_perp = decision_relevant_projection(A_eq, b_learned)
        norm_b_perp = np.linalg.norm(b_learned_perp)
        norm_c_perp = np.linalg.norm(c_base_perp)
        if norm_b_perp > 1e-8 and norm_c_perp > 1e-8:
            cos_sim = np.dot(b_learned_perp, c_base_perp) / (norm_b_perp * norm_c_perp)
        else:
            cos_sim = 0.0
        # Cosine similarity
        # norm_bias = np.linalg.norm(b_learned)
        # norm_base = np.linalg.norm(c_base)
        # if norm_bias > 1e-8 and norm_base > 1e-8:
        #     cos_sim = np.dot(b_learned, c_base) / (norm_bias * norm_base)
        # else:
        #     cos_sim = 0.0
        
        # Relative error (only for non-zero components of c_base)
        nonzero_mask = np.abs(c_base) > 1e-8
        if np.any(nonzero_mask):
            rel_err = np.mean(np.abs((b_learned[nonzero_mask] - c_base[nonzero_mask]) / c_base[nonzero_mask]))
        else:
            rel_err = float('nan')
        
        # Store metrics
        bias_to_cbase_metrics['l2_distance'].append(l2_dist)
        bias_to_cbase_metrics['l1_distance'].append(l1_dist)
        bias_to_cbase_metrics['cosine_sim'].append(cos_sim)
        bias_to_cbase_metrics['relative_error'].append(rel_err)
        bias_to_cbase_metrics['max_error'].append(max_err)
        
        # Print progress
        print(f"\nEpoch {epoch + 1}, Avg. Regret: {avg_regret:.4f}, "
              f"||b - c_base||_2: {l2_dist:.4f}, "
              f"Cosine Sim: {cos_sim:.4f}")
    
    return model, regret_history, obj_values_history, bias_to_cbase_metrics


def make_spo_plus_loss_eq(A_eq, b_eq, bounds):
    @tf.custom_gradient
    def spo_plus_loss_eq(c_true, c_pred):
        def spo_loss_numpy(c_true_np, c_pred_np):
            x_true = solve_with_scipy(c_true_np, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            c_spo = 2.0 * c_pred_np - c_true_np
            x_spo = solve_with_scipy(c_spo, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            loss = np.dot(c_spo, x_spo) - np.dot(c_true_np, x_true)
            return (
                loss.astype(np.float32),
                x_true.astype(np.float32),
                x_spo.astype(np.float32),
            )

        loss, x_true_tensor, x_spo_tensor = tf.numpy_function(
            func=spo_loss_numpy,
            inp=[c_true, c_pred],
            Tout=[tf.float32, tf.float32, tf.float32],
        )

        loss.set_shape(())
        x_true_tensor.set_shape(c_pred.shape)
        x_spo_tensor.set_shape(c_pred.shape)

        def grad(dy):
            grad_c_true = None
            grad_c_pred = dy * 2.0 * (x_true_tensor - x_spo_tensor)
            return grad_c_true, grad_c_pred

        return loss, grad

    return spo_plus_loss_eq


def train_spoplus(c_base, A_eq, b_eq, bounds, Z_train, C_train, Z_test, C_test,
                  learning_rate, num_epochs, batch_size, feature_dim,
                  beta_1=0.9, beta_2=0.999, clip_norm=1.0):
    num_variables = len(c_base)
    num_train_samples = len(Z_train)
    num_test_samples = len(Z_test)

    reset_random_state(42)
    kernel_init = tf.keras.initializers.GlorotUniform(seed=42)
    bias_init = tf.keras.initializers.Zeros()
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(num_variables, input_shape=(feature_dim,),
                              kernel_initializer=kernel_init,
                              bias_initializer=bias_init)
    ])

    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=beta_1, beta_2=beta_2)
    spo_plus_loss_eq = make_spo_plus_loss_eq(A_eq, b_eq, bounds)

    regret_history_spo = []

    for epoch in range(num_epochs):
        indices = np.arange(num_train_samples)
        np.random.shuffle(indices)
        Z_train_shuffled = Z_train[indices]
        C_train_shuffled = C_train[indices]

        for i in range(0, num_train_samples, batch_size):
            Z_batch = Z_train_shuffled[i:i + batch_size]
            C_batch = C_train_shuffled[i:i + batch_size]

            with tf.GradientTape() as tape:
                C_pred = model(Z_batch, training=True)
                C_batch_tf = tf.convert_to_tensor(C_batch, dtype=tf.float32)
                loss_batch = [spo_plus_loss_eq(c_t, c_p)
                              for c_t, c_p in zip(C_batch_tf, C_pred)]
                loss = tf.reduce_mean(loss_batch)

            grads = tape.gradient(loss, model.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, clip_norm)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        total_regret = 0.0
        for z_test, c_test in zip(Z_test, C_test):
            x_true_optimal = solve_with_scipy(c_test, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            c_pred_test = model(z_test.reshape(1, -1), training=False)[0].numpy()
            x_pred_optimal = solve_with_scipy(c_pred_test, A_eq=A_eq, b_eq=b_eq, bounds=bounds).x
            regret_val = calculate_regret(c_test, x_true_optimal, x_pred_optimal)
            total_regret += regret_val

        avg_regret = total_regret / num_test_samples
        regret_history_spo.append(avg_regret)
        print(f"[SPO+] Epoch {epoch + 1}, Avg. Regret: {avg_regret:.4f}")

    return model, regret_history_spo


# --- Main Script ---
if __name__ == '__main__':
    # ========== PARAMETER DEFINITIONS ==========

    # 1. Load the AFIRO benchmark problem (only AFIRO is supported here)
    datasets_root = Path(__file__).resolve().parents[1] / "datasets"
    mat_path = datasets_root / "lp_afiro.mat"
    print(f"Loading LP problem 'afiro' from {mat_path}")
    c_base, A_eq, b_eq, bounds, _ = load_lp_problem(str(mat_path))
    # c_base = c_base + 1.*np.random.rand(len(c_base))#np.ones_like(c_base)#
    num_variables = len(c_base)
    A_eq = A_eq.toarray()
    
    # 2. Dataset Parameters
    # num_active_costs = len(np.where(c_base != 0)[0])
    feature_dim = len(c_base)#num_active_costs
    num_train_samples = 128
    num_test_samples = 128
    use_scaling = True
    scale_factor = 0.1
    
    # 3. Training Parameters
    # Weak DFL (Optuna-best configuration from optuna_weak_dfl_iclr.py)
    use_optuna_best = True

    if use_optuna_best:
        learning_rate = 0.018761082439630705
        num_epochs = 60
        batch_size = 40
        tau = 0.3
        lambda_value = 0.0002297662414082499
        n_o_delta = 1167.9817513130797
        n_samples = 14
        n_e_scale = 784.3006551564999
    else:
        learning_rate = 0.015702970884055395
        num_epochs = 100
        batch_size = 62
        tau = 0.3
        lambda_value = 0.18272261776066237
        n_o_delta = 1e2
        n_samples = 13
        n_e_scale = 1e3

    # SPO+ (Optuna-best hyperparameters)
    spo_learning_rate = 0.009271215676808907
    spo_batch_size = 16
    spo_beta_1 = 0.8910331049441603
    spo_beta_2 = 0.9986879553659828
    spo_clip_norm = 0.5529579141551428

    # Control whether to run SPO+ during tuning (set to True for final comparison)
    RUN_SPO_PLUS = True

    # ========== DATA GENERATION ==========
    
    print("--- Generating Correlated Dataset for AFIRO ---")
    print(f"Using {'SCALED' if use_scaling else 'UNSCALED'} perturbations")
    if use_scaling:
        print(f"Scale factor: {scale_factor}")
    
    reset_random_state(1234)
    rng = np.random.default_rng(1234)
    
    Z_train, C_train, W_true = generate_correlated_afiro_dataset(
        c_base, num_samples=num_train_samples, feature_dim=feature_dim,
        noise_std=0.1, rng=rng, use_scaling=use_scaling, scale_factor=scale_factor
    )
    Z_test, C_test, _ = generate_correlated_afiro_dataset(
        c_base, num_samples=num_test_samples, feature_dim=feature_dim,
        noise_std=0.05, W_true=W_true, rng=rng, use_scaling=use_scaling, scale_factor=scale_factor
    )
    print("Dataset generated.\n")
    
    # Dataset statistics
    print("--- Dataset Statistics ---")
    print(f"c_base range: [{c_base.min():.4f}, {c_base.max():.4f}]")
    print(f"c_base non-zero elements: {np.count_nonzero(c_base)}")
    print(f"C_train range: [{C_train.min():.4f}, {C_train.max():.4f}]")
    print(f"C_test range: [{C_test.min():.4f}, {C_test.max():.4f}]")
    print(f"Number of negative costs in C_train: {(C_train < 0).sum()}")
    print(f"Perturbation std: {(C_train - c_base).std():.4f}")
    print(f"Base std: {c_base.std():.4f}")
    print()
    
    # ========== TRAINING ==========
    
    model, regret_history, obj_values_history, bias_to_cbase_metrics = train_weak_dfl(
        c_base, A_eq, b_eq, bounds, Z_train, C_train, Z_test, C_test,
        learning_rate, num_epochs, batch_size, tau, lambda_value,
        n_o_delta, n_samples, n_e_scale, feature_dim
    )
    
    regret_history_spo = None
    if RUN_SPO_PLUS:
        _, regret_history_spo = train_spoplus(
            c_base, A_eq, b_eq, bounds, Z_train, C_train, Z_test, C_test,
            spo_learning_rate, num_epochs, spo_batch_size, feature_dim,
            beta_1=spo_beta_1, beta_2=spo_beta_2, clip_norm=spo_clip_norm
        )

    # ========== PLOTTING ==========

    gen_plots(model, regret_history, obj_values_history, bias_to_cbase_metrics,
              lambda_value, tau, c_base, regret_history_spo)
