import os
import re
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import json


@dataclass
class SATFeatures:
    """SAT problem feature data structure"""
    # Basic features
    num_variables: int
    num_clauses: int
    var_clause_ratio: float
    
    # Clause structure features
    clause_length_distribution: Dict[int, int]  # length -> count
    avg_clause_length: float
    max_clause_length: int
    min_clause_length: int
    
    # Variable features
    variable_frequency: Dict[int, int]  # variable ID -> occurrence count
    variable_activity: Dict[int, float]  # variable ID -> activity level
    
    # Complexity features
    problem_type: str  # problem type identifier
    constraint_density: float  # constraint density
    
    # Statistical features
    positive_literal_ratio: float  # positive literal ratio
    negative_literal_ratio: float  # negative literal ratio
    
    # Dynamic features (updated during solving process)
    conflict_rate: Optional[float] = None
    learning_rate: Optional[float] = None
    restart_frequency: Optional[float] = None


class SATFeatureExtractor:
    """SAT problem feature extractor"""
    
    def __init__(self):
        self.problem_type_patterns = {
            'CNP': r'CNP-\d+-\d+',
            'SCPC': r'SCPC-\d+-\d+',
            'PRP': r'PRP-\d+-\d+',
            'SAT': r'.*_sat\.cnf',
            'UNSAT': r'.*_unsat\.cnf',
            'QUEEN': r'queen\d+-\d+',
            'CARRIER': r'Carry_Bits_Fast_\d+',
            'PRIME': r'prime_\d+',
            'SUMMLE': r'summle_X\d+',
            'KAKURO': r'Kakuro-.*',
            'TRAFFIC': r'traffic_.*',
            'ORTHOLATIN': r'ortholatin-\d+',
            'FRB': r'frb\d+-\d+-\d+'
        }
    
    def extract_features_from_cnf(self, cnf_file_path: str) -> SATFeatures:
        """Extract features from CNF file"""
        if not os.path.exists(cnf_file_path):
            raise FileNotFoundError(f"CNF file not found: {cnf_file_path}")
        
        variables = set()
        clauses = []
        clause_lengths = []
        literal_counts = {'positive': 0, 'negative': 0}
        
        with open(cnf_file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('c') or line == '':
                    continue
                elif line.startswith('p'):
                    # Parse header information
                    parts = line.split()
                    if len(parts) >= 4:
                        num_vars = int(parts[2])
                        num_clauses = int(parts[3])
                else:
                    # Parse clauses
                    clause = []
                    for lit in line.split():
                        if lit == '0':
                            break
                        literal = int(lit)
                        clause.append(literal)
                        variables.add(abs(literal))
                        if literal > 0:
                            literal_counts['positive'] += 1
                        else:
                            literal_counts['negative'] += 1
                    
                    if clause:  # Non-empty clause
                        clauses.append(clause)
                        clause_lengths.append(len(clause))
        
        # Calculate features
        num_variables = len(variables)
        num_clauses = len(clauses)
        var_clause_ratio = num_variables / num_clauses if num_clauses > 0 else 0
        
        # Clause length distribution
        clause_length_dist = {}
        for length in clause_lengths:
            clause_length_dist[length] = clause_length_dist.get(length, 0) + 1
        
        avg_clause_length = np.mean(clause_lengths) if clause_lengths else 0
        max_clause_length = max(clause_lengths) if clause_lengths else 0
        min_clause_length = min(clause_lengths) if clause_lengths else 0
        
        # Variable frequency
        variable_freq = {}
        for clause in clauses:
            for lit in clause:
                var_id = abs(lit)
                variable_freq[var_id] = variable_freq.get(var_id, 0) + 1
        
        # Variable activity (normalized based on frequency)
        max_freq = max(variable_freq.values()) if variable_freq else 1
        variable_activity = {var_id: freq / max_freq for var_id, freq in variable_freq.items()}
        
        # Problem type identification
        filename = os.path.basename(cnf_file_path)
        problem_type = self._identify_problem_type(filename)
        
        # Constraint density
        total_literals = sum(clause_lengths)
        constraint_density = total_literals / (num_variables * num_clauses) if num_variables * num_clauses > 0 else 0
        
        # Literal ratios
        total_literals_count = literal_counts['positive'] + literal_counts['negative']
        positive_literal_ratio = literal_counts['positive'] / total_literals_count if total_literals_count > 0 else 0
        negative_literal_ratio = literal_counts['negative'] / total_literals_count if total_literals_count > 0 else 0
        
        return SATFeatures(
            num_variables=num_variables,
            num_clauses=num_clauses,
            var_clause_ratio=var_clause_ratio,
            clause_length_distribution=clause_length_dist,
            avg_clause_length=avg_clause_length,
            max_clause_length=max_clause_length,
            min_clause_length=min_clause_length,
            variable_frequency=variable_freq,
            variable_activity=variable_activity,
            problem_type=problem_type,
            constraint_density=constraint_density,
            positive_literal_ratio=positive_literal_ratio,
            negative_literal_ratio=negative_literal_ratio
        )
    
    def _identify_problem_type(self, filename: str) -> str:
        """Identify problem type"""
        for problem_type, pattern in self.problem_type_patterns.items():
            if re.search(pattern, filename, re.IGNORECASE):
                return problem_type
        return "UNKNOWN"
    
    def extract_features_from_directory(self, data_dir: str) -> Dict[str, SATFeatures]:
        """Extract features from all CNF files in directory"""
        features_dict = {}
        
        for filename in os.listdir(data_dir):
            if filename.endswith('.cnf'):
                file_path = os.path.join(data_dir, filename)
                try:
                    features = self.extract_features_from_cnf(file_path)
                    features_dict[filename] = features
                except Exception as e:
                    print(f"Error extracting features from {filename}: {e}")
        
        return features_dict
    
    def save_features_to_json(self, features_dict: Dict[str, SATFeatures], output_path: str):
        """Save features to JSON file"""
        # Convert dataclass to dictionary
        features_data = {}
        for filename, features in features_dict.items():
            features_data[filename] = {
                'num_variables': features.num_variables,
                'num_clauses': features.num_clauses,
                'var_clause_ratio': features.var_clause_ratio,
                'clause_length_distribution': features.clause_length_distribution,
                'avg_clause_length': features.avg_clause_length,
                'max_clause_length': features.max_clause_length,
                'min_clause_length': features.min_clause_length,
                'variable_frequency': features.variable_frequency,
                'variable_activity': features.variable_activity,
                'problem_type': features.problem_type,
                'constraint_density': features.constraint_density,
                'positive_literal_ratio': features.positive_literal_ratio,
                'negative_literal_ratio': features.negative_literal_ratio,
                'conflict_rate': features.conflict_rate,
                'learning_rate': features.learning_rate,
                'restart_frequency': features.restart_frequency
            }
        
        with open(output_path, 'w') as f:
            json.dump(features_data, f, indent=2)
    
    def load_features_from_json(self, json_path: str) -> Dict[str, SATFeatures]:
        """Load features from JSON file"""
        with open(json_path, 'r') as f:
            features_data = json.load(f)
        
        features_dict = {}
        for filename, data in features_data.items():
            features_dict[filename] = SATFeatures(
                num_variables=data['num_variables'],
                num_clauses=data['num_clauses'],
                var_clause_ratio=data['var_clause_ratio'],
                clause_length_distribution=data['clause_length_distribution'],
                avg_clause_length=data['avg_clause_length'],
                max_clause_length=data['max_clause_length'],
                min_clause_length=data['min_clause_length'],
                variable_frequency=data['variable_frequency'],
                variable_activity=data['variable_activity'],
                problem_type=data['problem_type'],
                constraint_density=data['constraint_density'],
                positive_literal_ratio=data['positive_literal_ratio'],
                negative_literal_ratio=data['negative_literal_ratio'],
                conflict_rate=data.get('conflict_rate'),
                learning_rate=data.get('learning_rate'),
                restart_frequency=data.get('restart_frequency')
            )
        
        return features_dict
    
    def get_feature_summary(self, features_dict: Dict[str, SATFeatures]) -> Dict:
        """Get feature statistics summary"""
        if not features_dict:
            return {}
        
        problem_types = [f.problem_type for f in features_dict.values()]
        var_clause_ratios = [f.var_clause_ratio for f in features_dict.values()]
        avg_clause_lengths = [f.avg_clause_length for f in features_dict.values()]
        constraint_densities = [f.constraint_density for f in features_dict.values()]
        
        return {
            'total_problems': len(features_dict),
            'problem_type_distribution': {pt: problem_types.count(pt) for pt in set(problem_types)},
            'var_clause_ratio_stats': {
                'mean': np.mean(var_clause_ratios),
                'std': np.std(var_clause_ratios),
                'min': np.min(var_clause_ratios),
                'max': np.max(var_clause_ratios)
            },
            'avg_clause_length_stats': {
                'mean': np.mean(avg_clause_lengths),
                'std': np.std(avg_clause_lengths),
                'min': np.min(avg_clause_lengths),
                'max': np.max(avg_clause_lengths)
            },
            'constraint_density_stats': {
                'mean': np.mean(constraint_densities),
                'std': np.std(constraint_densities),
                'min': np.min(constraint_densities),
                'max': np.max(constraint_densities)
            }
        } 