"""
Compact SAT Feature Extractor
Extracts 21 independent, non-redundant features for SAT problems
No label leakage, no duplicate features
"""

import os
import numpy as np
from typing import List
from dataclasses import dataclass


@dataclass
class CompactSATFeatures:
    """
    Compact SAT problem features with 21 independent dimensions
    No redundancy, no label leakage
    """
    # Basic features (3)
    num_variables: int          # Number of variables
    num_clauses: int           # Number of clauses  
    var_clause_ratio: float    # Variables/clauses ratio

    # Clause structure features (8)
    avg_clause_length: float   # Average clause length
    std_clause_length: float   # Standard deviation of clause lengths
    min_clause_length: int     # Minimum clause length
    max_clause_length: int     # Maximum clause length
    clause_length_entropy: float  # Entropy of clause length distribution
    unit_clause_ratio: float   # Ratio of unit clauses
    binary_clause_ratio: float # Ratio of binary clauses
    long_clause_ratio: float   # Ratio of long clauses (>5 literals)

    # Variable activity features (5)
    avg_var_frequency: float   # Average variable occurrence frequency
    std_var_frequency: float   # Standard deviation of variable frequencies
    max_var_frequency: int     # Maximum variable frequency
    var_frequency_entropy: float # Entropy of variable frequency distribution
    singleton_var_ratio: float # Ratio of variables appearing only once

    # Polarity features (4)
    positive_literal_ratio: float # Ratio of positive literals
    balanced_var_ratio: float     # Ratio of variables with balanced pos/neg occurrences
    pure_literal_ratio: float     # Ratio of pure literal variables
    polarity_bias: float          # Polarity bias measure

    # Complexity features (1)
    constraint_density: float     # Constraint density (literals per variable per clause)


