import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from ortools.linear_solver import pywraplp
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle as sk_shuffle
import matplotlib.pyplot as plt
import sys
from pathlib import Path

# Ensure we can import common_utils from the additional_experiments utils directory
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "utils"))
from common_utils import (
    mpll_pref,
    bce_loss_pref,
    hinge_loss_pref,
    calculate_regret,
    reset_random_state
)


class knapsack_solver:
    """
    OR-Tools based solver for the knapsack problem.
    This class is framework-agnostic (not tied to PyTorch or TensorFlow).
    """

    def __init__(self, weights, capacity, n_items):
        self.weights = weights
        self.capacity = capacity
        self.n_items = n_items
        self.make_model()

    def make_model(self):
        # Create the solver.
        solver = pywraplp.Solver.CreateSolver('SCIP')
        # Create the binary variables.
        x = {i: solver.BoolVar(f'x_{i}') for i in range(self.n_items)}
        # Add the capacity constraint.
        solver.Add(sum(x[i] * self.weights[i] for i in range(self.n_items)) <= self.capacity)
        self.x = x
        self.solver = solver

    def solve(self, y):
        # Cast profits to float64 for the solver.
        y = y.astype(np.float64)
        objective = self.solver.Objective()
        # Set the objective function coefficients.
        for i in range(self.n_items):
            # Explicitly cast the coefficient to a native Python float to fix TypeError.
            objective.SetCoefficient(self.x[i], float(y[i]))
        objective.SetMaximization()

        status = self.solver.Solve()

        if status == pywraplp.Solver.OPTIMAL:
            sol = np.zeros(self.n_items)
            for i in range(self.n_items):
                # Corrected to use self.x to access class-level variables
                sol[i] = self.x[i].solution_value()
            return sol
        else:
            print("Warning: No optimal solution found for a given y.")
            return np.zeros(self.n_items)


class KnapsackDataModule:
    """
    Handles data loading, preprocessing, and dataset creation using tf.data.
    This class is the TensorFlow equivalent of the PyTorch Lightning DataModule.
    """

    def __init__(self, capacity, standardize=True, batch_size=16, seed=0):
        self.capacity = capacity
        self.standardize = standardize
        self.batch_size = batch_size
        self.seed = seed
        self._prepare_data()

    def _prepare_data(self):
        """Loads and processes the data from the NPZ file."""
        # A Data.npz file is required here.
        # You can create a dummy one with:
        # np.savez('Data.npz', weights=np.random.randint(10, 30, 48),
        #          X_1gtrain=np.random.rand(28800, 6), y_train=np.random.rand(28800),
        #          X_1gtest=np.random.rand(9600, 6), y_test=np.random.rand(9600))
        # Resolve path to the shared datasets directory in additional_experiments.
        data_path = Path(__file__).resolve().parents[1] / 'datasets' / 'Data.npz'
        data = np.load(data_path)

        self.weights = np.array(data['weights'])
        self.n_items = len(self.weights)

        x_train, x_test = data['X_1gtrain'], data['X_1gtest']
        y_train, y_test = data['y_train'], data['y_test']

        # Slice off the first feature column as in the original code
        x_train = x_train[:, 1:]
        x_test = x_test[:, 1:]

        if self.standardize:
            self.scaler = StandardScaler()
            x_train = self.scaler.fit_transform(x_train)
            x_test = self.scaler.transform(x_test)

        # Reshape into sequences of items
        n_features = x_train.shape[1]
        x_train = x_train.reshape(-1, self.n_items, n_features)
        y_train = y_train.reshape(-1, self.n_items)
        x_test = x_test.reshape(-1, self.n_items, n_features)
        y_test = y_test.reshape(-1, self.n_items)

        # Combine, shuffle, and split
        x_full = np.concatenate((x_train, x_test), axis=0)
        y_full = np.concatenate((y_train, y_test), axis=0)
        x_shuffled, y_shuffled = sk_shuffle(x_full, y_full, random_state=self.seed)

        # Split into train, validation, and test sets
        self.x_train, self.y_train = x_shuffled[:550], y_shuffled[:550]
        self.x_valid, self.y_valid = x_shuffled[550:650], y_shuffled[550:650]
        self.x_test, self.y_test = x_shuffled[650:], y_shuffled[650:]

        # Create the optimal solutions for the training set
        solver = knapsack_solver(self.weights, capacity=self.capacity, n_items=self.n_items)
        self.train_solutions = np.array([solver.solve(y) for y in self.y_train])

    def _create_dataset(self, X, y, sol=None):
        """Creates a tf.data.Dataset from numpy arrays."""
        if sol is not None:
            dataset = tf.data.Dataset.from_tensor_slices(
                (X.astype(np.float32), y.astype(np.float32), sol.astype(np.float32)))
        else:
            dataset = tf.data.Dataset.from_tensor_slices((X.astype(np.float32), y.astype(np.float32)))

        return dataset.batch(self.batch_size).prefetch(tf.data.AUTOTUNE)

    def get_train_dataset(self):
        dataset = tf.data.Dataset.from_tensor_slices(
            (self.x_train.astype(np.float32), self.y_train.astype(np.float32), self.train_solutions.astype(np.float32)))
        return dataset.shuffle(buffer_size=len(self.x_train)).batch(self.batch_size).prefetch(tf.data.AUTOTUNE)

    def get_val_dataset(self):
        return self._create_dataset(self.x_valid, self.y_valid)

    def get_test_dataset(self):
        return self._create_dataset(self.x_test, self.y_test)


