import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import linear_sum_assignment

def compute_diagonal_stats(matrix):
    """
    Compute average diagonal and off-diagonal values of a matrix.
    
    Args:
        matrix: Input matrix of shape (n, n)
        
    Returns:
        Tuple of (avg_diagonal, avg_off_diagonal)
    """
    n = matrix.shape[0]
    diagonal_mask = np.eye(n, dtype=bool)
    off_diagonal_mask = ~diagonal_mask
    
    avg_diagonal = np.mean(matrix[diagonal_mask])
    avg_off_diagonal = np.mean(matrix[off_diagonal_mask])
    
    return avg_diagonal, avg_off_diagonal

def find_best_permutation(matrix, axis=0):
    """
    Find the permutation that maximizes the average diagonal value of a matrix.
    
    Args:
        matrix: Input matrix of shape (m, n) where m and n can be different
        axis: Axis along which to find permutation (0 for rows, 1 for columns)
        
    Returns:
        Tuple of (permuted_matrix, permutation_indices, avg_diagonal, avg_off_diagonal)
    """
    # Convert to numpy array if it's a JAX array
    if isinstance(matrix, jnp.ndarray):
        matrix = np.array(matrix)
    
    # If axis is 1, transpose the matrix to work with rows
    if axis == 1:
        matrix = matrix.T
    
    # Get matrix dimensions
    m, n = matrix.shape
    
    # Use Hungarian algorithm to find optimal permutation
    # For rectangular matrices, we'll optimize over the smaller dimension
    row_ind, col_ind = linear_sum_assignment(-matrix)  # Negative because we want to maximize
    
    # Apply permutation
    permuted_matrix = matrix[row_ind][:, col_ind]
    
    # Compute diagonal statistics - only consider the diagonal up to min(m,n)
    min_dim = min(m, n)
    diagonal_values = np.array([permuted_matrix[i, i] for i in range(min_dim)])
    avg_diagonal = np.mean(diagonal_values)
    
    # For off-diagonal, consider all non-diagonal elements in the min_dim x min_dim submatrix
    off_diagonal_mask = ~np.eye(min_dim, dtype=bool)
    off_diagonal_values = permuted_matrix[:min_dim, :min_dim][off_diagonal_mask]
    avg_off_diagonal = np.mean(off_diagonal_values) if off_diagonal_values.size > 0 else 0.0
    
    return permuted_matrix, col_ind, avg_diagonal, avg_off_diagonal

def find_best_permutation_greedy(matrix, axis=0):
    """
    Find the permutation that maximizes the average diagonal value using a greedy approach.
    This is faster than the Hungarian algorithm but may not find the global optimum.
    
    Args:
        matrix: Input matrix of shape (n, n)
        axis: Axis along which to find permutation (0 for rows, 1 for columns)
        
    Returns:
        Tuple of (permuted_matrix, permutation_indices, avg_diagonal, avg_off_diagonal)
    """
    # Convert to numpy array if it's a JAX array
    if isinstance(matrix, jnp.ndarray):
        matrix = np.array(matrix)
    
    # If axis is 1, transpose the matrix to work with rows
    if axis == 1:
        matrix = matrix.T
    
    n = matrix.shape[0]
    used_cols = set()
    permutation = np.zeros(n, dtype=int)
    
    # Greedily select best available column for each row
    for i in range(n):
        # Get available columns
        available_cols = [j for j in range(n) if j not in used_cols]
        if not available_cols:
            break
            
        # Find column that maximizes diagonal value
        best_col = max(available_cols, key=lambda j: matrix[i, j])
        permutation[i] = best_col
        used_cols.add(best_col)
    
    # Apply permutation
    permuted_matrix = matrix[np.arange(n)][:, permutation]
    
    # Compute diagonal statistics
    avg_diagonal, avg_off_diagonal = compute_diagonal_stats(permuted_matrix)
    
    return permuted_matrix, permutation, avg_diagonal, avg_off_diagonal

