"""
Solve the AFIRO LP problem from MAT file using pre-built solvers.

Problem format: minimize c^T x subject to Ax = b, lo <= x <= hi
"""

import numpy as np
import scipy.io
from scipy.optimize import linprog
import time
from pathlib import Path


def load_lp_problem(mat_file_path):
    """
    Load LP problem from MAT file.
    
    Returns:
        c: cost vector (n,)
        A_eq: equality constraint matrix (m, n)
        b_eq: equality constraint vector (m,)
        bounds: list of (lo, hi) tuples for each variable
        problem_info: dict with problem metadata
    """
    data = scipy.io.loadmat(mat_file_path)
    prob = data['Problem'][0, 0]
    aux = prob['aux'][0, 0]
    
    # Extract problem data
    A_eq = prob['A']  # (m, n)
    b_eq = prob['b'].flatten()  # (m,)
    c = aux['c'].flatten()  # (n,)
    lo = aux['lo'].flatten()  # (n,)
    hi = aux['hi'].flatten()  # (n,)
    
    # Create bounds list
    bounds = [(lo[i], hi[i]) for i in range(len(c))]
    
    # Problem metadata
    problem_info = {
        'name': prob['name'][0] if prob['name'].size > 0 else 'Unknown',
        'type': prob['kind'][0] if prob['kind'].size > 0 else 'Unknown',
        'num_constraints': A_eq.shape[0],
        'num_variables': A_eq.shape[1],
        'title': prob['title'][0] if prob['title'].size > 0 else 'Unknown'
    }
    
    return c, A_eq, b_eq, bounds, problem_info


def solve_with_scipy(c, A_eq, b_eq, bounds, method='highs', verbose=False):
    """
    Solve LP using scipy.optimize.linprog.
    
    Args:
        c: cost vector
        A_eq: equality constraint matrix
        b_eq: equality constraint vector
        bounds: variable bounds
        method: solver method ('highs', 'highs-ds', 'highs-ipm', 'interior-point', 'revised simplex')
    
    Returns:
        result: OptimizeResult object
    """
    # print(f"\n{'='*60}")
    # print(f"Solving with scipy.optimize.linprog (method='{method}')")
    # print(f"{'='*60}")
    
    start_time = time.time()
    
    result = linprog(
        c=c,
        A_eq=A_eq,
        b_eq=b_eq,
        bounds=bounds,
        method=method,
        options={'disp': True} if verbose else {'disp': False}
    )
    
    solve_time = time.time() - start_time
    if verbose:
        print(f"\nSolver Status: {result.message}")
        print(f"Success: {result.success}")
        print(f"Optimal Value: {result.fun:.10f}" if result.success else "No solution found")
        print(f"Iterations: {result.nit if hasattr(result, 'nit') else 'N/A'}")
        print(f"Solve Time: {solve_time:.4f} seconds")
    
    if result.success:
        if verbose:
            print(f"\nSolution vector x (first 10 components):")
            print(result.x[:10])
            print(f"\nConstraint violation (||Ax - b||): {np.linalg.norm(A_eq @ result.x - b_eq):.2e}")
        
        # Check bounds
        if verbose:
            lo_violations = np.sum(result.x < np.array([b[0] for b in bounds]) - 1e-6)
            hi_violations = np.sum(result.x > np.array([b[1] for b in bounds]) + 1e-6)
            print(f"Bound violations: {lo_violations + hi_violations}")
    
    return result


def solve_with_cvxpy(c, A_eq, b_eq, bounds):
    """
    Solve LP using CVXPY (if available).
    
    Args:
        c: cost vector
        A_eq: equality constraint matrix
        b_eq: equality constraint vector
        bounds: variable bounds
    
    Returns:
        problem: CVXPY Problem object
        x_opt: optimal solution (or None)
    """
    try:
        import cvxpy as cp
    except ImportError:
        print("\n" + "="*60)
        print("CVXPY not available. Install with: pip install cvxpy")
        print("="*60)
        return None, None
    
    print(f"\n{'='*60}")
    print(f"Solving with CVXPY")
    print(f"{'='*60}")
    
    n = len(c)
    x = cp.Variable(n)
    
    # Extract bounds
    lo = np.array([b[0] for b in bounds])
    hi = np.array([b[1] for b in bounds])
    
    # Define constraints
    constraints = [A_eq @ x == b_eq]
    constraints.append(x >= lo)
    constraints.append(x <= hi)
    
    # Define objective
    objective = cp.Minimize(c @ x)
    
    # Create and solve problem
    problem = cp.Problem(objective, constraints)
    
    start_time = time.time()
    problem.solve(solver=cp.ECOS, verbose=True)
    solve_time = time.time() - start_time
    
    print(f"\nSolver Status: {problem.status}")
    print(f"Optimal Value: {problem.value:.10f}" if problem.status == 'optimal' else "No solution found")
    print(f"Solve Time: {solve_time:.4f} seconds")
    
    if problem.status == 'optimal':
        x_opt = x.value
        print(f"\nSolution vector x (first 10 components):")
        print(x_opt[:10])
        print(f"\nConstraint violation (||Ax - b||): {np.linalg.norm(A_eq @ x_opt - b_eq):.2e}")
    else:
        x_opt = None
    
    return problem, x_opt


def main():
    # Path to MAT file
    mat_file = Path(__file__).parent / 'datasets' / 'lp_afiro.mat'
    
    print("="*60)
    print("AFIRO LP Problem Solver")
    print("="*60)
    
    # Load problem
    print("\nLoading problem from:", mat_file)
    c, A_eq, b_eq, bounds, info = load_lp_problem(mat_file)
    
    print("\nProblem Information:")
    print(f"  Name: {info['name']}")
    print(f"  Type: {info['type']}")
    print(f"  Title: {info['title']}")
    print(f"  Number of variables: {info['num_variables']}")
    print(f"  Number of constraints: {info['num_constraints']}")
    print(f"  Objective coefficients (first 10): {c[:10]}")
    
    # Solve with different scipy methods
    methods = ['highs', 'highs-ds', 'highs-ipm']
    results = {}
    
    for method in methods:
        try:
            result = solve_with_scipy(c, A_eq, b_eq, bounds, method=method)
            results[method] = result
        except Exception as e:
            print(f"\nError with method '{method}': {e}")
    
    # Try CVXPY
    try:
        cvxpy_problem, cvxpy_x = solve_with_cvxpy(c, A_eq, b_eq, bounds)
    except Exception as e:
        print(f"\nError with CVXPY: {e}")
    
    # Summary
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    
    for method, result in results.items():
        if result.success:
            print(f"{method:20s}: Optimal value = {result.fun:.10f}")
        else:
            print(f"{method:20s}: Failed - {result.message}")
    
    print("\nNote: The AFIRO problem is a well-known LP test problem from NETLIB.")
    print("Expected optimal value: -464.7531428571 (approximately)")


if __name__ == "__main__":
    main()
