from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
import re

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

@dataclass
class TwoStageDataCollator:
    """
    Data collator for two-stage training
    In stage 1, coefficients are replaced with [C], in stage 2 they are used as is
    """
    tokenizer: PreTrainedTokenizerBase
    padding: bool = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"
    stage: int = 1  # Current stage (1 or 2)
    mask_target_only: bool = False  # Whether to mask only the target
    
    def __post_init__(self):
        self.coefficient_pattern = re.compile(r'C\d+')
        self.mask_token = "[C]"
        
    def set_stage(self, stage: int):
        """Set the stage"""
        assert stage in [1, 2], "Stage must be 1 or 2"
        self.stage = stage
        
    def mask_coefficients(self, text: str) -> str:
        """Replace coefficient tokens (C123 etc.) with [C]"""
        return self.coefficient_pattern.sub(self.mask_token, text)
    
    def __call__(self, features: List[Dict[str, Union[str, List[int]]]]) -> Dict[str, torch.Tensor]:
        # For stage 1, mask coefficients
        if self.stage == 1:
            processed_features = []
            for feature in features:
                processed_feature = feature.copy()
                
                # Process input text (only if mask_target_only=False)
                if not self.mask_target_only and 'input' in feature:
                    processed_feature['input'] = self.mask_coefficients(feature['input'])
                
                # Process target text
                if 'target' in feature:
                    processed_feature['target'] = self.mask_coefficients(feature['target'])
                
                processed_features.append(processed_feature)
            features = processed_features
        
        # Tokenize text
        batch = {}
        
        # Tokenize input text
        if 'input' in features[0]:
            inputs = [f['input'] for f in features]
            inputs_encodings = self.tokenizer(
                inputs,
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors=self.return_tensors,
                truncation=True
            )
            batch.update({
                'input_ids': inputs_encodings['input_ids'],
                'attention_mask': inputs_encodings['attention_mask']
            })
        
        # Tokenize target text
        if 'target' in features[0]:
            targets = [f['target'] for f in features]
            targets_encodings = self.tokenizer(
                targets,
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors=self.return_tensors,
                truncation=True
            )
            # Shift for decoder input
            labels = targets_encodings['input_ids'].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch['labels'] = labels
            batch['decoder_attention_mask'] = targets_encodings['attention_mask']
        
        # Add other features if they exist
        for key in features[0]:
            if key not in ['input', 'target'] and key in features[0]:
                batch[key] = torch.tensor([f[key] for f in features])
        
        return batch