import os
import pandas as pd
from collections import Counter
from nltk.util import ngrams
from multiprocessing import Pool, cpu_count
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
from functools import reduce

def calculate_ngrams(tokens_list, n=2):
    """Generate n-grams from a list of tokens."""
    return Counter(ngrams(tokens_list, n))

def process_batches(batches, ngram_arity):
    """Process a list of batches and return the n-grams as a Counter."""
    total_counter = Counter()
    for batch in batches:
        df = batch.to_pandas()
        tokens = [token for sublist in df['tokens'].apply(lambda x: x[0].tolist()) for token in sublist]
        total_counter.update(calculate_ngrams(tokens, ngram_arity))
    return total_counter

def process_files(file_list, ngram_arity):
    """Process a list of files for a specific source using multiprocessing and reduce using a tree-based approach."""
    parquet_files = [pq.ParquetFile(file) for file in file_list]
    batches = [file.iter_batches(batch_size=1000000) for file in parquet_files]
    with Pool(cpu_count()) as pool:
        counters = pool.starmap(process_batches, [(batch, ngram_arity) for batch in batches])
    # Reduce counters using a tree-based approach
    while len(counters) > 1:
        if len(counters) % 2 == 1:  # if odd number of counters, keep the last one aside
            extra = [counters.pop()]
        else:
            extra = []
        # Prepare pairs of counters for parallel reduction
        paired_counters = [(counters[i], counters[i+1]) for i in range(0, len(counters), 2)]
        with Pool(cpu_count()) as pool:
            counters = pool.starmap(merge_two_counters, paired_counters) + extra
    return counters[0]

def merge_two_counters(counter1, counter2):
    """Merge two counters."""
    counter1.update(counter2)
    return counter1

def main():
    source_dir = 'tokenized_data/rest'
    result_dir = 'ngrams_results_final'
    os.makedirs(result_dir, exist_ok=True)
    ngram_arity = 3

    # Discover all sources based on directory structure and group files by source
    source_files = defaultdict(list)
    for root, dirs, files in os.walk(source_dir):
        for file in files:
            if file.endswith('.parquet'):
                source = root.split(os.sep)[-1].split('=')[1]
                source_files[source].append(os.path.join(root, file))

    # Process each source using multiprocessing and tree-based reduction
    final_results = {}
    for source, files in tqdm(source_files.items(), desc="Processing Sources"):
        final_results[source] = process_files(files, ngram_arity)

    # Write final results to files
    for source, counter in final_results.items():
        final_df = pd.DataFrame(list(counter.items()), columns=['ngram', 'count'])
        final_table = pa.Table.from_pandas(final_df)
        pq.write_table(final_table, os.path.join(result_dir, f"{source}.parquet"))

if __name__ == "__main__":
    main()
