import numpy as np
import cvxpy as cp
from typing import Dict, List, Optional, Tuple, Set
from itertools import combinations_with_replacement
from data_generation.monomials.monomials import Polynomial, Monomial, MonomialBasis
from sdp_solver.sdp_interface import SDPSolver
from utils.polynomial import get_newton_polytope_basis

class CVXPYSOSSolver(SDPSolver):
    """SDP solver for SOS problems using CVXPY."""
    
    def __init__(self, solver: str = 'SCS', verbose: bool = False):
        """
        Initialize the solver.
        
        Args:
            solver: CVXPY solver to use (e.g., 'MOSEK', 'SCS', 'CVXOPT')
            verbose: Whether to print solver output
        """
        self.solver = solver
        self.verbose = verbose
    
    def _generate_monomial_basis(self, poly: Polynomial) -> MonomialBasis:
        """Generate a monomial basis suitable for the SOS decomposition."""
        # Get the number of variables from any monomial
        num_vars = len(next(iter(poly.terms)).exponents)
        
        # Maximum degree of basis elements should be half the maximum degree of poly
        max_deg = max(sum(m.exponents) for m in poly.terms.keys())
        basis_max_deg = max_deg // 2
        
        # Generate all monomials up to basis_max_deg
        basis = []
        
        def generate_exponents(current: List[int], remaining_vars: int, remaining_deg: int):
            if remaining_vars == 0:
                if remaining_deg == 0:  # Only add when we've used exactly the degree
                    basis.append(Monomial(tuple(current)))
                return
            
            # Try all possible exponents for the current variable
            for d in range(remaining_deg + 1):
                generate_exponents(
                    current + [d],
                    remaining_vars - 1,
                    remaining_deg - d
                )
        
        # Generate all monomials up to basis_max_deg
        for d in range(basis_max_deg + 1):
            generate_exponents([], num_vars, d)
        
        if self.verbose:
            print(f"Generated basis of size {len(basis)}:")
            for m in basis:
                print(f"  {m}")
        
        return basis
    
    def _build_coefficient_map(self, 
                             basis: MonomialBasis,
                             Q: cp.Variable) -> Dict[Monomial, cp.Expression]:
        """Build map from monomials to their coefficients in z^T Q z."""
        coeff_map = {}
        n = len(basis)
        
        # Compute z^T Q z symbolically
        for i in range(n):
            for j in range(n):
                # Multiply basis elements
                result_monomial = basis[i] * basis[j]
                # Add Q[i,j] to coefficient of resulting monomial
                if result_monomial in coeff_map:
                    coeff_map[result_monomial] += Q[i,j]
                else:
                    coeff_map[result_monomial] = Q[i,j]
        
        if self.verbose:
            print("\nCoefficient map:")
            for m, coeff in coeff_map.items():
                print(f"  {m}: {coeff}")
        
        return coeff_map
    
    def solve_sos_feasibility(self, 
                            poly: Polynomial,
                            basis: Optional[MonomialBasis] = None,
                            solver_options: dict = None) -> Tuple[bool, Optional[np.ndarray]]:
        """
        Check if a polynomial is SOS by solving a feasibility SDP.
        
        Args:
            poly: The polynomial to check
            basis: Optional monomial basis to use. If None, generates a complete basis.
                  If provided, attempts to find an SOS decomposition using only these basis elements.
            solver_options: Optional solver options
            
        Returns:
            Tuple of (is_sos, gram_matrix)
        """
        if self.verbose:
            print("\nSolving SOS feasibility for polynomial:")
            print(poly)
        
        # Use provided basis or generate complete basis
        if basis is None:
            basis = self._generate_monomial_basis(poly)
        
        if self.verbose:
            print(f"\nUsing basis of size {len(basis)}:")
            for m in basis:
                print(f"  {m}")
        
        n = len(basis)
        
        try:
            # Create PSD matrix variable with better numerical properties
            Q = cp.Variable((n, n), symmetric=True)
            
            # Add PSD constraint with small regularization for numerical stability
            epsilon = 1e-8
            constraints = [Q >> epsilon * np.eye(n)]
            
            # Build coefficient map for z^T Q z
            coeff_map = self._build_coefficient_map(basis, Q)
            
            # Scale the coefficients for better numerical conditioning
            scale = max(abs(coeff) for coeff in poly.terms.values())
            if scale > 0:
                scaled_poly = {m: c/scale for m, c in poly.terms.items()}
            else:
                scaled_poly = poly.terms
            
            # Add constraints matching coefficients
            for monomial, target_coeff in scaled_poly.items():
                if monomial in coeff_map:
                    constraints.append(coeff_map[monomial] == target_coeff)
                    if self.verbose:
                        print(f"\nAdding constraint for {monomial}: {coeff_map[monomial]} == {target_coeff}")
                else:
                    if self.verbose:
                        print(f"\nMonomial {monomial} not in basis products, cannot represent with this basis")
                    return False, None
            
            # Add constraints that all other terms must be zero
            for monomial, expr in coeff_map.items():
                if monomial not in scaled_poly:
                    constraints.append(expr == 0)
                    if self.verbose:
                        print(f"\nAdding zero constraint for {monomial}: {expr} == 0")
            
            # Add trace constraint for better numerical properties
            constraints.append(cp.trace(Q) <= 100 * n)  # Reasonable upper bound
            
            # Solve the feasibility SDP
            problem = cp.Problem(cp.Minimize(0), constraints)
            
            if self.verbose:
                print("\nSolving SDP...")
            
            # Handle solver options
            if solver_options is None:
                solver_options = {}
            
            # Add solver-specific default options
            if self.solver == 'MOSEK':
                # MOSEK doesn't need these SCS-specific defaults
                # Use only the provided MOSEK options
                pass
            elif self.solver == 'CLARABEL':
                # CLARABEL doesn't need SCS-specific defaults
                # Use only the provided CLARABEL options
                pass
            else:
                # Add default options for SCS and other solvers
                default_options = {
                    'normalize': True,
                    'scale': 1.0,
                    'eps': 1e-8,
                    'max_iters': 10000,
                }
                # Merge with user options, preferring user values
                solver_options = {**default_options, **solver_options}
            
            result = problem.solve(solver=self.solver, verbose=self.verbose, **solver_options)
            
            is_feasible = problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]
            Q_val = Q.value if is_feasible else None
            
            if Q_val is not None:
                # Scale back the solution
                Q_val *= float(scale)
                
                # Verify PSD property
                min_eig = np.min(np.linalg.eigvalsh(Q_val))
                is_feasible = is_feasible and min_eig > -1e-2
            
            if self.verbose:
                print(f"Solver status: {problem.status}")
                print(f"Optimal value: {result}")
                if Q_val is not None:
                    print("Q matrix eigenvalues:", np.linalg.eigvalsh(Q_val))
            
            return is_feasible, Q_val
            
        except (cp.error.SolverError, ValueError) as e:
            if self.verbose:
                print(f"Solver error: {e}")
            return False, None
    
    def get_sos_decomposition(self, 
                            poly: Polynomial, 
                            Q: np.ndarray,
                            basis: Optional[MonomialBasis] = None) -> str:
        """Get the SOS decomposition using eigendecomposition of Q."""
        # Use provided basis or generate complete basis
        if basis is None:
            basis = self._generate_monomial_basis(poly)
            
        if self.verbose:
            print("\nGenerating SOS decomposition using basis:")
            for m in basis:
                print(f"  {m}")
        
        # Compute eigendecomposition
        eigvals, eigvecs = np.linalg.eigh(Q)
        
        # Keep only significant eigenvalues and vectors
        tol = 1e-10
        significant = eigvals > tol
        eigvals = eigvals[significant]
        eigvecs = eigvecs[:, significant]
        
        # Build the squares
        terms = []
        for i, (val, vec) in enumerate(zip(eigvals, eigvecs.T)):
            # Construct the term sqrt(λ)(v₁z₁ + v₂z₂ + ...)
            coeff_str = f"{np.sqrt(val):.6f}"
            term_parts = []
            for coeff, monomial in zip(vec, basis):
                if abs(coeff) > tol:
                    # Format the monomial term
                    term = self._format_monomial(monomial, coeff)
                    if term:
                        term_parts.append(term)
            
            if term_parts:
                terms.append(f"({coeff_str}*({' + '.join(term_parts)}))²")
        
        return " + ".join(terms)
    
    def _format_monomial(self, monomial: Monomial, coeff: float) -> str:
        """Format a monomial term with coefficient."""
        if abs(coeff) < 1e-10:
            return ""
            
        # Convert exponents to string representation
        var_terms = []
        for i, exp in enumerate(monomial.exponents):
            if exp == 0:
                continue
            elif exp == 1:
                var_terms.append(f"x{i+1}")
            else:
                var_terms.append(f"x{i+1}^{exp}")
        
        if not var_terms:
            return f"{coeff:.6f}"
        
        term = "*".join(var_terms)
        if abs(coeff - 1.0) < 1e-10:
            return term
        elif abs(coeff + 1.0) < 1e-10:
            return f"-{term}"
        else:
            return f"{coeff:.6f}*{term}"

