import numpy as np
from scipy.spatial import ConvexHull
import scipy.spatial._qhull

def is_lattice_point(point):
    """
    Check if a point has all integer coordinates.
    
    Args:
        point (np.ndarray): Point coordinates
        
    Returns:
        bool: True if all coordinates are integers
    """
    return np.all(np.equal(np.mod(point, 1), 0))

def get_lattice_points_in_hull(vertices, max_coord=10):
    """
    Get all lattice points within a convex hull.
    
    Args:
        vertices (np.ndarray): Vertices of the convex hull
        max_coord (int): Maximum coordinate value to check
        
    Returns:
        np.ndarray: Array of lattice points inside the hull
    """
    # Get bounding box of vertices
    min_coords = np.floor(np.min(vertices, axis=0)).astype(int)
    max_coords = np.ceil(np.max(vertices, axis=0)).astype(int)
    
    # Generate all lattice points in bounding box
    coords = [np.arange(min_coord, max_coord + 1) 
             for min_coord, max_coord in zip(min_coords, max_coords)]
    grid = np.meshgrid(*coords)
    points = np.vstack([g.flatten() for g in grid]).T
    
    # Create hull from vertices
    hull = ConvexHull(vertices)
    
    # Test each point
    in_hull = np.zeros(len(points), dtype=bool)
    for i, point in enumerate(points):
        # Use Delaunay triangulation (built into ConvexHull) to test if point is in hull
        new_points = np.vstack((vertices, point))
        new_hull = ConvexHull(new_points)
        # Point is inside if it doesn't add to the hull facets
        in_hull[i] = len(new_hull.vertices) == len(hull.vertices)
    
    return points[in_hull]

def get_variance_ordering(points: np.ndarray) -> np.ndarray:
    """
    Compute permutation that orders variables by their variance (highest to lowest).
    
    Args:
        points (np.ndarray): Input points with shape (n_points, n_dimensions)
        
    Returns:
        np.ndarray: Permutation array where result[i] gives the original index of the i-th highest variance variable
    """
    # Compute variance along each dimension
    variances = np.var(points, axis=0)
    
    # Get indices that would sort variances in descending order
    return np.argsort(-variances)

def analyze_lattice_polytope(points):
    """
    Analyze a set of lattice points and their convex hull (lattice polytope).
    
    Args:
        points (np.ndarray): Input lattice points with shape (n_points, n_dimensions)
        
    Returns:
        dict: Dictionary containing various properties of the lattice polytope
    """
    # Verify all points are lattice points
    points = np.array(points)
    if not all(is_lattice_point(point) for point in points):
        raise ValueError("All input points must be lattice points (integer coordinates)")
    
    # Compute statistics using ALL points
    means = np.mean(points, axis=0)
    variances = np.var(points, axis=0)
    variance_ordering = get_variance_ordering(points)
    
    # Check dimensionality
    n_points, n_dims = points.shape
    if n_points < 2:
        # Single point case
        return {
            'hull_vertices': points,
            'interior_points': np.array([]),
            'n_interior_points': 0,
            'n_vertices': 1,
            'means': means,
            'variances': variances,
            'volume': 0.0,
            'variance_ordering': variance_ordering,
            'is_degenerate': True,
            'effective_dimension': 0,
            'n_total_points': n_points
        }
    
    # Compute effective dimension using SVD
    centered = points - means  # Use already computed mean
    _, s, _ = np.linalg.svd(centered, full_matrices=False)
    effective_dim = np.sum(s > 1e-10)
    
    try:
        # Try to compute convex hull
        hull = ConvexHull(points)
        hull_vertices = points[hull.vertices]
        volume = hull.volume
    except scipy.spatial._qhull.QhullError:
        # Handle degenerate case
        # For degenerate cases, all points are effectively vertices
        hull_vertices = points
        volume = 0.0
    
    # Get interior points only if we have a proper hull
    try:
        interior_points = get_lattice_points_in_hull(hull_vertices)
    except Exception as e:
        interior_points = np.array([])
    
    return {
        'hull_vertices': hull_vertices,
        'interior_points': interior_points,
        'n_interior_points': len(interior_points),
        'n_vertices': len(hull_vertices),
        'means': means,
        'variances': variances,
        'volume': volume,
        'variance_ordering': variance_ordering,
        'is_degenerate': effective_dim < n_dims,
        'effective_dimension': effective_dim,
        'n_total_points': n_points
    }