def create_model(n_items, n_features, *, seed=42):
    """
    Creates a simple model to predict item values for the knapsack problem.

    The model processes each item's features independently to predict its value.

    Args:
        n_items (int): The number of items in each knapsack instance.
        n_features (int): The number of features for each item.
        seed (int): Random seed for weight initialization.

    Returns:
        tf.keras.Model: A Keras model ready for training.
    """
    kernel_init = tf.keras.initializers.GlorotNormal(seed=seed)
    bias_init = tf.keras.initializers.GlorotUniform(seed=seed + 1)

    # Input shape is (n_items, n_features) for a single instance
    inputs = tf.keras.Input(shape=(n_items, n_features), dtype=tf.float32)

    # A Dense layer to project item features to a single value
    dense_layer = layers.Dense(
        1,
        # activation='softplus',
        kernel_initializer=kernel_init,
        bias_initializer=bias_init,
    )

    # Apply the same dense layer to each item's feature vector
    time_distributed_output = layers.TimeDistributed(dense_layer)(inputs)

    # Squeeze the last dimension to get an output shape of (n_items,)
    # Using a Lambda layer with tf.squeeze for better compatibility across TF versions.
    outputs = layers.Lambda(lambda x: tf.squeeze(x, axis=-1))(time_distributed_output)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name="knapsack_model")


def find_nearby_feasible_solutions(sol_pred, weights, capacity, k):
    """
    Finds nearby feasible solutions by flipping one bit of a given solution.

    This function is designed to be efficient by avoiding repeated calculations
    and is not meant to be part of the TensorFlow computation graph.

    Args:
        sol_pred (np.array): A binary solution vector (e.g., from the knapsack solver).
        weights (np.array): The weights of the items.
        capacity (float): The capacity of the knapsack.
        k (int): The maximum number of nearby solutions to return.

    Returns:
        list: A list of nearby feasible solution vectors (np.array).
    """
    nearby_solutions = []
    current_weight = np.dot(sol_pred, weights)

    # To make the search for neighbors non-deterministic and varied,
    # we can iterate through the indices in a random order.
    indices = np.arange(len(sol_pred))
    np.random.shuffle(indices)

    for i in indices:
        if len(nearby_solutions) >= k:
            break

        # This operation is much faster than a full copy
        flipped_val = 1 - sol_pred[i]

        # Efficiently check feasibility based on the flipped bit
        if flipped_val == 1:  # Flipped from 0 to 1, weight increases
            if current_weight + weights[i] <= capacity:
                neighbor = sol_pred.copy()
                neighbor[i] = flipped_val
                nearby_solutions.append(neighbor)
        else:  # Flipped from 1 to 0, weight decreases
            # This is always feasible.
            neighbor = sol_pred.copy()
            neighbor[i] = flipped_val
            nearby_solutions.append(neighbor)

    return nearby_solutions