def test_permutation_optimizer():
    """Test the permutation optimizer functions with 2x2, 3x3, and 9x9 matrices."""
    
    # Test 2x2 matrix
    matrix_2x2 = np.array([[0.25, 0.5],
                          [0.75, 1.0]])
    
    # Test optimal permutation
    opt_perm_2x2, opt_ind_2x2, opt_diag_2x2, opt_off_diag_2x2 = find_best_permutation(matrix_2x2)
    print(f'opt_perm_2x2: {opt_perm_2x2}')
    print(f'opt_ind_2x2: {opt_ind_2x2}')
    print(f'opt_avg_diagonal: {opt_diag_2x2}')
    print(f'opt_avg_off_diagonal: {opt_off_diag_2x2}')
    
    # Test greedy permutation
    greedy_perm_2x2, greedy_ind_2x2, greedy_diag_2x2, greedy_off_diag_2x2 = find_best_permutation_greedy(matrix_2x2)
    print(f'greedy_avg_diagonal: {greedy_diag_2x2}')
    print(f'greedy_avg_off_diagonal: {greedy_off_diag_2x2}')
    
    # Test 3x3 matrix
    matrix_3x3 = np.array([[0.11, 0.22, 0.33],
                          [0.44, 0.55, 0.66],
                          [0.77, 0.88, 0.99]])
    
    # Test optimal permutation
    opt_perm_3x3, opt_ind_3x3, opt_diag_3x3, opt_off_diag_3x3 = find_best_permutation(matrix_3x3)
    print(f'opt_avg_diagonal_3x3: {opt_diag_3x3}')
    print(f'opt_avg_off_diagonal_3x3: {opt_off_diag_3x3}')
    
    # Test greedy permutation
    greedy_perm_3x3, greedy_ind_3x3, greedy_diag_3x3, greedy_off_diag_3x3 = find_best_permutation_greedy(matrix_3x3)
    print(f'greedy_avg_diagonal_3x3: {greedy_diag_3x3}')
    print(f'greedy_avg_off_diagonal_3x3: {greedy_off_diag_3x3}')
    
    # Test with JAX array
    jax_matrix_2x2 = jnp.array([[0.25, 0.5],
                               [0.75, 1.0]])
    opt_perm_jax_2x2, opt_ind_jax_2x2, opt_diag_jax_2x2, opt_off_diag_jax_2x2 = find_best_permutation(jax_matrix_2x2)
    print(f'opt_avg_diagonal_jax: {opt_diag_jax_2x2}')
    print(f'opt_avg_off_diagonal_jax: {opt_off_diag_jax_2x2}')
    
    # Test 9x9 matrix with a clear pattern
    matrix_9x9 = np.array([
        [0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88, 0.99],
        [0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88, 0.99, 0.11],
        [0.33, 0.44, 0.55, 0.66, 0.77, 0.88, 0.99, 0.11, 0.22],
        [0.44, 0.55, 0.66, 0.77, 0.88, 0.99, 0.11, 0.22, 0.33],
        [0.55, 0.66, 0.77, 0.88, 0.99, 0.11, 0.22, 0.33, 0.44],
        [0.66, 0.77, 0.88, 0.99, 0.11, 0.22, 0.33, 0.44, 0.55],
        [0.77, 0.88, 0.99, 0.11, 0.22, 0.33, 0.44, 0.55, 0.66],
        [0.88, 0.99, 0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77],
        [0.99, 0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88]
    ])
    
    print("\nTesting 9x9 matrix:")
    print("Original matrix diagonal stats:")
    orig_diag_9x9, orig_off_diag_9x9 = compute_diagonal_stats(matrix_9x9)
    print(f'Original avg_diagonal: {orig_diag_9x9}')
    print(f'Original avg_off_diagonal: {orig_off_diag_9x9}')
    
    # Test optimal permutation
    opt_perm_9x9, opt_ind_9x9, opt_diag_9x9, opt_off_diag_9x9 = find_best_permutation(matrix_9x9)
    print(f'Optimal permutation avg_diagonal: {opt_diag_9x9}')
    print(f'Optimal permutation avg_off_diagonal: {opt_off_diag_9x9}')
    print(f'Optimal permutation indices: {opt_ind_9x9}')
    
    # Test greedy permutation
    greedy_perm_9x9, greedy_ind_9x9, greedy_diag_9x9, greedy_off_diag_9x9 = find_best_permutation_greedy(matrix_9x9)
    print(f'Greedy permutation avg_diagonal: {greedy_diag_9x9}')
    print(f'Greedy permutation avg_off_diagonal: {greedy_off_diag_9x9}')
    print(f'Greedy permutation indices: {greedy_ind_9x9}')
    
    # Test matrix with diagonal close to 1 and off-diagonal close to 0
    print("\nTesting matrix with diagonal close to 1 and off-diagonal close to 0:")
    matrix_4x4 = np.array([
        [0.01, 0.02, 0.95, 0.01],
        [0.92, 0.01, 0.01, 0.03],
        [0.01, 0.94, 0.02, 0.01],
        [0.03, 0.01, 0.01, 0.96]
    ])
    
    print("Original matrix:")
    print(matrix_4x4)
    print("Original matrix diagonal stats:")
    orig_diag_4x4, orig_off_diag_4x4 = compute_diagonal_stats(matrix_4x4)
    print(f'Original avg_diagonal: {orig_diag_4x4}')
    print(f'Original avg_off_diagonal: {orig_off_diag_4x4}')
    
    # Test optimal permutation
    opt_perm_4x4, opt_ind_4x4, opt_diag_4x4, opt_off_diag_4x4 = find_best_permutation(matrix_4x4)
    print("\nOptimal permutation:")
    print(opt_perm_4x4)
    print(f'Optimal permutation indices: {opt_ind_4x4}')
    print(f'Optimal permutation avg_diagonal: {opt_diag_4x4}')
    print(f'Optimal permutation avg_off_diagonal: {opt_off_diag_4x4}')
    
    # Test greedy permutation
    greedy_perm_4x4, greedy_ind_4x4, greedy_diag_4x4, greedy_off_diag_4x4 = find_best_permutation_greedy(matrix_4x4)
    print("\nGreedy permutation:")
    print(greedy_perm_4x4)
    print(f'Greedy permutation indices: {greedy_ind_4x4}')
    print(f'Greedy permutation avg_diagonal: {greedy_diag_4x4}')
    print(f'Greedy permutation avg_off_diagonal: {greedy_off_diag_4x4}')
    
    # Test rectangular matrix (3x4)
    print("\nTesting 3x4 rectangular matrix:")
    matrix_3x4 = np.array([
        [0.1, 0.2, 0.3, 0.4],
        [0.5, 0.6, 0.7, 0.8],
        [0.9, 0.2, 0.3, 0.4]
    ])
    
    print("Original matrix:")
    print(matrix_3x4)
    
    # Test optimal permutation along rows (axis=0)
    opt_perm_3x4, opt_ind_3x4, opt_diag_3x4, opt_off_diag_3x4 = find_best_permutation(matrix_3x4, axis=0)
    print("\nOptimal permutation (axis=0):")
    print(opt_perm_3x4)
    print(f'Optimal permutation indices: {opt_ind_3x4}')
    print(f'Optimal permutation avg_diagonal: {opt_diag_3x4}')
    print(f'Optimal permutation avg_off_diagonal: {opt_off_diag_3x4}')
    
    # Test optimal permutation along columns (axis=1)
    opt_perm_3x4_col, opt_ind_3x4_col, opt_diag_3x4_col, opt_off_diag_3x4_col = find_best_permutation(matrix_3x4, axis=1)
    print("\nOptimal permutation (axis=1):")
    print(opt_perm_3x4_col)
    print(f'Optimal permutation indices: {opt_ind_3x4_col}')
    print(f'Optimal permutation avg_diagonal: {opt_diag_3x4_col}')
    print(f'Optimal permutation avg_off_diagonal: {opt_off_diag_3x4_col}')
    
    # Test greedy permutation along rows (axis=0)
    greedy_perm_3x4, greedy_ind_3x4, greedy_diag_3x4, greedy_off_diag_3x4 = find_best_permutation_greedy(matrix_3x4, axis=0)
    print("\nGreedy permutation (axis=0):")
    print(greedy_perm_3x4)
    print(f'Greedy permutation indices: {greedy_ind_3x4}')
    print(f'Greedy permutation avg_diagonal: {greedy_diag_3x4}')
    print(f'Greedy permutation avg_off_diagonal: {greedy_off_diag_3x4}')
    
    print("\nAll tests passed!")

if __name__ == "__main__":
    test_permutation_optimizer()
