"""Ngram LM."""
import collections
import json


class NgramLanguageModel:
    """Ngram LM."""
    def __init__(self, vocab_path, n):
        """Initialize the n-gram language model."""
        with open(vocab_path, 'r', encoding='utf-8') as f:
            self.vocab = json.load(f)
        self.n = n
        # Dictionary to store counts for n-grams of order 1 to n.
        # Each key is an integer indicating the n-gram order and the value is a Counter.
        self.ngram_counts = {order: collections.Counter() for order in range(1, n+1)}

    def _process_tokens(self, text):
        """Tokenize text by splitting on spaces and map out-of-vocab tokens to [UNK]."""
        tokens = text.split()
        processed_tokens = []
        for token in tokens:
            if token not in self.vocab:
                token = '[UNK]'
            processed_tokens.append(token)
        return processed_tokens

    def train(self, dataset):
        """Train the model on a Hugging Face dataset. It assumes each example in the dataset has a 'text' field."""
        for example in dataset:
            text = example['text']
            tokens = self._process_tokens(text)
            # Count n-grams for each order from 1 to n.
            for order in range(1, self.n + 1):
                for i in range(len(tokens) - order + 1):
                    ngram = tuple(tokens[i:i+order])
                    self.ngram_counts[order][ngram] += 1

    def infer(self, sentence, order=-1):
        """Given a sentence, predict the probability distribution of the next token using an n-gram model with Laplace smoothing. If the highest order context is not found, it falls back to a lower order."""
        tokens = self._process_tokens(sentence)
        # Determine the order to use for prediction: ideally, use n-gram (context length n-1).
        # If not enough tokens are present, use the highest possible order.
        if order == -1:
            order = min(self.n, len(tokens) + 1)
        else:
            order = min(self.n, len(tokens) + 1, order)

        # Fallback: if the context (last order-1 tokens) has never been seen, fall back to a lower order.
        while order > 0:
            if order == 1:
                context = tuple()  # No context for unigrams.
                context_count = sum(self.ngram_counts[1].values())
            else:
                context = tuple(tokens[-(order-1):])
                # Sum counts of all n-grams that start with the context.
                context_count = sum(count for ngram, count in self.ngram_counts[order].items()
                                    if ngram[:-1] == context)
            # If we have seen the context or are at unigram level, break out.
            if context_count > 0 or order == 1:
                break
            order -= 1

        vocab_size = len(self.vocab)
        probabilities = {}

        if order == 1:
            # Unigram: use the counts of single tokens.
            total_count = sum(self.ngram_counts[1].values())
            for token in self.vocab:
                count = self.ngram_counts[1].get((token,), 0)
                probabilities[token] = (count + 1) / (total_count + vocab_size)
        else:
            # For higher-order n-gram: calculate probability for each token given the context.
            for token in self.vocab:
                ngram = context + (token,)
                count_ngram = self.ngram_counts[order].get(ngram, 0)
                probabilities[token] = (count_ngram + 1) / (context_count + vocab_size)

        # Sort the probability distribution in descending order.
        sorted_probs = dict(sorted(probabilities.items(), key=lambda x: x[1], reverse=True))
        return sorted_probs

    def save(self, file_path):
        """Save the model's contents to disk as a JSON file."""
        # Convert ngram_counts to a JSON serializable format.
        serializable_counts = {}
        for order, counter in self.ngram_counts.items():
            # Convert tuple keys to a space-joined string.
            serializable_counts[str(order)] = {' '.join(ngram): count for ngram, count in counter.items()}

        model_data = {
            'vocab': self.vocab,
            'n': self.n,
            'ngram_counts': serializable_counts
        }
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(model_data, f)

    @classmethod
    def load(cls, file_path):
        """Load model contents from disk and return an instance of NgramLanguageModel."""
        with open(file_path, 'r', encoding='utf-8') as f:
            model_data = json.load(f)
        # Create an instance without calling __init__
        instance = cls.__new__(cls)
        instance.vocab = model_data['vocab']
        instance.n = model_data['n']
        instance.ngram_counts = {}
        # Convert keys back from string to tuple.
        for order_str, counts in model_data['ngram_counts'].items():
            order = int(order_str)
            instance.ngram_counts[order] = collections.Counter(
                {tuple(ngram.split(' ')): count for ngram, count in counts.items()}
            )
        return instance