class CompactSATFeatureExtractor:
    """Compact SAT feature extractor with 21 independent features"""
    
    def __init__(self):
        pass

    def extract_features_from_cnf(self, cnf_file_path: str) -> CompactSATFeatures:
        """Extract features from CNF file"""
        if not os.path.exists(cnf_file_path):
            raise FileNotFoundError(f"CNF file not found: {cnf_file_path}")
        
        variables = set()
        clauses = []
        clause_lengths = []
        var_positive_count = {}  # Variable positive occurrence count
        var_negative_count = {}  # Variable negative occurrence count
        var_total_count = {}     # Variable total occurrence count
        literal_counts = {'positive': 0, 'negative': 0}
        
        # Parse CNF file
        with open(cnf_file_path, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                line = line.strip()
                if line.startswith('c') or line == '':
                    continue
                elif line.startswith('p'):
                    # Skip header, we recalculate everything
                    continue
                else:
                    # Parse clause
                    clause = []
                    for lit_str in line.split():
                        if lit_str == '0':
                            break
                        literal = int(lit_str)
                        clause.append(literal)
                        var_id = abs(literal)
                        variables.add(var_id)
                        
                        # Count variable occurrences
                        var_total_count[var_id] = var_total_count.get(var_id, 0) + 1
                        if literal > 0:
                            literal_counts['positive'] += 1
                            var_positive_count[var_id] = var_positive_count.get(var_id, 0) + 1
                        else:
                            literal_counts['negative'] += 1
                            var_negative_count[var_id] = var_negative_count.get(var_id, 0) + 1
                    
                    if clause:  # Non-empty clause
                        clauses.append(clause)
                        clause_lengths.append(len(clause))

        return self._calculate_features(
            variables, clauses, clause_lengths, var_positive_count, 
            var_negative_count, var_total_count, literal_counts
        )

    def _calculate_features(self, variables, clauses, clause_lengths,
                          var_positive_count, var_negative_count, 
                          var_total_count, literal_counts) -> CompactSATFeatures:
        """Calculate 21 independent features"""
        
        # Basic features (3)
        num_variables = len(variables)
        num_clauses = len(clauses)
        var_clause_ratio = num_variables / num_clauses if num_clauses > 0 else 0
        
        # Clause structure features (8)
        if clause_lengths:
            avg_clause_length = np.mean(clause_lengths)
            std_clause_length = np.std(clause_lengths)
            min_clause_length = min(clause_lengths)
            max_clause_length = max(clause_lengths)
            
            # Clause length distribution entropy
            length_counts = {}
            for length in clause_lengths:
                length_counts[length] = length_counts.get(length, 0) + 1
            length_probs = [count / len(clause_lengths) for count in length_counts.values()]
            clause_length_entropy = -sum(p * np.log2(p) if p > 0 else 0 for p in length_probs)
            
            # Special clause ratios
            unit_clause_ratio = sum(1 for l in clause_lengths if l == 1) / len(clause_lengths)
            binary_clause_ratio = sum(1 for l in clause_lengths if l == 2) / len(clause_lengths)
            long_clause_ratio = sum(1 for l in clause_lengths if l > 5) / len(clause_lengths)
        else:
            avg_clause_length = std_clause_length = min_clause_length = max_clause_length = 0
            clause_length_entropy = unit_clause_ratio = binary_clause_ratio = long_clause_ratio = 0
        
        # Variable activity features (5)
        if var_total_count:
            frequencies = list(var_total_count.values())
            avg_var_frequency = np.mean(frequencies)
            std_var_frequency = np.std(frequencies)
            max_var_frequency = max(frequencies)
            
            # Variable frequency distribution entropy
            freq_counts = {}
            for freq in frequencies:
                freq_counts[freq] = freq_counts.get(freq, 0) + 1
            freq_probs = [count / len(frequencies) for count in freq_counts.values()]
            var_frequency_entropy = -sum(p * np.log2(p) if p > 0 else 0 for p in freq_probs)
            
            singleton_var_ratio = sum(1 for f in frequencies if f == 1) / len(frequencies)
        else:
            avg_var_frequency = std_var_frequency = max_var_frequency = 0
            var_frequency_entropy = singleton_var_ratio = 0
        
        # Polarity features (4)
        total_literals = literal_counts['positive'] + literal_counts['negative']
        positive_literal_ratio = literal_counts['positive'] / total_literals if total_literals > 0 else 0
        
        # Calculate balanced variables and pure literals
        balanced_vars = 0
        pure_literal_vars = 0
        if variables:
            for var_id in variables:
                pos_count = var_positive_count.get(var_id, 0)
                neg_count = var_negative_count.get(var_id, 0)
                total_var_count = pos_count + neg_count
                
                if total_var_count > 0:
                    balance = min(pos_count, neg_count) / total_var_count
                    if balance > 0.3:  # Balance threshold
                        balanced_vars += 1
                    if pos_count == 0 or neg_count == 0:  # Pure literal
                        pure_literal_vars += 1
            
            balanced_var_ratio = balanced_vars / len(variables)
            pure_literal_ratio = pure_literal_vars / len(variables)
        else:
            balanced_var_ratio = pure_literal_ratio = 0
        
        polarity_bias = abs(positive_literal_ratio - 0.5) * 2  # [0,1]
        
        # Complexity features (1)
        literal_count = sum(clause_lengths)
        constraint_density = literal_count / (num_variables * num_clauses) if num_variables * num_clauses > 0 else 0
        
        return CompactSATFeatures(
            # Basic features (3)
            num_variables=num_variables,
            num_clauses=num_clauses,
            var_clause_ratio=var_clause_ratio,
            
            # Clause structure features (8)
            avg_clause_length=avg_clause_length,
            std_clause_length=std_clause_length,
            min_clause_length=min_clause_length,
            max_clause_length=max_clause_length,
            clause_length_entropy=clause_length_entropy,
            unit_clause_ratio=unit_clause_ratio,
            binary_clause_ratio=binary_clause_ratio,
            long_clause_ratio=long_clause_ratio,
            
            # Variable activity features (5)
            avg_var_frequency=avg_var_frequency,
            std_var_frequency=std_var_frequency,
            max_var_frequency=max_var_frequency,
            var_frequency_entropy=var_frequency_entropy,
            singleton_var_ratio=singleton_var_ratio,
            
            # Polarity features (4)
            positive_literal_ratio=positive_literal_ratio,
            balanced_var_ratio=balanced_var_ratio,
            pure_literal_ratio=pure_literal_ratio,
            polarity_bias=polarity_bias,
            
            # Complexity features (1)
            constraint_density=constraint_density
        )

    def extract_features_to_vector(self, cnf_file_path: str) -> np.ndarray:
        """Extract features and return as fixed-length vector (21 dimensions)"""
        features = self.extract_features_from_cnf(cnf_file_path)
        return np.array([
            # Basic features (3)
            features.num_variables, features.num_clauses, features.var_clause_ratio,
            
            # Clause structure features (8)
            features.avg_clause_length, features.std_clause_length,
            features.min_clause_length, features.max_clause_length,
            features.clause_length_entropy, features.unit_clause_ratio,
            features.binary_clause_ratio, features.long_clause_ratio,
            
            # Variable activity features (5)
            features.avg_var_frequency, features.std_var_frequency,
            features.max_var_frequency, features.var_frequency_entropy,
            features.singleton_var_ratio,
            
            # Polarity features (4)
            features.positive_literal_ratio, features.balanced_var_ratio,
            features.pure_literal_ratio, features.polarity_bias,
            
            # Complexity features (1)
            features.constraint_density
        ], dtype=np.float32)

    def get_feature_names(self) -> List[str]:
        """Get feature names list"""
        return [
            # Basic features (3)
            'num_variables', 'num_clauses', 'var_clause_ratio',
            # Clause structure features (8)
            'avg_clause_length', 'std_clause_length', 'min_clause_length', 'max_clause_length',
            'clause_length_entropy', 'unit_clause_ratio', 'binary_clause_ratio', 'long_clause_ratio',
            # Variable activity features (5)
            'avg_var_frequency', 'std_var_frequency', 'max_var_frequency', 'var_frequency_entropy',
            'singleton_var_ratio',
            # Polarity features (4)
            'positive_literal_ratio', 'balanced_var_ratio', 'pure_literal_ratio', 'polarity_bias',
            # Complexity features (1)
            'constraint_density'
        ]