"""
Implementation of the minimal support computation algorithm for polynomials.
"""

import numpy as np
from scipy.spatial import ConvexHull
from typing import List, Union, Optional, Tuple, Set
import itertools
from collections import defaultdict
from sos_transformer.utils.polynomial import load_polynomials_and_bases_from_jsonl, get_newton_polytope_basis

def compute_minimal_support_phase1(exponent_vectors: Union[List[Tuple[int, ...]], np.ndarray]) -> Optional[np.ndarray]:
    """
    Phase 1 of the minimal support computation algorithm using the Newton polytope method.
    This phase computes the initial set G0 by finding all integer points in N(p)/2.
    Args:
        exponent_vectors: List of exponent vectors from the original polynomial
    Returns:
        Array of exponent vectors in G0, or None if computation fails
    """
    result = get_newton_polytope_basis(exponent_vectors)
    if result is None or len(result) == 0:
        return None
    return result

def compute_minimal_support_phase2(G0: np.ndarray, Fe: np.ndarray) -> np.ndarray:
    """
    Phase 2 of the minimal support computation algorithm.
    
    This phase computes G* from G0 by constructing a digraph and iteratively
    removing nodes with no outgoing edges.
    
    Args:
        G0: Array of exponent vectors from Phase 1
        Fe: Array of exponent vectors from the original polynomial
        
    Returns:
        Array of exponent vectors in G*
    """
    if len(G0) == 0:
        return G0
        
    # Convert arrays to sets of tuples for easier comparison
    G0_set = {tuple(g) for g in G0}
    Fe_set = {tuple(f) for f in Fe}
    
    # Initialize Gp as G0
    Gp = G0_set.copy()
    
    while True:
        # Find nodes with no outgoing edges
        nodes_to_remove = set()
        
        for alpha in Gp:
            # Check if 2*alpha is in Fe
            alpha_doubled = tuple(2 * np.array(alpha))
            if alpha_doubled in Fe_set:
                continue
                
            # Check if 2*alpha can be written as sum of two other points in Gp
            can_be_sum = False
            for beta in Gp:
                if beta == alpha:
                    continue
                for gamma in Gp:
                    if gamma == alpha:
                        continue
                    if tuple(np.array(beta) + np.array(gamma)) == alpha_doubled:
                        can_be_sum = True
                        break
                if can_be_sum:
                    break
                    
            if not can_be_sum:
                nodes_to_remove.add(alpha)
                
        if not nodes_to_remove:
            break
            
        # Remove nodes with no outgoing edges
        Gp -= nodes_to_remove
        
    return np.array(list(Gp))

def compute_minimal_support(exponent_vectors: Union[List[Tuple[int, ...]], np.ndarray]) -> Optional[np.ndarray]:
    """
    Compute the minimal support for a given polynomial using both phases of the algorithm.
    
    Args:
        exponent_vectors: List of exponent vectors from the original polynomial
        
    Returns:
        Array of exponent vectors in the minimal support, or None if computation fails
    """
    # Phase 1: Compute G0
    G0 = compute_minimal_support_phase1(exponent_vectors)
    if G0 is None:
        return None
    
    print(f"length of G0: {len(G0)}")
        
    # Phase 2: Compute G* from G0
    G_star = compute_minimal_support_phase2(G0, np.array(exponent_vectors))
    
    return G_star