import numpy as np
import cvxpy as cp
from typing import List, Optional, Tuple, Set
from itertools import combinations_with_replacement

# ---- import your project‑specific helpers -----------------------------------
# We reuse the same Monomial / Polynomial datatypes that the user already has.
# Adjust the import paths if they live elsewhere.
from sos_transformer.data_generation.monomials.monomials import Polynomial, Monomial, MonomialBasis
from sos_transformer.utils.polynomial import get_newton_polytope_basis

from cvxpy.error import SolverError

# -----------------------------------------------------------------------------
#  Column‑generation variant of the CVXPYSOSSolver
# -----------------------------------------------------------------------------

class ColumnGenSOSSolver:
    """Sum‑of‑squares feasibility via *solver‑guided* basis extension.

    Start from an initial basis (typically the ML prediction),
    run the SDP; if infeasible, extract a *blocking pair* from the dual
    certificate, add the corresponding square‑root monomial, and repeat.

    The loop provably needs ≤ k iterations, where k is the number of
    missing true basis monomials.  In practice it terminates after 2–3
    rounds.
    """

    def __init__(self, solver: str = "SCS", verbose: bool = False):
        self.solver = solver
        self.verbose = verbose

    # ---------------------------------------------------------------------
    # 1)  "Vanilla" SDP feasibility on a *fixed* basis
    #     (code mostly copied from the user's CVXPYSOSSolver)
    # ---------------------------------------------------------------------

    def _solve_fixed_basis(self, poly: Polynomial, basis: MonomialBasis,
                           solver_opts: dict) -> Tuple[bool, Optional[np.ndarray]]:
        """Try to find a Gram matrix Q ⪰ 0 such that p(x)=z_B^T Q z_B.

        Returns (is_feasible, Q_value or None).
        """
        if self.verbose:
            print(f"\n[solver]  solving SDP on basis of size {len(basis)} …")

        n = len(basis)
        Q = cp.Variable((n, n), symmetric=True)

        # Mild regularization for numerical stability
        eps = 1e-9
        constraints = [Q >> eps * np.eye(n)]

        # Build coefficient map  z^T Q z  (same helper as in base code)
        coeff_map = {}
        for i in range(n):
            for j in range(i, n):                # exploit symmetry i ≤ j
                m = basis[i] * basis[j]
                entry = Q[i, j] if i == j else 2 * Q[i, j]
                coeff_map[m] = coeff_map.get(m, 0) + entry

        # Match polynomial coefficients
        for m, target in poly.terms.items():
            constraints.append(coeff_map.get(m, 0) == target)
        # All other monomials must vanish
        for m, expr in coeff_map.items():
            if m not in poly.terms:
                constraints.append(expr == 0)

        prob = cp.Problem(cp.Minimize(0), constraints)
        try:
            prob.solve(solver=self.solver, verbose=self.verbose, **solver_opts)
        except (SolverError, ValueError) as e:
            if self.verbose:
                print("Solver crashed:", e)
            return False, None

        feasible = prob.status in (cp.OPTIMAL, cp.OPTIMAL_INACCURATE)
        return feasible, Q.value if feasible else None

    # ---------------------------------------------------------------------
    # 2)  Extract *blocking* row/column indices from an *infeasible* run
    # ---------------------------------------------------------------------

    @staticmethod
    def _find_blocking_pairs(Q_dual: np.ndarray, tol: float = 1e-7) -> Set[Tuple[int, int]]:
        """Return index pairs (i,j) with zero diagonals but |Q_ij| > tol.

        Works with either the primal iterate *or* any symmetric matrix that
        the solver provides when it declares primal infeasibility.
        """
        n = Q_dual.shape[0]
        pairs = set()
        diag_zero = np.abs(np.diag(Q_dual)) < tol
        for i in range(n):
            if not diag_zero[i]:
                continue
            for j in range(i + 1, n):
                if diag_zero[j] and abs(Q_dual[i, j]) > tol:
                    pairs.add((i, j))
        return pairs

    @staticmethod
    def _half_sum(m1: Monomial, m2: Monomial) -> Optional[Monomial]:
        """Return the component‑wise average if it is integral; else None."""
        exps1, exps2 = m1.exponents, m2.exponents
        mid = []
        for a, b in zip(exps1, exps2):
            s = a + b
            if s & 1:                 # odd ⇒ non‑integral
                return None
            mid.append(s // 2)
        return Monomial(tuple(mid))

    # ------------------------------------------------------------------
    # 3)  *Column‑generation* outer loop
    # ------------------------------------------------------------------

    def solve_sos(self, poly: Polynomial, init_basis: Optional[MonomialBasis] = None,
                  max_loops: int = 10, solver_opts: Optional[dict] = None) -> Tuple[bool, MonomialBasis, Optional[np.ndarray]]:
        """Main entry: returns (is_sos, final_basis, Gram_matrix)."""
        if solver_opts is None:
            solver_opts = {
                "eps": 1e-6,
                "max_iters": 10000,
            }

        # -- build *initial* basis -------------------------------------------------
        if init_basis is None:
            # fallback: Newton polytope as a safe but large default
            exps = np.array([m.exponents for m in poly.terms])
            pts = get_newton_polytope_basis(exps)
            init_basis = [Monomial(tuple(map(int, p))) for p in pts]

        basis: MonomialBasis = list(init_basis)  # make a modifiable copy

        for loop in range(1, max_loops + 1):
            feasible, Q_val = self._solve_fixed_basis(poly, basis, solver_opts)
            if feasible:
                if self.verbose:
                    print(f"[solver]  feasible after {loop} iteration(s)")
                return True, basis, Q_val

            # Infeasible ⇒ query the solver for its *final* primal iterate
            # Many solvers return the last symmetric matrix even if infeasible.
            if Q_val is None:
                if self.verbose:
                    print("[solver]  infeasible & no primal matrix available — abort")
                return False, basis, None

            pairs = self._find_blocking_pairs(Q_val)
            new_monos: List[Monomial] = []
            for i, j in pairs:
                t = self._half_sum(basis[i], basis[j])
                if t is not None and t not in basis and t not in new_monos:
                    new_monos.append(t)
            if not new_monos:
                if self.verbose:
                    print("[solver]  no integral half‑sums found — fallback fails")
                return False, basis, None

            if self.verbose:
                print(f"[solver]  adding {len(new_monos)} monomials: {new_monos}")
            basis.extend(new_monos)

        # exceeded iteration budget ------------------------------------------------
        if self.verbose:
            print(f"[solver]  gave up after {max_loops} outer iterations")
        return False, basis, None



if __name__ == "__main__":
    import time

    print("Starting CVXPYSOSSolver example...")
    
    # Example usage
    from data_generation import (
        SparseUniformBasisSampler,
        SimpleRandomPSDSampler,
        SOSPolynomialSampler
    )
    
    print("\nCreating samplers...")
    # Create samplers
    basis_sampler = SparseUniformBasisSampler(
        min_sparsity=0.3,
        max_sparsity=0.5,
        min_degree=1,
        max_degree=2
    )
    
    matrix_sampler = SimpleRandomPSDSampler(
        min_eigenval=0.0,
        scale=0.1,
        random_state=42
    )
    
    print("\nSampling polynomial...")
    # Create and sample a polynomial
    poly_sampler = SOSPolynomialSampler(basis_sampler, matrix_sampler)
    poly, basis, Q = poly_sampler.sample(num_vars=5, max_degree=3)
    
    print("\nCreating and running SDP solver...")
    # Create SDP solver and verify SOS
    solver = CVXPYSOSSolver(
        solver='MOSEK',  # SCS is the default solver
        verbose=False   # This will help you see what's happening
    )
    
    # Common solver options
    solver_options = {
        'mosek_params': {
            'MSK_DPAR_INTPNT_CO_TOL_DFEAS': 1e-6,
            'MSK_DPAR_INTPNT_CO_TOL_PFEAS': 1e-6,
            'MSK_DPAR_INTPNT_CO_TOL_REL_GAP': 1e-6,
            'MSK_DPAR_INTPNT_CO_TOL_MU_RED': 1e-8,
            'MSK_DPAR_INTPNT_CO_TOL_INFEAS': 1e-6,
            'MSK_DPAR_INTPNT_TOL_DFEAS': 1e-6,
            'MSK_DPAR_INTPNT_TOL_PFEAS': 1e-6,
            'MSK_DPAR_INTPNT_TOL_REL_GAP': 1e-6,
            'MSK_DPAR_INTPNT_TOL_MU_RED': 1e-8,
            'MSK_DPAR_INTPNT_TOL_INFEAS': 1e-6,
            'MSK_IPAR_INTPNT_MAX_ITERATIONS': 10000,
            'MSK_IPAR_INTPNT_SCALING': 1,
            'MSK_IPAR_LOG_INTPNT': 1,
        }
    }
    
    # Try with full basis first
    print("\nTrying with full basis:")
    t0 = time.time()    
    is_sos, Q_full = solver.solve_sos_feasibility(
        poly,
        solver_options=solver_options
    )
    t1 = time.time()
    print(f"\nTime for full basis: {t1 - t0:.4f} seconds")
    
    print("\nResults with full basis:")
    print("Generated polynomial:")
    print(poly)
    print("\nIs SOS?", is_sos)
    
    if is_sos:
        print("\nSOS decomposition:")
        print(solver.get_sos_decomposition(poly, Q_full))
    
    # Try with reduced basis (original basis from sampling)
    print(basis)
    print("\nTrying with reduced basis (original):")
    t0 = time.time()
    is_sos_reduced, Q_reduced = solver.solve_sos_feasibility(
        poly, 
        basis=basis,
        solver_options=solver_options
    )
    t1 = time.time()
    print(f"\nTime for reduced basis: {t1 - t0:.4f} seconds")
    
    print("\nResults with reduced basis:")
    print("Is SOS with reduced basis?", is_sos_reduced)
    
    if is_sos_reduced:
        print("\nSOS decomposition with reduced basis:")
        print(solver.get_sos_decomposition(poly, Q_reduced, basis=basis))
    
    # Try with Newton polytope basis
    print("\nTrying with Newton polytope basis:")
    
    # Extract exponents from polynomial terms
    exponents = []
    for monomial in poly.terms.keys():
        exponents.append(monomial.exponents)
    exponents = np.array(exponents)
    
    # Get Newton polytope basis
    newton_basis_points = get_newton_polytope_basis(exponents)
    if newton_basis_points is not None:
        # Convert points back to Monomial objects
        newton_basis = [Monomial(tuple(map(int, point))) for point in newton_basis_points]
        
        print(f"\nNewton polytope basis size: {len(newton_basis)}")
        print("Newton polytope basis:")
        for m in newton_basis:
            print(f"  {m}")
        
        t0 = time.time()
        is_sos_newton, Q_newton = solver.solve_sos_feasibility(
            poly, 
            basis=newton_basis,
            solver_options=solver_options
        )
        t1 = time.time()
        print(f"\nTime for Newton polytope basis: {t1 - t0:.4f} seconds")
        
        print("\nResults with Newton polytope basis:")
        print("Is SOS with Newton polytope basis?", is_sos_newton)
        
        if is_sos_newton:
            print("\nSOS decomposition with Newton polytope basis:")
            print(solver.get_sos_decomposition(poly, Q_newton, basis=newton_basis))
    else:
        print("\nFailed to compute Newton polytope basis")
    
    # Print comparison
    print("\nComparison of methods:")
    print(f"Full basis      : {'✓' if is_sos else '✗'}")
    print(f"Reduced basis   : {'✓' if is_sos_reduced else '✗'}")
    print(f"Newton polytope : {'✓' if 'is_sos_newton' in locals() and is_sos_newton else '✗'}")
    
    print("\nExample completed.") 