def make_spo_plus_loss_knapsack(solver, minimise=False):
    """TensorFlow SPO+ loss matching the canonical SPOlayer behavior.

    This implementation mirrors dfl_tests_LO/experiments/knapsack_experiment/knapsack_module.py:
    - Forward: uses the true optimal solution sol_true and the predicted solution sol_hat
      to compute a regret-like objective (mm * (sol_hat - sol_true) * y_true).
    - Backward: uses the SPO+ gradient via the perturbed costs y_spo = 2*y_hat - y_true.
    """
    mm = 1.0 if minimise else -1.0

    @tf.custom_gradient
    def spo_plus_loss_knapsack(y_hat, y_true, sol_true):
        def spo_numpy(y_hat_np, y_true_np, sol_true_np):
            y_hat_np = np.asarray(y_hat_np, dtype=np.float64)
            y_true_np = np.asarray(y_true_np, dtype=np.float64)
            sol_true_np = np.asarray(sol_true_np, dtype=np.float64)

            sol_hat_np = solver.solve(y_hat_np)
            loss = mm * float(np.dot((sol_hat_np - sol_true_np), y_true_np))

            y_spo_np = 2.0 * y_hat_np - y_true_np
            sol_spo_np = solver.solve(y_spo_np)

            return (
                np.float32(loss),
                sol_true_np.astype(np.float32),
                sol_spo_np.astype(np.float32),
            )

        loss, sol_true_tensor, sol_spo_tensor = tf.numpy_function(
            func=spo_numpy,
            inp=[y_hat, y_true, sol_true],
            Tout=[tf.float32, tf.float32, tf.float32],
        )

        loss.set_shape(())
        sol_true_tensor.set_shape(y_hat.shape)
        sol_spo_tensor.set_shape(y_hat.shape)

        def grad(dy):
            # SPO+ gradient: grad_y_hat = dy * mm * (sol_true - sol_spo)
            grad_y_hat = dy * mm * (sol_true_tensor - sol_spo_tensor)
            grad_y_true = None
            grad_sol_true = None
            return grad_y_hat, grad_y_true, grad_sol_true

        return loss, grad

    return spo_plus_loss_knapsack


def train_spoplus_knapsack(data, solver, epochs, learning_rate=1e-3, clip_norm=1.0, seed=123):
    reset_random_state(seed)
    n_items = data.n_items
    n_features = data.x_train.shape[-1]
    model = create_model(n_items, n_features, seed=seed)
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    spo_plus_loss_fn = make_spo_plus_loss_knapsack(solver, minimise=False)
    val_regret_history = []

    for epoch in range(1, epochs + 1):
        train_df = data.get_train_dataset()
        for batch_idx, (x, y, sol) in enumerate(train_df, start=1):
            with tf.GradientTape() as tape:
                y_hat_batch = model(x, training=True)
                y_true_batch = tf.cast(y, tf.float32)
                sol_true_batch = tf.cast(sol, tf.float32)

                batch_losses = []
                for y_hat, y_true, sol_true in zip(
                    tf.unstack(y_hat_batch),
                    tf.unstack(y_true_batch),
                    tf.unstack(sol_true_batch),
                ):
                    batch_losses.append(spo_plus_loss_fn(y_hat, y_true, sol_true))

                if batch_losses:
                    loss = tf.reduce_mean(batch_losses)
                else:
                    loss = tf.constant(0.0, dtype=tf.float32)

            grads = tape.gradient(loss, model.trainable_variables)
            if grads:
                grads, _ = tf.clip_by_global_norm(grads, clip_norm)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))

        val_regrets = []
        val_df = data.get_val_dataset()
        for xv, yv in val_df:
            y_pred_val = model(xv, training=False)
            for i in range(y_pred_val.shape[0]):
                y_pred_np = y_pred_val[i].numpy()
                y_true_np = yv[i].numpy()
                sol_pred = solver.solve(y_pred_np)
                sol_true = solver.solve(y_true_np)
                pred_cost = np.dot(y_true_np, sol_pred)
                true_cost = np.dot(y_true_np, sol_true)
                val_regrets.append((true_cost - pred_cost) * 100 / true_cost)

        epoch_val_regret = float(np.mean(val_regrets)) if val_regrets else 0.0
        val_regret_history.append(epoch_val_regret)
        print(
            f"[SPO+] Epoch {epoch}/{epochs} | mean_val_regret = {epoch_val_regret:.4f}"
        )

    return model, val_regret_history

