# Standard library imports
from functools import cache
from typing import List, Tuple, Optional
from pprint import pprint
from bisect import bisect_left, insort
from time import time
from collections import defaultdict
from itertools import product

# Third-party imports
from mpmath.libmp.libelefun import exponential_series
from sage.all import *
from tqdm import tqdm
from sortedcontainers import SortedList

# Local imports
from .border_basis import BorderBasisCalculator
from .sanity_checks import OBorderBasisChecker
from .utils import (
    plot_multiple_monomials,
    find_maximal_terms,
    collect_all_indices,
    compute_index_and_expansion_direction,
    expand_within_universe
)
from ..oracle import (
    Oracle,
    HardPartialExtensionOracle,
    HistoricalOracle,
    TransformerOracle
)


class ImprovedBorderBasisCalculator(BorderBasisCalculator):
    """
    Implementation of the Improved Border Basis algorithm from the paper 
    "Computing Border Bases" by Kehrein & Kreuzer (2005).

    This class extends BorderBasisCalculator to implement an optimized version
    of the border basis computation algorithm with oracle-based predictions.

    Attributes:
        datasets (List[List[Tuple]]): Training data for oracle prediction.
            Each inner list contains tuples of (L, V, successful_expansion_directions).
        oracle (TransformerOracle): Oracle for predicting successful expansion directions.
        timings (dict): Dictionary tracking execution times for various steps.
    """

    def __init__(
        self,
        ring,
        corollary_23: bool = False,
        N: int = 1,
        save_universes: bool = False,
        save_expansion_directions: bool = True,
        L: Optional[List] = None,
        verbose: bool = False,
        sorted_V: bool = False,
        relative_gap: float = 0.5,
        absolute_gap: int = 100,
        order_ideal_size: int = 5,
        oracle_max_calls: int = 5,
        min_universe: int = 20,
        save_path: Optional[str] = None,
        leading_term_k: int = 5,
        oracle: bool = False,
        oracle_model: Optional[TransformerOracle] = None
    ):
        """
        Initialize the ImprovedBorderBasisCalculator.

        Args:
            ring: Polynomial ring for computations
            corollary_23: Whether to use Corollary 23 optimization
            N: Frequency of full border extension
            save_universes: Whether to save universe evolution
            save_expansion_directions: Whether to save expansion directions
            L: Initial order ideal
            verbose: Whether to print detailed progress
            sorted_V: Whether to use sorted list for V
            relative_gap: Relative size threshold for oracle usage
            absolute_gap: Absolute size threshold for oracle usage
            order_ideal_size: Target size for order ideal
            oracle_max_calls: Maximum number of oracle calls
            min_universe: Minimum universe size for oracle usage
            save_path: Path to save oracle model
            leading_term_k: Number of leading terms to consider
            oracle: Whether to use oracle
            oracle_model: Pre-trained oracle model
        """
        super().__init__(ring)
        
        # Core parameters
        self.corollarly_23 = corollary_23
        self.N = N
        self.L = L
        self.verbose = verbose
        self.sorted_V = sorted_V
        
        # Oracle configuration
        self.relative_gap = relative_gap
        self.absolute_gap = absolute_gap
        self.order_ideal_size = order_ideal_size
        self.oracle_max_calls = oracle_max_calls
        self.min_universe = min_universe
        
        # Initialize oracle
        start_time = time()
        if oracle_model is None:
            self.oracle = TransformerOracle(ring, save_path, leading_term_k=leading_term_k)
        else:
            self.oracle = oracle_model
        if verbose:
            print(f"Time taken to load oracle: {time() - start_time:.2f}s")

        # State tracking
        self.logging_enabled = False
        self.count = 1
        self.no_oracle = not oracle
        self.oracle_calls = 0
        self.total_reduction_steps = 0
        
        # Data collection
        self.save_universes = save_universes
        self.save_expansion_directions = save_expansion_directions
        self.universes = []
        self.leading_terms = []
        self.border_terms_add = []
        self.successful_expansion_directions = []
        self.surviving_indices = []
        self.datasets = []
        self.efficiency = []
        self.zero_reduction_steps = []

    def enable_logging(self, enabled: bool):
        """Enable or disable logging for debugging and performance tracking."""
        self.logging_enabled = enabled

    def log(self, message: str):
        """Log a message if logging is enabled."""
        if self.logging_enabled:
            print(f"[LOG]: {message}")

    def border(self, V: List, origin = False) -> List:
        """
        Implementation for extending the set V to alll directions, notes also the expansion origin.

        For origin = True, it returns the expansion origin.
        Example: V = [x*+2 + y, y], border(V) = [(x*+2 + y, x**3+x*y), (x*+2 + y, x**2*y+y**2), (y, x*y), (y, y**2)].

        For order ideals it returns the border of the order ideal.

        Args:
            V: List of polynomials (or monomials) in the polynomial ring.
            origin: Flag to include expansion origin.
        Returns:
            List: The border of the set V and the expansion origin.
        """
        border = list()
        for t in O:
            for var in self.variables:
                new_term = t * var
                if new_term not in O and new_term not in border:
                    if origin:
                        border.append((t, new_term))
                    else:
                        border.append(new_term)
        # border = sorted(border, key=lambda t: t.lm())
        return border
    


    def compute_order_ideal_monomials(self, monomials):
        """
        Compute the order ideal spanned by a given set of monomials.

        Args:
            monomials (list): A list of monomials in a polynomial ring.

        Returns:
            list: The set of all monomials in the order ideal.
        """
        if not monomials:
            return []

        # Get the polynomial ring from the first monomial
        self.ring = monomials[0].parent()

        # Convert monomials to exponent form
        monomial_exponents = [monomial.exponents()[0] for monomial in monomials]

        # Number of variables in the polynomial ring
        num_vars = len(monomial_exponents[0])

        # Compute the order ideal in exponent form
        order_ideal_exponents = set()
        for exponents in monomial_exponents:
            # Generate all divisors by reducing exponents
            divisors = product(*(range(e + 1) for e in exponents))
            order_ideal_exponents.update(divisors)

        # Convert the result back to monomials
        order_ideal_monomials = [self.ring.monomial(*exponents) for exponents in
                                 order_ideal_exponents]

        # Sort the monomials by degree and lexicographical order
        return sorted(order_ideal_monomials, key=lambda t: (t.degree(), t))

    def compute_lstable_span_optimized(
        self,
        F: List,
        L: List,
        use_fast_elimination: bool = False,
        hints: Optional[List] = None
    ) -> Tuple[List, List]:
        """
        Optimized version of compute_lstable_span with better handling for large computational universes.

        Args:
            F: List of generating polynomials
            L: Current computational universe
            use_fast_elimination: Whether to use fast Gaussian elimination
            hints: Optional hints for computation

        Returns:
            Tuple containing:
            - List of polynomials in L-stable span
            - Updated computational universe
        """
        self.log("Starting optimized L-stable span computation.")

        # Initial Gaussian elimination
        V, _, _ = super().gaussian_elimination([], F, use_fast_elimination=use_fast_elimination)
        
        if self.sorted_V:
            V = SortedList(V)

        zero_reductions = 0
        
        while True:
            sequence = [find_maximal_terms(L), V.copy()]
            
            # Compute border terms
            border_terms = self._compute_border_terms(V, L)
            self.total_reduction_steps += len(border_terms)

            # Perform Gaussian elimination
            W, non_zero_reductions_indices, reduction_indices = self._perform_gaussian_elimination(
                V, border_terms, use_fast_elimination
            )
            
            # Process results
            W_prime, L = self._process_elimination_results(W, L)
            
            # Update tracking variables
            self._update_tracking_variables(
                W_prime, non_zero_reductions_indices,
                reduction_indices, border_terms, V
            )
            
            zero_reductions += len(border_terms) - len(non_zero_reductions_indices)

            if not W_prime:
                break
                
            if self.sorted_V:
                V.update(W_prime)
            else:
                V.extend(W_prime)
            
        self.zero_reduction_steps.append(zero_reductions)
        return V, L

    def compute_border_basis_optimized(
        self,
        F: List,
        use_fast_elimination: bool = False,
        lstabilization_only: bool = False
    ) -> Tuple[List, List, dict]:
        """
        Implementation of Improved Border Basis algorithm.

        Args:
            F: List of generating polynomials
            use_fast_elimination: Whether to use fast Gaussian elimination
            lstabilization_only: Whether to only compute L-stable span

        Returns:
            Tuple containing:
            - Border basis G
            - Order ideal O
            - Dictionary of timing information
        """
        self._initialize_timings()
        global_start_time = time()

        # Initialize computational universe
        L = self._initialize_computational_universe(F)
        
        while True:
            # Compute L-stable span
            M, L = self._compute_lstable_span(F, L, use_fast_elimination)
            F = M
            
            # Check if universe is sufficient
            O, sufficient = self._check_universe_sufficiency(M, L)
            
            if sufficient:
                if self._should_stop_oracle(O):
                    self.no_oracle = True
                    self.timings['fallback_to_border_basis'] = 1
                else:
                    break
            else:
                L = self._extend_universe(L, O)
                
        # Final reduction if needed
        G = self._final_reduction(M, O, lstabilization_only)
        
        # Update final timings
        self._update_final_timings(global_start_time)
        
        return G, O, self.timings

    def final_reduction_algorithm(self, V, O):
        """
        Compute the O-border basis of a zero-dimensional ideal, i.e., transform V into a border basis.

        This is the final reduction algorithm (Proposition 17) from the paper "Computing Border Bases" by Kehrein & Kreuzer (2005).

        Args:
            V (list): A vector basis of the span FL with pairwise different leading terms.
            O (set): Order ideal for border basis.

        Returns:
            list: The O-border basis {g1, ..., gτ}.
        """
        # Initialize VR and create a dictionary for leading terms
        VR = []
        leading_term_map = {}  # Maps leading terms to their polynomials

        # Sort V by leading term
        V = sorted(V, key=lambda v: v.lm())
        
        # Cache for polynomial supports
        support_cache = {}

        while V:
            # Get polynomial with minimal leading term
            v = V.pop(0)
            
            # Compute and cache support if not already cached
            if v not in support_cache:
                support_cache[v] = set(self.ring.monomial(*exponent) for exponent in v.dict().keys())
            
            H = support_cache[v] - ({v.lm()} | set(O))

            if not H:
                # Normalize and store in VR
                normalized_v = v / v.lc()
                VR.append(normalized_v)
                leading_term_map[normalized_v.lm()] = normalized_v
                continue

            # Reduce v using polynomials from VR
            for h in H:
                if h in leading_term_map:
                    wh = leading_term_map[h]
                    ch = v.monomial_coefficient(h) / wh.lc()
                    v -= ch * wh
                else:
                    raise ValueError(f"No wh found in VR for h = {h}")

            # Normalize and store the reduced polynomial
            normalized_v = v / v.lc()
            VR.append(normalized_v)
            leading_term_map[normalized_v.lm()] = normalized_v

        # Construct border basis using the leading term map
        border = super().border(O)
        border_basis = []
        
        for b in border:
            if b in leading_term_map:
                border_basis.append(leading_term_map[b])
            else:
                raise ValueError(f"No polynomial in VR with leading term {b}")

        # Sort the border basis by leading term
        return sorted(border_basis, key=lambda g: g.lm())

    def check_pure_monomials(self, M):
        """
        Check if the leading terms of the polynomials in M contain pure monomials.

        :param M: Set of polynomials
        :return: list of tuples (variable, pure_monomial)
        """
        pure_monomials = []
        for f in M:
            t = f.lm()
            monomial_exponents = t.exponents()[0]
            if all([e == 0 or e == sum(monomial_exponents) for e in monomial_exponents]):
                pure_monomials.append(t)

        return pure_monomials

    def plot_evolution(self):
        """
        Plot the evolution of the computational universe.
        """

        # get the universe terms in exponent form
        universes_exponents = [[tuple(t.exponents()[0]) for t in universe] for universe in self.universes]

        # get the leading terms in exponent form
        leading_terms_exponents = [[tuple(t.exponents()[0]) for t in leading_terms] for leading_terms in self.leading_terms]

        # get the border terms that were added in exponent form
        border_terms_add_exponents = [[tuple(t.exponents()[0]) for t in border_terms] for border_terms in self.border_terms_add]

        # plot the evolution
        print(border_terms_add_exponents)
        plot_multiple_monomials(universes_exponents, leading_terms_exponents, border_terms_add_exponents)

    def _compute_border_terms(self, V: List, L: List) -> List:
        """
        Compute border terms using oracle prediction if applicable.

        Args:
            V: Current set of polynomials
            L: Current computational universe

        Returns:
            List of border terms
        """
        border_terms = super().extend_V(V)
        
        if (not self.no_oracle and 
            self.oracle_calls < self.oracle_max_calls and 
            len(V)/len(L) > self.relative_gap and 
            len(L) - len(V) < self.absolute_gap and 
            len(L) > self.min_universe):
            
            prediction = self.oracle.predict(V, find_maximal_terms(L))
            prediction = [t for _, t in prediction]
            
            hV = [poly for poly in V if poly.lm() in prediction]
            border_terms = super().extend_V(hV)
            self.oracle_calls += 1
            
        return border_terms

    def _perform_gaussian_elimination(
        self,
        V: List,
        border_terms: List,
        use_fast_elimination: bool
    ) -> Tuple[List, List, List]:
        """
        Perform Gaussian elimination and track timing.

        Args:
            V: Current set of polynomials
            border_terms: Border terms to eliminate
            use_fast_elimination: Whether to use fast elimination

        Returns:
            Tuple of (reduced polynomials, non-zero reduction indices, reduction indices)
        """
        start_time = time()
        result = super().gaussian_elimination(
            V, border_terms, use_fast_elimination=use_fast_elimination
        )
        self.timings["gaussian_elimination_times"].append(time() - start_time)
        return result





# Example usage
if __name__ == '__main__':
    # Minimal example usage

    # Set options
    use_fast_elimination = True
    lstabilization_only = False

    # Define a small 0-dimensional system in two variables
    R = PolynomialRing(QQ, 'x, y', order='degrevlex')
    x, y = R.gens()
    F = [x**2 - y, y**2 - x]

    print('Input polynomials (F):')
    pprint(F)
    print()

    print(f'Ideal dimension (should be 0): {ideal(F).dimension()}')
    print(f'Use fast gaussian elimination: {use_fast_elimination}')

    # Create a calculator
    calculator = ImprovedBorderBasisCalculator(R, corollary_23=True, N=20, save_universes=True, verbose=True)

    # Compute the border basis
    G, O, _ = calculator.compute_border_basis_optimized(
        F,
        use_fast_elimination=use_fast_elimination,
        lstabilization_only=lstabilization_only
    )

    print("Order ideal:", O)
    print("Border basis:", G)

    # Check if the output is indeed an O-border basis
    checker = OBorderBasisChecker(R)
    is_border_basis = checker.check_oborder_basis(G, O, F)
    print(f"Is border basis by Buchberger Criterion: {is_border_basis}")
