import numpy as np
import random
from typing import List, Dict, Tuple
from collections import Counter, defaultdict
from data_generation.monomials.monomials import Monomial, Polynomial
from utils.polynomial import load_polynomials_and_bases_from_jsonl, remove_k_random_elements

def basis_extension(initial_basis: List[Monomial], poly: Polynomial, max_iter: int = 10) -> List[Monomial]:
    """
    Iteratively extend a monomial basis to cover the support of a given polynomial using a greedy divisor-based heuristic.

    Args:
        initial_basis: List of Monomial objects (the initial basis B0)
        poly: Polynomial object whose support we want to cover
        max_iter: Maximum number of extension iterations

    Returns:
        List of Monomial objects forming the extended basis
    """
    basis = set(initial_basis)
    support = set(poly.terms.keys())

    for _ in range(max_iter):
        # 1. Compute all products B * B
        product_monomials = set()
        basis_list = list(basis)
        for i in range(len(basis_list)):
            for j in range(i, len(basis_list)):
                product = basis_list[i] * basis_list[j]
                product_monomials.add(product)

        # 2. Find missing terms in the support
        missing = support - product_monomials
        if not missing:
            break  # All support covered

        # 3. For each missing term, try to divide by all elements in basis
        divisor_counter = Counter()
        for m in missing:
            for b in basis:
                # Try to divide m by b (componentwise)
                if all(e1 >= e2 for e1, e2 in zip(m.exponents, b.exponents)):
                    # Compute the divisor
                    divisor_exp = tuple(e1 - e2 for e1, e2 in zip(m.exponents, b.exponents))
                    divisor = Monomial(divisor_exp)
                    divisor_counter[divisor] += 1

        if not divisor_counter:
            break  # No divisors found, cannot extend further

        # 4. Add the most common divisor to the basis
        most_common_divisor, _ = divisor_counter.most_common(1)[0]
        if most_common_divisor in basis:
            break  # No new basis element found
        basis.add(most_common_divisor)

    return list(basis)


def basis_extension_comprehensive(initial_basis: List[Monomial], poly: Polynomial, min_score: int = 7, max_iter: int = 10) -> List[Monomial]:
    """
    Comprehensive basis extension using score-based selection of divisors.
    
    This function finds all monomials that can serve as divisors for multiple terms in the 
    polynomial's support, keeping those that meet a minimum score threshold. The score 
    of a monomial t is the number of terms in the support that are divisible by t with 
    the remainder being in the initial basis.
    
    Args:
        initial_basis: List of Monomial objects (the initial basis B0)
        poly: Polynomial object whose support we want to cover
        min_score: Minimum score threshold - keep monomials that divide at least this many support terms
    
    Returns:
        List of Monomial objects forming the extended basis (includes initial_basis + new monomials)
    """
    basis = set(initial_basis)
    support = set(poly.terms.keys())

    print("Starting basis extension")
    
    # Dictionary to store scores for potential divisors
    divisor_scores = defaultdict(int)
    
    # For each term in the support, find all possible divisors
    for support_term in support:
        # Try all possible divisors by considering all combinations of exponents
        # We'll generate candidates by looking at all possible ways to split the exponents
        max_exponents = support_term.exponents
        
        # Generate all possible divisor candidates up to the exponents of this support term
        def generate_divisor_candidates(max_exp):
            candidates = []
            # Use recursive generation to create all combinations
            def generate_recursive(current_exp, remaining_dims):
                if remaining_dims == 0:
                    candidates.append(Monomial(tuple(current_exp)))
                    return
                
                dim_idx = len(current_exp)
                max_val = max_exp[dim_idx]
                
                for exp_val in range(max_val + 1):
                    generate_recursive(current_exp + [exp_val], remaining_dims - 1)
            
            generate_recursive([], len(max_exp))
            return candidates
        
        candidates = generate_divisor_candidates(max_exponents)
        
        # For each candidate divisor, check if the remainder is in the initial basis
        for candidate in candidates:
            # Check if candidate divides support_term
            if all(e1 >= e2 for e1, e2 in zip(support_term.exponents, candidate.exponents)):
                # Compute remainder
                remainder_exp = tuple(e1 - e2 for e1, e2 in zip(support_term.exponents, candidate.exponents))
                remainder = Monomial(remainder_exp)
                
                # Check if remainder is in initial basis
                if remainder in basis:
                    divisor_scores[candidate] += 1
    
    # Add all divisors that meet the minimum score threshold
    extended_basis = set(initial_basis)
    for divisor, score in divisor_scores.items():
        if score >= min_score:
            extended_basis.add(divisor)

    print("Basis extension complete", len(extended_basis))
    
    return list(extended_basis)