if __name__ == "__main__":
    seed = 42
    np.random.seed(seed)
    tf.random.set_seed(seed)
    capacity_knapsack = 60
    n_samples = 10
    epochs = 50
    tau = 0.3
    lambda_value = 0.2
    learning_rate = 1e-3
    batch_size = 8
    decay_steps = 28*69
    decay_rate = 0.5

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        learning_rate,
        decay_steps=decay_steps,
        decay_rate=decay_rate,
        staircase=True)
    model = create_model(48, 8, seed=seed)
    optimizer = tf.keras.optimizers.Adam(lr_schedule)



    data = KnapsackDataModule(capacity=capacity_knapsack, batch_size=batch_size)
    train_df = data.get_train_dataset()
    # Instantiate the solver once outside the loop for efficiency
    solver = knapsack_solver(data.weights, capacity_knapsack, len(data.weights))

    epoch_train_history = []
    epoch_val_regret_history = []

    for epoch in range(1, epochs + 1):
        train_losses = []
        print(f"\n===== Epoch {epoch}/{epochs} =====")
        for batch_idx, (x, y, sol) in enumerate(train_df, start=1):
            with tf.GradientTape() as tape:
                batch_loss = tf.constant(0.0, dtype=tf.float32)
                y_pred = model(x, training=True)
                with tape.stop_recording():
                    y_pred_np = tf.stop_gradient(y_pred).numpy()

                    batch_sols_pred = []
                    batch_negatives = []
                    batch_labels = []
                    for i in range(y_pred_np.shape[0]):
                        # Process one prediction at a time
                        single_y_pred = y_pred_np[i]

                        sol_pred = solver.solve(single_y_pred)
                        negatives = find_nearby_feasible_solutions(sol_pred, data.weights, capacity_knapsack, n_samples)
                        utility_true_val = float(np.dot(y[i].numpy(), sol[i]))
                        utility_pred_val = float(np.dot(y[i].numpy(), sol_pred))
                        label_val = 1.0 if abs(utility_pred_val - utility_true_val) <= tau * (
                                    abs(utility_true_val) + 1e-12) else -1.0
                        batch_sols_pred.append(sol_pred)
                        batch_negatives.append(negatives)
                        batch_labels.append(label_val)

                for i in range(y_pred_np.shape[0]):
                    x_hat_tf = tf.constant(batch_sols_pred[i], dtype=tf.float32)
                    q_negatives = tf.constant(batch_negatives[i], dtype=tf.float32)
                    labels = tf.constant(batch_labels[i], dtype=tf.float32)
                    pl_loss = mpll_pref(y_pred[i], x_hat_tf, labels, q_negatives, maximise=True)
                    hinge_loss = hinge_loss_pref(y_pred[i], x_hat_tf, labels, q_negatives)
                    batch_loss += (1 - lambda_value) * pl_loss + lambda_value * hinge_loss
            trainable_vars = model.trainable_variables
            grads = tape.gradient(batch_loss, trainable_vars)
            grads_and_vars = [(g, v) for g, v in zip(grads, trainable_vars) if g is not None]
            if grads_and_vars:
                optimizer.apply_gradients(grads_and_vars)
            batch_loss_value = float(batch_loss.numpy())
            train_losses.append(batch_loss_value)
            # print(
            #     f"[Epoch {epoch}] Batch {batch_idx:03d} | train_loss = {batch_loss_value:.4f}"
            # )

        # Validation pass after epoch
        val_regrets = []
        val_df = data.get_val_dataset()
        for val_batch_idx, (xv, yv) in enumerate(val_df, start=1):
            y_pred_val = model(xv, training=False)
            batch_regrets = []
            for i in range(y_pred_val.shape[0]):
                y_pred_np = y_pred_val[i].numpy()
                y_true_np = yv[i].numpy()
                sol_pred = solver.solve(y_pred_np)
                sol_true = solver.solve(y_true_np)
                pred_cost = np.dot(y_true_np, sol_pred)
                true_cost = np.dot(y_true_np, sol_true)
                batch_regrets.append((true_cost - pred_cost)*100/true_cost)
            if batch_regrets:
                batch_regret_value = float(np.mean(batch_regrets))
                val_regrets.append(batch_regret_value)
                # print(
                #     f"[Epoch {epoch}] Val batch {val_batch_idx:03d} | mean_regret = {batch_regret_value:.4f}"
                # )

        epoch_train_loss = float(np.mean(train_losses)) if train_losses else 0.0
        epoch_val_regret = float(np.mean(val_regrets)) if val_regrets else 0.0
        epoch_train_history.append(epoch_train_loss)
        epoch_val_regret_history.append(epoch_val_regret)
        print(
            f"Epoch {epoch} summary | mean_train_loss = {epoch_train_loss:.4f} | "
            f"mean_val_regret = {epoch_val_regret:.4f} | batches = {len(train_losses)}"
        )

    spo_learning_rate = learning_rate
    spo_clip_norm = 1.0

    _, spo_val_regret_history = train_spoplus_knapsack(
        data=data,
        solver=solver,
        epochs=epochs,
        learning_rate=spo_learning_rate,
        clip_norm=spo_clip_norm,
        seed=seed + 1,
    )

    if epoch_val_regret_history and spo_val_regret_history:
        epochs_range = np.arange(1, len(epoch_val_regret_history) + 1)
        fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4))
        ax2.plot(epochs_range, epoch_val_regret_history, marker='o', label='Weak DFL (pref)')
        ax2.plot(epochs_range, spo_val_regret_history, marker='x', label='SPO+')
        ax2.set_title('Knapsack: Mean Validation Regret per Epoch')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Relative Regret')
        ax2.legend()
        fig2.tight_layout()
        fig2.savefig('training_metrics_weak_vs_spoplus.png')
        plt.close(fig2)
