import pickle
from collections import defaultdict
import os
from watermark_stealing.server import Server
from tqdm.auto import tqdm
from datasets import load_dataset


class NGramCounter:
    def __init__(self, tokenizer, ordered: bool = False):
        self.ngram_counts = defaultdict(int)
        self.tokenizer = tokenizer
        self.total_ngrams = 0
        self.ordered = ordered

    def _generate_ngrams(self, text, n):
        tokens = self.tokenizer.encode(text, add_special_tokens=False)
        for i in range(len(tokens) - n + 1):
            if self.ordered:
                yield tuple(tokens[i : i + n])
            else:
                yield tuple(sorted(tokens[i : i + n]))

    def add_text(self, text, n):
        for ngram in self._generate_ngrams(text, n):
            self.ngram_counts[ngram] += 1
            self.total_ngrams += 1

    def get_ngram_count(self, ngram):
        if not self.ordered:
            sorted_ngram = tuple(sorted(ngram))
        else:
            sorted_ngram = ngram
        return self.ngram_counts.get(sorted_ngram, 1)

    def save_data(self, filename):
        with open(filename, "wb") as f:
            pickle.dump(self.ngram_counts, f)

    @staticmethod
    def load_data(filename, ordered=False, tokenizer=None):
        with open(filename, "rb") as f:
            ngram_counts = pickle.load(f)
        obj = NGramCounter(tokenizer=tokenizer, ordered=ordered)
        obj.ngram_counts = ngram_counts
        return obj


def generate_ngrams_counter(cfg, n_gram_value, ordered, data_path):
    n_tokens = 1000000000

    cfg.server.model.skip = True
    server = Server(cfg.meta, cfg.server)
    tokenizer = server.model.tokenizer

    n_gram_counter = NGramCounter(tokenizer, ordered=ordered)

    # Progress bar and counting
    pbar = tqdm(total=n_tokens)
    count = 0

    en_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)

    # Main processing loop
    for i, example in enumerate(en_dataset):
        text = example["text"]
        n_gram_counter.add_text(text, n_gram_value)
        pbar.update(n_gram_counter.total_ngrams - count)
        count = n_gram_counter.total_ngrams

        if count > n_tokens:
            break

    pbar.close()
    # Save token occurrences
    os.makedirs(os.path.dirname(data_path), exist_ok=True)
    n_gram_counter.save_data(data_path)


def load_ngram_counter_from_cfg(cfg, n_gram_value, ordered=False):
    model_name = cfg.server.model.name.replace("/", "_")
    suffix = "_ordered" if ordered else ""
    output_path = f"data/token_occurrence/{model_name}"
    data_path = f"{output_path}/{n_gram_value}_ngram_counter{suffix}.pkl"

    if not os.path.exists(data_path):
        print(f"Data not found at {data_path}, creating new counter")
        generate_ngrams_counter(cfg, n_gram_value, ordered=ordered, data_path=data_path)

    return NGramCounter.load_data(data_path, ordered=ordered)
