import torch
import torch.nn as nn

class RingCountClassifier(nn.Module):
    def __init__(self, tgt_vocab=None):
        """
        A simple heuristic for estimating if a molecule will be small based on current sequence.
        
        Args:
            target_size: Target number of atoms (smaller values prefer smaller molecules)
            scale: Scaling factor for the sigmoid function
            gamma: Sharpness parameter for the sigmoid
        """
        super().__init__()
        self.tgt_vocab = tgt_vocab

    def ring_count_guidance(self, partial_seq, max_rings=0, base_probability=0.9):
        """
        Count rings in partial SMILES sequences and return penalties for sequences with more than max_rings.
        
        Args:
            partial_seq: Tensor of shape (batch_size*beam_size, tgt_len) containing token indices
            vocab_list: List of tokens in the vocabulary
            max_rings: Maximum number of rings allowed before penalties apply
            penalty_per_excess_ring: Penalty value per ring exceeding max_rings
            
        Returns:
            Tensor of shape (batch_size*beam_size, vocab_size) with penalties for next token
        """
        batch_beam_size, seq_len = partial_seq.size()
        vocab_size = len(self.tgt_vocab)
        
        # Create mapping from token index to token string
        idx_to_token = {i: self.tgt_vocab.itos[i] for i in range(vocab_size)}
        
        # Initialize result tensor
        result = torch.zeros((batch_beam_size, 1), device=partial_seq.device)
            
        # Process each sequence in the batch
        for i in range(batch_beam_size):
            # Convert sequence to token strings
            tokens = [idx_to_token[idx.item()] for idx in partial_seq[i, :seq_len-1]]  # Exclude last token
            
            # Count rings using digit tracking
            ring_digits = {}
            ring_count = 0
            
            for token in tokens:
                # Check if token is a digit (1-9) or %10, %11, etc.
                if token.isdigit() or token.startswith('%') and token[1:].isdigit():
                    digit = token
                    if digit in ring_digits:
                        # Found closing digit
                        ring_count += 1
                        del ring_digits[digit]
                    else:
                        # Found opening digit
                        ring_digits[digit] = 1
            
            # Calculate penalty based on current ring count
            # Convert to log-probability

            #  consider this: log_prob = torch.log(torch.sigmoid(scale * (max_rings - ring_count)))
            if ring_count > max_rings:
                # Lower probability for sequences with too many rings
                log_prob = torch.log(torch.tensor(base_probability ** (ring_count - max_rings)))
            else:
                # High probability for sequences within ring limit
                log_prob = torch.log(torch.tensor(1.0))
            
            result[i, 0] = log_prob
        
        return result
    
    def forward_(self, partial_seq):
        return torch.zeros(partial_seq.shape[0], 1)
        
    def forward(self, partial_seq):
        """
        Apply a heuristic to predict if the molecule will be small.
        
        Args:
            input_ids: The current token sequence
            tokenizer: The tokenizer for decoding (if needed)
            
        Returns:
            Scores favoring smaller molecules
        """
        # 1. Simple token counting strategy
        # Count tokens that likely indicate atoms/bonds
        result = self.ring_count_guidance(partial_seq)
        # NOTE: code below is just to check the values of result is not 0
        if not torch.isclose(result, torch.zeros_like(result).float()).all():
            # Check if there are any NaN/Inf values
            has_nan = torch.isnan(result).any()
            has_inf = torch.isinf(result).any()

            # Get the min and max values
            min_val = result.min().item()
            max_val = result.max().item()

            # Get a reasonable epsilon based on tensor values
            eps = max(torch.finfo(result.dtype).eps, 1e-10)

            # Find all non-zero values (both positive and negative)
            non_zero_mask = ~torch.isclose(result, torch.zeros_like(result), atol=eps)
            non_zero_values = result[non_zero_mask]
            non_zero_indices = non_zero_mask.nonzero()

            # Check if there are any negative values
            negative_mask = result < -eps
            negative_count = negative_mask.sum().item()

            #print(f'ring count guidance: {result}')
        return result
    