import re
import math
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.util import ngrams
from collections import Counter
from typing import Protocol, Set, Literal, List
from src.datasets.bias_dataset import BiasDataset


class CandidateKeywordSelector(Protocol):
    def select_candidate_keywords(self, bias_dataset: BiasDataset) -> Set[str]:
        ...

class SimpleCandidateKeywordSelector(CandidateKeywordSelector):
    def __init__(self,
                 tokenization: Literal['space', 'nltk'] = 'nltk',
                 do_lower_case: bool = True,
                 stopwords: str | None = 'english',
                 remove_punctuation: bool = True,
                 max_ngram_size: int = 2,
                 min_freq: float = 0.15
                 ) -> None:
        super().__init__()
        assert tokenization in ('space', 'nltk')

        self.tokenization = tokenization
        self.do_lower_case = do_lower_case
        self.stopwords = stopwords
        self.remove_punctuation = remove_punctuation
        self.max_ngram_size = max_ngram_size
        self.min_freq = min_freq

        self.punkt_regex = re.compile(r'[^a-zA-Z0-9\s]')
    
    def _process_text(self, text: str) -> List[str]:
        if self.tokenization == 'nltk':
            words = word_tokenize(text)
        else:
            words = text.split()

        if self.do_lower_case:
            words = [w.lower() for w in words]

        if self.remove_punctuation:
            words = [self.punkt_regex.sub('', w) for w in words]
            words = [w for w in words if w]

        if self.stopwords is not None:
            stop_words = set(stopwords.words(self.stopwords))
            words = [w for w in words if w not in stop_words]
        
        return words
    
    def _generate_ngrams(self, words: List[str]) -> List[str]:
        all_ngrams = []
        for n in range(self.max_ngram_size, 0, -1):
            n_grams = list(ngrams(words, n))
            all_ngrams.extend(' '.join(ng) for ng in n_grams)
        return all_ngrams

    def select_candidate_keywords(self, bias_dataset: BiasDataset) -> List[Counter[str]]:
        assert bias_dataset.return_captions
        
        all_keywords = [Counter() for _ in range(len(bias_dataset))]
        min_freq_int = []

        for bias_idx, bias in enumerate(bias_dataset): # type: ignore
            captions = bias.correctly_classified + bias.incorrectly_classified
            min_freq_int.append(math.ceil(len(captions) * self.min_freq))

            for caption in captions:
                assert isinstance(caption, str)

                words = self._process_text(caption)
                all_keywords[bias_idx].update(set(self._generate_ngrams(words)))
        
        result = []

        for i, keywords_for_bias in enumerate(all_keywords):
            result.append(Counter({k: v for k, v in keywords_for_bias.items() if v >= min_freq_int[i]}))

        return result
