"""
Sum of Squares Basis Extension using Dual SDP and Schur Complement Scores

This module implements an iterative basis extension algorithm for Sum of Squares (SOS) decomposition:
1. Start with an initial monomial basis
2. Try the primal SDP (find Gram matrix)
3. If infeasible, solve the dual to get moment values
4. Compute Schur scores for candidate monomials
5. Add monomials with most negative scores to the basis
6. Repeat until feasible or proven impossible

Uses the monomial and polynomial data structures from the project.
"""

import itertools
import numpy as np
import cvxpy as cp
from typing import List, Dict, Tuple, Set, Optional, Union
from collections import defaultdict

from data_generation.monomials.monomials import Monomial, Polynomial


class SOSBasisExtender:
    """
    A class for iteratively extending monomial bases for SOS decomposition using
    dual SDP formulation and Schur complement scores.
    """
    
    def __init__(self, polynomial: Polynomial, num_vars: int, max_degree: Optional[int] = None):
        """
        Initialize the SOS basis extender.
        
        Args:
            polynomial: The polynomial to find SOS decomposition for
            num_vars: Number of variables in the polynomial
            max_degree: Maximum degree to consider for candidate monomials (defaults to 2*poly_degree)
        """
        self.polynomial = polynomial
        self.num_vars = num_vars
        self.poly_degree = max(m.degree for m in polynomial.terms.keys()) if polynomial.terms else 0
        self.max_degree = max_degree or (2 * self.poly_degree)
        
        # Solver preferences
        self.try_solvers = ["MOSEK", "SCS", "CVXOPT"]
        
    def monomial_to_tuple(self, monomial: Monomial) -> Tuple[int, ...]:
        """Convert Monomial object to tuple representation."""
        exps = list(monomial.exponents)
        # Pad with zeros if needed
        while len(exps) < self.num_vars:
            exps.append(0)
        return tuple(exps[:self.num_vars])
    
    def tuple_to_monomial(self, exp_tuple: Tuple[int, ...]) -> Monomial:
        """Convert tuple representation to Monomial object."""
        return Monomial(exp_tuple)
    
    def add_monomials(self, m1: Tuple[int, ...], m2: Tuple[int, ...]) -> Tuple[int, ...]:
        """Add two monomials represented as tuples."""
        return tuple(a + b for a, b in zip(m1, m2))
    
    def get_polynomial_coeffs_dict(self) -> Dict[Tuple[int, ...], float]:
        """Convert polynomial to coefficient dictionary with tuple keys."""
        coeffs = {}
        for monomial, coeff in self.polynomial.terms.items():
            key = self.monomial_to_tuple(monomial)
            coeffs[key] = float(coeff)
        return coeffs
    
    def generate_candidate_monomials(self, current_basis: List[Monomial]) -> List[Monomial]:
        """
        Generate candidate monomials for basis extension based on coverage analysis.
        
        Args:
            current_basis: Current monomial basis
            
        Returns:
            List of candidate monomials to consider for addition
        """
        # Convert basis to tuple representation
        basis_tuples = [self.monomial_to_tuple(m) for m in current_basis]
        
        # Get polynomial support
        support = set(self.monomial_to_tuple(m) for m in self.polynomial.terms.keys())
        
        # Compute current basis products
        current_products = set()
        for i in range(len(basis_tuples)):
            for j in range(i, len(basis_tuples)):
                product = self.add_monomials(basis_tuples[i], basis_tuples[j])
                current_products.add(product)
        
        # Find missing terms
        missing = support - current_products
        
        # Generate candidates based on divisibility analysis
        candidates = set()
        
        # Add all monomials up to half the polynomial degree
        max_candidate_degree = min(self.max_degree, self.poly_degree)
        for degree in range(max_candidate_degree + 1):
            for exp_tuple in itertools.combinations_with_replacement(range(degree + 1), self.num_vars):
                if sum(exp_tuple) <= degree:
                    # Generate all permutations of this exponent pattern
                    for perm in set(itertools.permutations(exp_tuple)):
                        candidate_tuple = tuple(perm)
                        candidate = self.tuple_to_monomial(candidate_tuple)
                        if candidate not in current_basis:
                            candidates.add(candidate)
        
        # Also add divisors of missing terms
        for missing_term in missing:
            for basis_elem in basis_tuples:
                # Check if we can divide missing_term by basis_elem
                if all(m >= b for m, b in zip(missing_term, basis_elem)):
                    remainder = tuple(m - b for m, b in zip(missing_term, basis_elem))
                    remainder_monomial = self.tuple_to_monomial(remainder)
                    if remainder_monomial not in current_basis:
                        candidates.add(remainder_monomial)
        
        return list(candidates)
    
    def moment_matrix(self, basis: List[Tuple[int, ...]], yvals: Dict[Tuple[int, ...], float]) -> np.ndarray:
        """
        Build moment matrix M_B(y) where M[i,j] = y_{alpha_i + alpha_j}.
        
        Args:
            basis: List of basis monomials as tuples
            yvals: Dictionary of moment values
            
        Returns:
            Moment matrix
        """
        m = len(basis)
        M = np.zeros((m, m))
        for i, ai in enumerate(basis):
            for j, aj in enumerate(basis):
                key = self.add_monomials(ai, aj)
                M[i, j] = yvals.get(key, 0.0)
        return M
    
    def schur_score(self, basis: List[Tuple[int, ...]], yvals: Dict[Tuple[int, ...], float], 
                   beta: Tuple[int, ...], pinv_tol: float = 1e-10) -> Tuple[float, float]:
        """
        Compute Schur complement score for adding monomial beta to the basis.
        
        Args:
            basis: Current basis as list of tuples
            yvals: Moment values from dual solution
            beta: Candidate monomial to add
            pinv_tol: Tolerance for pseudoinverse computation
            
        Returns:
            Tuple of (schur_score, range_defect)
        """
        M = self.moment_matrix(basis, yvals)
        
        # Rank-revealing eigendecomposition
        lam, U = np.linalg.eigh(M)
        keep = lam > (pinv_tol * max(1.0, lam.max()))
        
        if keep.any():
            Uplus = U[:, keep]
            LamInv = np.zeros_like(M)
            LamInv[np.ix_(keep, keep)] = np.diag(1.0 / lam[keep])
            P = Uplus @ Uplus.T
        else:
            P = np.zeros_like(M)
        
        # Compute k vector: k_i = y_{alpha_i + beta}
        k = np.array([yvals.get(self.add_monomials(ai, beta), 0.0) for ai in basis], dtype=float)
        k_proj = P @ k
        rho = float(np.linalg.norm(k - k_proj))
        
        # Pseudoinverse action: z = M^† k
        z = U @ LamInv @ U.T @ k
        
        # Schur score: s(beta) = y_{2*beta} - k^T M^† k
        beta_doubled = self.add_monomials(beta, beta)
        s = yvals.get(beta_doubled, 0.0) - float(k @ z)
        
        return s, rho
    
    def solve_primal(self, basis: List[Monomial]) -> Tuple[str, Optional[np.ndarray]]:
        """
        Solve the primal SDP to find Gram matrix Q.
        
        Args:
            basis: Current monomial basis
            
        Returns:
            Tuple of (status, Q_matrix)
        """
        basis_tuples = [self.monomial_to_tuple(m) for m in basis]
        poly_coeffs = self.get_polynomial_coeffs_dict()
        
        # Create Gram matrix variable
        Q = cp.Variable((len(basis), len(basis)), PSD=True)
        
        # Generate all possible products from basis elements
        all_products = {}
        for i, ai in enumerate(basis_tuples):
            for j in range(i, len(basis_tuples)):  # Use j >= i to avoid double counting
                aj = basis_tuples[j]
                product = self.add_monomials(ai, aj)
                if product not in all_products:
                    all_products[product] = []
                all_products[product].append((i, j))
        
        # Create constraints for each possible product
        constraints = []
        for product, indices_list in all_products.items():
            # Get coefficient from polynomial (0 if not present)
            target_coeff = poly_coeffs.get(product, 0.0)
            
            # Sum over all ways to form this product
            expr_terms = []
            for i, j in indices_list:
                if i == j:
                    expr_terms.append(Q[i, j])
                else:
                    expr_terms.append(2 * Q[i, j])  # Count both Q[i,j] and Q[j,i]
            
            if expr_terms:
                constraints.append(cp.sum(expr_terms) == target_coeff)
        
        # Solve the problem
        problem = cp.Problem(cp.Minimize(0), constraints)
        
        for solver_name in self.try_solvers:
            try:
                problem.solve(solver=getattr(cp, solver_name))
                if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                    return problem.status, Q.value
                elif problem.status == cp.INFEASIBLE:
                    return problem.status, None
            except Exception as e:
                print(f"Solver {solver_name} failed: {e}")
                continue
        
        return "SOLVER_ERROR", None
    
    def solve_dual(self, basis: List[Monomial], target_monomial: Optional[Tuple[int, ...]] = None) -> Dict[Tuple[int, ...], float]:
        """
        Solve the dual SDP to obtain moment values.
        
        Args:
            basis: Current monomial basis
            target_monomial: Optional target monomial to minimize (for dual objective)
            
        Returns:
            Dictionary of moment values
        """
        basis_tuples = [self.monomial_to_tuple(m) for m in basis]
        
        # Generate all moments up to 2 * max_degree
        max_moment_degree = 2 * self.max_degree
        moment_keys = []
        for degree in range(max_moment_degree + 1):
            for exp_tuple in itertools.combinations_with_replacement(range(degree + 1), self.num_vars):
                if sum(exp_tuple) <= degree:
                    for perm in set(itertools.permutations(exp_tuple)):
                        moment_keys.append(tuple(perm))
        
        # Create moment variables
        y = {key: cp.Variable(name=f"y_{'_'.join(map(str, key))}") for key in moment_keys}
        
        # Create moment matrix
        M_entries = []
        for ai in basis_tuples:
            row = []
            for aj in basis_tuples:
                key = self.add_monomials(ai, aj)
                row.append(y[key] if key in y else 0)
            M_entries.append(row)
        M = cp.bmat(M_entries)
        
        # Dual constraints
        constraints = [
            M >> 0,  # PSD constraint
            y[tuple([0] * self.num_vars)] == 1.0  # Normalization: y_0 = 1
        ]
        
        # Add trace constraint for regularization
        constraints.append(cp.trace(M) == 1.0)
        
        # Objective: minimize some moment (e.g., first polynomial term)
        if target_monomial is None:
            poly_terms = list(self.polynomial.terms.keys())
            if poly_terms:
                target_monomial = self.monomial_to_tuple(poly_terms[0])
            else:
                target_monomial = tuple([0] * self.num_vars)
        
        objective = cp.Minimize(y.get(target_monomial, 0))
        
        # Solve dual problem
        dual_problem = cp.Problem(objective, constraints)
        
        for solver_name in self.try_solvers:
            try:
                dual_problem.solve(solver=getattr(cp, solver_name))
                if dual_problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                    break
            except Exception as e:
                print(f"Dual solver {solver_name} failed: {e}")
                continue
        
        # Extract moment values
        yvals = {}
        for key, var in y.items():
            yvals[key] = float(var.value) if var.value is not None else 0.0
        
        return yvals
    
    def extend_basis_iteratively(self, initial_basis: List[Monomial], max_iterations: int = 10, 
                                score_threshold: float = -1e-6, max_additions_per_iter: int = 3, 
                                verbose: bool = False) -> Tuple[List[Monomial], bool, Dict]:
        """
        Iteratively extend the basis using SOS dual approach and Schur scores.
        
        Args:
            initial_basis: Starting monomial basis
            max_iterations: Maximum number of extension iterations
            score_threshold: Add monomials with scores below this threshold
            max_additions_per_iter: Maximum monomials to add per iteration
            verbose: Whether to print detailed progress
            
        Returns:
            Tuple of (final_basis, is_feasible, info_dict)
        """
        current_basis = initial_basis.copy()
        iteration_info = []
        
        if verbose:
            print(f"Starting SOS basis extension with {len(current_basis)} monomials")
            print(f"Initial basis: {[str(m) for m in current_basis]}")
        
        for iteration in range(max_iterations):
            if verbose:
                print(f"\n--- Iteration {iteration + 1} ---")
            
            # Try primal SDP
            primal_status, Q_matrix = self.solve_primal(current_basis)
            if verbose:
                print(f"Primal status: {primal_status}")
            
            if primal_status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                if verbose:
                    print("✓ SOS decomposition found!")
                info = {
                    'iterations': iteration + 1,
                    'final_basis_size': len(current_basis),
                    'gram_matrix': Q_matrix,
                    'iteration_info': iteration_info
                }
                return current_basis, True, info
            
            elif primal_status == cp.INFEASIBLE:
                if verbose:
                    print("Primal infeasible, solving dual...")
                
                # Solve dual to get moments
                yvals = self.solve_dual(current_basis)
                
                # Generate candidate monomials
                candidates = self.generate_candidate_monomials(current_basis)
                if verbose:
                    print(f"Generated {len(candidates)} candidate monomials")
                
                # Compute Schur scores for candidates
                candidate_scores = []
                for candidate in candidates:
                    candidate_tuple = self.monomial_to_tuple(candidate)
                    basis_tuples = [self.monomial_to_tuple(m) for m in current_basis]
                    score, rho = self.schur_score(basis_tuples, yvals, candidate_tuple)
                    candidate_scores.append((candidate, score, rho))
                
                # Sort by score (most negative first)
                candidate_scores.sort(key=lambda x: x[1])
                
                # Add best candidates
                added_count = 0
                for candidate, score, rho in candidate_scores:
                    if score < score_threshold and added_count < max_additions_per_iter:
                        current_basis.append(candidate)
                        added_count += 1
                        if verbose:
                            print(f"  Added {candidate} (score: {score:.6g}, rho: {rho:.3e})")
                
                if added_count == 0:
                    if verbose:
                        print("No candidates with negative scores found. Stopping.")
                    break
                
                iteration_info.append({
                    'iteration': iteration + 1,
                    'basis_size': len(current_basis),
                    'candidates_evaluated': len(candidates),
                    'monomials_added': added_count,
                    'best_scores': [score for _, score, _ in candidate_scores[:5]]
                })
                
            else:
                if verbose:
                    print(f"Solver error: {primal_status}")
                break
        
        if verbose:
            print(f"\nBasis extension completed after {max_iterations} iterations")
            print(f"Final basis size: {len(current_basis)}")
        
        info = {
            'iterations': max_iterations,
            'final_basis_size': len(current_basis),
            'gram_matrix': None,
            'iteration_info': iteration_info
        }
        
        return current_basis, False, info
