"""
Solver for functional group tasks:
The solver identifies functional groups provided as SMARTS patterns in a given molecule
and returns count and index values.
"""

#---------------------------------------------------------------------------------------
# Config
from dataclasses import dataclass
from pathlib import Path
@dataclass
class FunctionalGroupSolverConfig:
    smarts_patterns_path: str = Path(__file__).parent / "smarts_functional_groups.txt"

#---------------------------------------------------------------------------------------
# Imports
from typing import List, Tuple, Dict
from rdkit import Chem
#---------------------------------------------------------------------------------------
# Class definitions

class FunctionalGroupSolver:
    def __init__(self, 
                 config: FunctionalGroupSolverConfig=FunctionalGroupSolverConfig()):
        self.config = config
        
        # Initialize and load functional groups
        (self.smarts_patterns, self.names
         ) = self._load_smarts_patterns(config.smarts_patterns_path)
        
        # Pre-compile SMARTS patterns
        self.compiled_patterns = {}
        self._compile_patterns()
        
    def get_counts_and_indices(self, smiles: str) -> Dict[str, Dict[str, any]]:
        """
        Get both counts and indices for all functional groups.
        """
        results = {}
        
        # Parse molecule once
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return {name: {'indices': [], 'count': 0, 'instances': 0} 
                    for name in self.compiled_patterns}
        
        # Process all patterns
        for fg_name, patterns in self.compiled_patterns.items():
            try:
                # Collect all atom indices and count instances for this functional group
                atom_indices = set()
                total_instances = 0
                
                # Check each pattern (some FGs have multiple SMARTS)
                for pattern in patterns:
                    matches = mol.GetSubstructMatches(pattern)
                    # Count instances
                    total_instances += len(matches)
                    # Flatten to unique atom indices
                    for match in matches:
                        atom_indices.update(match)
                
                indices_list = sorted(list(atom_indices))
                results[f"functional_group_{fg_name}_count"] = len(indices_list)
                results[f"functional_group_{fg_name}_index"] = indices_list
                results[f"functional_group_{fg_name}_nbrInstances"] = total_instances
                #results[fg_name] = {
                #    'indices': indices_list,
                #    'count': len(indices_list),  # Number of atoms
                #    'instances': total_instances  # Number of functional group instances
                #}
                
            except Exception as e:
                print(f"Error processing {fg_name} in molecule {smiles}: {e}")
            #    results[fg_name] = {'indices': [], 'count': 0, 'instances': 0}
        
        return results

    #-----------------------------------------------------------------------------------
    # General utility methods
    def _load_smarts_patterns(self, smarts_file: str):
        """Load SMARTS patterns."""

        smarts_patterns = {}
        names = []

        with open(smarts_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    parts = line.split(':')
                    if len(parts) >= 3:
                        name = parts[0].strip()
                        smarts = parts[2].strip()
                        smarts_patterns[name] = smarts
                        names.append(name)

        return smarts_patterns, names

    def get_functional_groups(self) -> Tuple[dict, List[str]]:
        """Return the a dictionary of functional groups and a list of functional group 
        names."""
        return self.smarts_patterns, self.names

    def _compile_patterns(self):
        """Pre-compile all SMARTS patterns for efficiency."""

        for name, smarts_str in self.smarts_patterns.items():
            # Compile SMARTS string to Mol object
            if smarts_str is not None:
                pattern_mol = Chem.MolFromSmarts(smarts_str)
                if pattern_mol is not None:
                    self.compiled_patterns[name] = [pattern_mol]  # Store as list for consistency
                else:
                    print(f"Warning: Could not compile SMARTS for {name}: {smarts_str}")
                    self.compiled_patterns[name] = []
    




