import string
from tqdm import tqdm
from datasets import load_dataset
import pickle


def load_dataset_and_filter(data_path):
    def is_valid_char(char, language='en'):
        """
        Check if a character is valid (alphabet, digit, or punctuation).
        
        Args:
            char: Character to check
            language: Language code ('en', 'es', 'it')
        
        Returns:
            Boolean indicating if character is valid
        """
        # Allow digits
        if char.isdigit():
            return True
        
        # Allow common punctuation and whitespace
        if char in string.punctuation or char.isspace():
            return True
        
        # Define allowed alphabet characters per language
        if language == 'en':
            # English: only basic Latin alphabet
            return 'a' <= char.lower() <= 'z'
        
        elif language == 'es':
            # Spanish: basic Latin + Spanish-specific chars
            spanish_chars = 'áéíóúüñ¿¡'
            return ('a' <= char.lower() <= 'z') or (char.lower() in spanish_chars)
        
        elif language == 'it':
            # Italian: basic Latin + Italian-specific chars
            italian_chars = 'àèéìíîòóùú'
            return ('a' <= char.lower() <= 'z') or (char.lower() in italian_chars)
        
        return False


    def filter_text(text, language='en'):
        """
        Filter text to keep only valid characters for the given language.
        
        Args:
            text: Input text string
            language: Language code ('en', 'es', 'it')
        
        Returns:
            Boolean indicating if text contains only valid characters
        """
        if not text or len(text.strip()) == 0:
            return False
        
        # Check if all characters are valid
        for char in text:
            if not is_valid_char(char, language):
                return False
        
        return True

    dataset = load_dataset(
        "MoritzLaurer/multilingual-NLI-26lang-2mil7"
    )
    
    samples = {
        'en': [],
        'es': [],
        'it': []
    }
    
    for ds_name, ds in tqdm(dataset.items(), desc='Iterate over datasets'):
        if ds_name.startswith('it'):
            lang = 'it'
        elif ds_name.startswith('es'):
            lang = 'es'
        else:
            continue
        for sample in tqdm(ds, desc=f'Iterate over {ds_name}', leave=False):
            premise = sample['premise']
            hypothesis = sample['hypothesis']
            
            if filter_text(premise, lang):
                samples[lang].append(premise)
                
            if filter_text(hypothesis, lang):
                samples[lang].append(hypothesis)
            
            premise_orig = sample['premise_original']
            hypothesis_orig = sample['hypothesis_original']
            
            if filter_text(premise_orig, 'en'):
                samples['en'].append(premise_orig)
                
            if filter_text(hypothesis_orig, 'en'):
                samples['en'].append(hypothesis_orig)
        
    
    final_counts = {lang: len(texts) for lang, texts in samples.items()}
    print(f"Counts: {final_counts}")

    with open(f'{data_path}', 'wb') as handle:
        pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    return samples

class CharacterProcessor:
    def __init__(self, texts):
        char_counts = {}
        for text in texts:
            for char in text:
                char_counts[char] = char_counts.get(char, 0) + 1
        
        sorted_chars = sorted(char_counts.items(), key=lambda x: x[1], reverse=True)

        self.char_to_idx = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<SOS>': 2,
            '<EOS>': 3
        }
        
        for char, _ in sorted_chars:
            if char not in self.char_to_idx:
                self.char_to_idx[char] = len(self.char_to_idx)

        # reverse mapping
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)
        print("Character vocabulary size:", self.vocab_size)
        print(self.char_to_idx)
    
    def encode(self, text, add_special_tokens=False):
        indices = []
        
        if add_special_tokens:
            indices.append(self.char_to_idx['<SOS>'])
        
        for char in text:
            indices.append(self.char_to_idx.get(char, self.char_to_idx['<UNK>']))
        
        if add_special_tokens:
            indices.append(self.char_to_idx['<EOS>'])
        
        return indices
    

    def decode(self, indices, skip_special_tokens=True):
        chars = []
        special_tokens = {'<PAD>', '<UNK>', '<SOS>', '<EOS>'}
        
        for idx in indices:
            if idx in self.idx_to_char:
                char = self.idx_to_char[idx]
                if not (skip_special_tokens and char in special_tokens):
                    chars.append(char)
        
        return ''.join(chars)