def generate_random_lattice_points(n_points, dim, min_coord=-5, max_coord=5):
    """
    Generate random lattice points in specified dimension and range.
    
    Args:
        n_points (int): Number of points to generate
        dim (int): Dimension of the space
        min_coord (int): Minimum coordinate value
        max_coord (int): Maximum coordinate value
        
    Returns:
        np.ndarray: Array of random lattice points
    """
    return np.random.randint(min_coord, max_coord + 1, size=(n_points, dim))

def analyze_and_print_results(points, dim_name=""):
    """
    Helper function to analyze and print results for a set of points.
    
    Args:
        points (np.ndarray): Input points
        dim_name (str): Name/description of the dimension for printing
    """
    print(f"\n{'-'*20} {dim_name} {'-'*20}")
    try:
        results = analyze_lattice_polytope(points)
        
        print(f"Number of total points: {results['n_total_points']}")
        print(f"Number of vertices in convex hull: {results['n_vertices']}")
        print(f"Number of interior lattice points: {results['n_interior_points']}")
        print(f"Volume: {results['volume']:.4f}")
        print(f"Effective dimension: {results['effective_dimension']}")
        print(f"Is degenerate: {results['is_degenerate']}")
        
        print("\nCoordinate-wise Statistics (using ALL points):")
        for i, (mean, var) in enumerate(zip(results['means'], results['variances'])):
            print(f"Dimension {i}:")
            print(f"  Mean: {mean:.4f}")
            print(f"  Variance: {var:.4f}")
        
        print("\nVariable ordering (by variance):")
        print(f"  {results['variance_ordering']}")
        
        if points.shape[1] <= 3:  # Only print points for 2D and 3D examples
            print("\nAll points:")
            print(points)
            
    except ValueError as e:
        print(f"Error: {e}")


def get_exponents_from_M_tokens(tokens):
    """
    Convert a list of tokenized monomials in M_i_j_k_l format into their exponent vectors.
    
    Args:
        tokens (list): List of strings, each representing a monomial in M_i_j_k_l format
                      where i,j,k,l are the exponents for each variable
                      Example: ['M_0_1_2_0', 'M_1_0_0_1']
                      
    Returns:
        np.ndarray: Array of exponent vectors, shape (n_monomials, n_variables)
        int: Number of variables detected
    """
    # Split first token to determine number of variables
    n_vars = len(tokens[0].split('_')[1:])
    exponents = np.zeros((len(tokens), n_vars), dtype=int)
    
    for i, token in enumerate(tokens):
        # Skip the 'M' prefix and split into exponents
        exp_values = token.split('_')[1:]
        exponents[i] = [int(v) for v in exp_values]
    
    return exponents, n_vars

def get_M_token_variance_ordering(tokens):
    """
    Compute the variance-based permutation for a list of tokenized monomials in M_i_j_k_l format.
    
    Args:
        tokens (list): List of strings, each representing a monomial in M_i_j_k_l format
                      Example: ['M_0_1_2_0', 'M_1_0_0_1']
                      
    Returns:
        dict: Dictionary containing:
            - 'permutation': Array of indices ordering variables by variance (highest to lowest)
            - 'variances': Array of variances for each variable
            - 'n_variables': Number of variables
            - 'exponents': The computed exponent vectors
    """
    # Convert tokens to exponent vectors
    exponents, n_vars = get_exponents_from_M_tokens(tokens)
    
    # Compute variance ordering
    variance_ordering = get_variance_ordering(exponents)
    variances = np.var(exponents, axis=0)
    
    return {
        'permutation': variance_ordering,
        'variances': variances,
        'n_variables': n_vars,
        'exponents': exponents
    }

def analyze_and_print_M_token_results(tokens, name=""):
    """
    Helper function to analyze and print results for M_i_j_k_l format tokenized monomials.
    
    Args:
        tokens (list): List of tokenized monomials in M_i_j_k_l format
        name (str): Name/description for printing
    """
    print(f"\n{'-'*20} {name} {'-'*20}")
    results = get_M_token_variance_ordering(tokens)
    
    print(f"Number of monomials: {len(tokens)}")
    print(f"Number of variables: {results['n_variables']}")
    print("\nVariances for each variable:")
    for i, var in enumerate(results['variances']):
        print(f"  var_{i}: {var:.4f}")
    
    print("\nVariable ordering (by variance, highest to lowest):")
    ordering = [f"var_{i}" for i in results['permutation']]
    print(f"  {ordering}")
    
    print("\nExponent vectors:")
    print(results['exponents'])
