import numpy as np
from scipy.optimize import linprog
import time
import matplotlib.pyplot as plt
import os
import pulp

from solve_afiro_lp import solve_with_scipy


def _solve_lp(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=None, new_objective=None, verbose=True):
    """
    Internal helper function to solve the linear programming problem.
    Assumes a MAXIMIZATION problem, so it minimizes the negative of the objective.
    """
    # For maximization, we minimize the negative of the objective function
    objective = np.array(c, dtype=float)
    if new_objective is not None:
        # The new_objective for finding adjacent vertices is already set for minimization
        objective = new_objective

    if verbose and new_objective is None:
        print("Solving for OPTIMAL solution (Maximization)...")

    result = linprog(objective, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
    return result


def _find_nearby_vertices(optimal_solution, c, A_ub, b_ub, A_eq, b_eq, bounds,
                          num_samples: int = 12,
                          eps_scales = (1e-6, 1e-5, 1e-4),
                          min_obj_delta: float = 1e-6,
                          tolerance: float = 1e-6, seed: int = 42, verbose: bool = True):
    """
    Find alternative nearby vertices by slightly jittering the objective and re-solving.

    This breaks degeneracy and encourages the solver to choose different extreme points
    that are close to the optimal one, without requiring strict adjacency.
    """
    if verbose:
        print("\n--- Finding Nearby Vertices via Objective Jitter ---")

    x_optimal = optimal_solution.x
    opt_val = float(np.dot(c, x_optimal))
    n = len(c)
    rng = np.random.default_rng(seed)

    candidates = []
    tried = 0
    while tried < num_samples:
        for scale in eps_scales:
            delta = scale * rng.normal(size=n)
            # Minimization objective for linprog via _solve_lp new_objective
            new_objective = (np.array(c, dtype=float) + delta)
            res = _solve_lp(c, A_ub, b_ub, A_eq, b_eq, bounds, new_objective=new_objective, verbose=False)
            tried += 1
            if res.success and not np.allclose(res.x, x_optimal, atol=tolerance):
                obj_val = float(np.dot(c, res.x))
                if abs(obj_val - opt_val) >= min_obj_delta:
                    candidates.append(res.x)
                else:
                    # Still keep as a nearby but same-objective solution candidate (degeneracy)
                    candidates.append(res.x)
            if tried >= num_samples:
                break

    # Deduplicate
    unique_vertices = []
    for v in candidates:
        if not any(np.allclose(v, u, atol=tolerance) for u in unique_vertices):
            unique_vertices.append(v)

    if verbose:
        if unique_vertices:
            print(f"Found {len(unique_vertices)} unique nearby vertices.")
        else:
            print("Could not find any distinct nearby vertices (increase eps_obj or num_samples).")
    return unique_vertices


def analyze_lp(c, A_ub=None, b_ub=None, A_eq=None, b_eq=None, bounds=None, verbose=True,
               nearby_samples: int =10, seed: int = 42,
               nearby_eps_scales = (1e-2, 1e-1, 1., 1e1, 1e2), nearby_min_obj_delta: float = 1e1):
    """
    Analyzes an LP problem to find the optimal solution and near-optimal adjacent vertices.
    Assumes the problem is for MAXIMIZATION.
    """
    start_time = time.time()

    if bounds is None:
        if A_ub is None:
            n = len(c)
            bounds = [(0, None)] * n
        else:
            raise ValueError("Bounds must be provided for problems without explicit non-negativity constraints.")

    # optimal_result = _solve_lp(c, A_ub, b_ub, A_eq, b_eq, bounds, verbose=verbose)
    optimal_result = solve_with_scipy(c, A_eq, b_eq, bounds)

    if not optimal_result.success:
        if verbose:
            print(
                f"Could not find an optimal solution. Solver status: {optimal_result.status} - {optimal_result.message}")
        return None, None
    # else:
    #     x_val = np.dot(c, optimal_result.x)
    #     scales = (x_val,)
    # near_optimal_solutions = _find_nearby_vertices(
    #     optimal_result, c, A_ub, b_ub, A_eq, b_eq, bounds,
    #     num_samples=nearby_samples, eps_scales=scales,
    #     min_obj_delta=nearby_min_obj_delta, seed=seed, verbose=verbose
    # )
    near_optimal_solutions = _find_nearby_vertices(
        optimal_result, c, A_ub, b_ub, A_eq, b_eq, bounds,
        num_samples=nearby_samples, eps_scales=nearby_eps_scales,
        min_obj_delta=nearby_min_obj_delta, seed=seed, verbose=verbose
    )
    unique_a = []
    [
        unique_a.append(arr)
        for arr in near_optimal_solutions
        if not np.allclose(arr, optimal_result.x) and not any(np.allclose(arr, ua) for ua in unique_a)
    ]
    end_time = time.time()
    if verbose:
        print(f"\nAnalysis complete in {end_time - start_time:.4f} seconds.")

    return optimal_result.x, unique_a
    # return optimal_result, near_optimal_solutions

def load_mps_benchmark(mps_path: str):
    """
    Loads a standard LP problem from an MPS file (e.g., Netlib AFIRO) using PuLP,
    and converts it to SciPy linprog matrix form.

    We return coefficients for a MAXIMIZATION formulation (since _solve_lp performs
    minimization of the negative objective internally).

    Returns: (c, A_ub, b_ub, A_eq, b_eq, bounds)
    """
    print(f"--- Loading standard LP from MPS: {mps_path} ---")

    # Parse MPS (PuLP versions may return LpProblem, a tuple/list, or a dict)
    parsed = pulp.LpProblem.fromMPS(mps_path, sense=pulp.LpMinimize)
    prob = None
    if hasattr(parsed, 'variables'):
        prob = parsed
    elif isinstance(parsed, (list, tuple)):
        for item in parsed:
            if hasattr(item, 'variables'):
                prob = item
                break
    elif isinstance(parsed, dict):
        for item in parsed.values():
            if hasattr(item, 'variables'):
                prob = item
                break
    if prob is None:
        raise TypeError(f"Unable to parse MPS file into a PuLP LpProblem. Got type: {type(parsed)}")

    # Variables ordered by name for consistency
    variables = sorted(prob.variables(), key=lambda v: v.name)
    var_map = {v.name: i for i, v in enumerate(variables)}
    num_vars = len(variables)

    # Extract objective as vector (PuLP stores minimization by default when reading MPS)
    obj_vec = np.zeros(num_vars, dtype=float)
    for var, coeff in prob.objective.items():
        obj_vec[var_map[var.name]] = float(coeff)

    # Convert to maximization coefficients c
    # If the problem is minimize f(x), then for our maximization interface, use c = -obj_vec
    # so that maximize c^T x is equivalent to minimize obj_vec^T x.
    c = -obj_vec

    # Extract constraints
    A_ub_rows, b_ub_vals = [], []
    A_eq_rows, b_eq_vals = [], []

    for const in prob.constraints.values():
        row = np.zeros(num_vars, dtype=float)
        for var, coeff in const.items():
            row[var_map[var.name]] = float(coeff)

        # PuLP stores constraints in canonical form: expr + constant sense 0
        # For a constraint like a^T x <= b, PuLP stores constant = -b.
        rhs = -float(const.constant)

        if const.sense == pulp.LpConstraintLE:  # <=
            A_ub_rows.append(row)
            b_ub_vals.append(rhs)
        elif const.sense == pulp.LpConstraintGE:  # >=  => multiply by -1 to convert to <=
            A_ub_rows.append(-row)
            b_ub_vals.append(-rhs)
        elif const.sense == pulp.LpConstraintEQ:  # ==
            A_eq_rows.append(row)
            b_eq_vals.append(rhs)

    A_ub = np.array(A_ub_rows, dtype=float) if A_ub_rows else None
    b_ub = np.array(b_ub_vals, dtype=float) if b_ub_vals else None
    A_eq = np.array(A_eq_rows, dtype=float) if A_eq_rows else None
    b_eq = np.array(b_eq_vals, dtype=float) if b_eq_vals else None

    # Variable bounds
    bounds = [(v.lowBound, v.upBound) for v in variables]

    print("Benchmark MPS problem loaded and framed for MAXIMIZATION.\n")
    return c, A_ub, b_ub, A_eq, b_eq, bounds


def visualize_results(optimal_result, adjacent_vertices, c):
    """
    Creates plots to visualize the optimization results.
    """
    print("\n--- Generating Visualizations ---")
    x_optimal = optimal_result.x
    optimal_value = -optimal_result.fun

    plt.style.use('seaborn-v0_8-whitegrid')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    fig.suptitle("LP Optimization Analysis (Standard MPS Benchmark)", fontsize=18)

    # Plot 1: Composition of the Optimal Solution
    solution_indices = np.where(x_optimal > 1e-5)[0]
    solution_quantities = x_optimal[solution_indices]
    labels = [f"x[{i}]" for i in solution_indices]

    ax1.bar(labels, solution_quantities, color='skyblue', edgecolor='black', linewidth=0.7)
    ax1.set_title(f"Optimal Solution Composition ({len(labels)} Non-Zero Variables)", fontsize=14)
    ax1.set_xlabel("Decision Variable", fontsize=12)
    ax1.set_ylabel("Value", fontsize=12)
    ax1.tick_params(axis='x', rotation=90)

    # Plot 2: Value Comparison
    values = [optimal_value] + [np.dot(c, v) for v in adjacent_vertices]
    value_labels = ['Optimal'] + [f'Adj. {i + 1}' for i in range(len(adjacent_vertices))]
    colors = ['cornflowerblue'] + ['lightcoral'] * len(adjacent_vertices)
    bars = ax2.bar(value_labels, values, color=colors, edgecolor='black', linewidth=0.7)

    ax2.set_title("Objective Value Comparison: Optimal vs. Adjacent", fontsize=14)
    ax2.set_ylabel("Objective Value (to be maximized)", fontsize=12)
    ax2.bar_label(bars, fmt='%.4f', padding=3, fontsize=10)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    print("Displaying plots. Close the plot window to exit the script.")
    plt.show()


def main():
    """
    Main function to demonstrate the use of the analyze_lp module with a benchmark problem.
    """
    # 1. Load the benchmark LP problem from an MPS file (e.g., AFIRO)
    candidate_paths = [
        "./datasets/afiro.mps",                 # project-level datasets directory
        "dfl_tests_LO/datasets/afiro.mps",      # module-local datasets directory
    ]
    mps_path = None
    for p in candidate_paths:
        if os.path.exists(p):
            mps_path = p
            break
    if mps_path is None:
        raise FileNotFoundError("Could not locate afiro.mps. Checked: " + ", ".join(candidate_paths))
    c, A_ub, b_ub, A_eq, b_eq, bounds = load_mps_benchmark(mps_path)

    # 2. Analyze the problem
    optimal_solution, near_optimal_solutions = analyze_lp(c, A_ub, b_ub, A_eq, b_eq, bounds)

    if optimal_solution.success:
        num_vars_in_solution = np.sum(optimal_solution.x > 1e-5)
        print("\n--- Optimal Solution Summary ---")
        # The solver minimizes, so fun = -max_value. We negate it to get the maximized value.
        print(f"Optimal Objective Value: {-optimal_solution.fun:.4f}")
        print(f"Number of non-zero variables in solution: {num_vars_in_solution} out of {len(c)}")
        visualize_results(optimal_solution, near_optimal_solutions, c)
    else:
        print("\n--- Optimization Failed ---")
        print(f"Message from solver: {optimal_solution.message}")
        print(f"Status code: {optimal_solution.status}")


if __name__ == "__main__":
    main()

