from multiprocessing import Manager, Lock, Pool, cpu_count
from tqdm import tqdm
import os
import gc
from collections import Counter
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
from threading import Thread
from nltk.util import ngrams
from nltk.lm import NgramCounter
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from os import listdir
from os.path import isfile, join
import time

def calculate_ngrams(tokens_list, n=2):
    return ngrams(tokens_list, n)

def process_tokens(arr, ngram_arity):
    return calculate_ngrams(arr, ngram_arity)

def extract_ngrams_from_ngram_counter(ngram_counter, ngram_arity=3):
    flat_counter = Counter()

    for ngram in tqdm(ngram_counter[ngram_arity+1], desc="Converting to counter", total=len(ngram_counter[ngram_arity+1])):
        flat_counter[ngram] = list(ngram_counter[ngram_arity+1][ngram].values())[0]
    return flat_counter


def process_tokens_batch(arrays, ngram_arity):
    counter = Counter()
    for arr in arrays:
        ngrams_generated = calculate_ngrams(arr, ngram_arity)
        counter.update(ngrams_generated)
    return counter

def process_batch_weird(dataset, column_name='tokens', ngram_arity=3, inner_workers=50):
    column_data = dataset.read(columns=[column_name])[column_name]
    tokens_all = pa.compute.list_flatten(column_data).to_numpy()
    # Divide tokens_all into chunks for each thread
    chunk_size = len(tokens_all) // inner_workers + (len(tokens_all) % inner_workers > 0)
    chunks = [tokens_all[i:i + chunk_size] for i in range(0, len(tokens_all), chunk_size)]

    counter_all = Counter()
    with ThreadPoolExecutor(max_workers=inner_workers) as executor:
        futures = [executor.submit(process_tokens_batch, chunk, ngram_arity) for chunk in chunks]
        with tqdm(total=len(futures), desc="Processing tokens") as progress:
            for future in as_completed(futures):
                counter_all.update(future.result())
                progress.update(1)

    del tokens_all, column_data
    gc.collect()
    return counter_all

def process_batch(dataset, column_name='tokens', ngram_arity=3, inner_workers=10): # 50
    # Convert the Parquet dataset to a Pandas DataFrame

    # Iterate over rows in the DataFrame
    counter_all = Counter()

    column_data = dataset.read(columns=[column_name])[column_name]
    tokens_all = pa.compute.list_flatten(column_data).to_numpy()

    #for arr in tqdm(tokens_all, total=len(tokens_all), desc="Processing batch"):
    #    ngram_counter.update(calculate_ngrams(arr, ngram_arity))

    with ThreadPoolExecutor(max_workers=inner_workers) as executor:
        # Submit all tasks and create a list of futures.
        futures = {executor.submit(process_tokens, arr, ngram_arity): arr for arr in tokens_all}

        # Initialize a tqdm progress bar to track completion of futures.
        with tqdm(total=len(futures), desc="Processing tokens") as progress:
            # Iterate over futures as they complete.
            for future in as_completed(futures):
                # Update the ngram_counter with the result from each completed future.
                counter_all.update(future.result())
                # Update the progress bar each time a future is completed.
                progress.update(1)

    #for arr in tqdm(tokens_all, total=len(tokens_all), desc="Processing batch"):
    #    counter_all.update(calculate_ngrams(arr, ngram_arity))

    # Clean up to free memory
    print('before gc')
    del tokens_all, column_data
    gc.collect()  # Explicit call to garbage collector
    print('after gc')
    #print('tokens counter', len(tokens_all), sum(ngram_counter.values()))
    return counter_all

def process_batch_old(batch, ngram_arity=3):
    df = batch.to_pandas()
    df['tokens'] = df['tokens'].apply(lambda x: x[0].tolist())
    return Counter(calculate_ngrams(df['tokens'].sum(), ngram_arity))

def write_intermediate_results(counter, source, result_dir):
    batch_df = pd.DataFrame(list(counter.items()), columns=['ngram', 'count'])
    file_path = os.path.join(result_dir, f"{source}.parquet")
    table = pa.Table.from_pandas(batch_df)
    if os.path.exists(file_path):
        with pq.ParquetWriter(file_path, table.schema, compression='snappy', flavor='spark', append=True) as writer:
            writer.write_table(table)
    else:
        pq.write_table(table, file_path)

def process_file(args):
    #file_path, write_lock, threshold, counters, merge_counter, source, id_source, result_dir = args
    file_path, write_lock, threshold, counters, source, id_source, result_dir, arity = args
    #print('file_path')
    #parquet_file = pq.ParquetFile(file_path)
    print('reading dataset', file_path, id_source, source)
    batch = pq.ParquetDataset(file_path)
    #local_counter = Counter()

    #batch_id = 0
    #for batch in parquet_file.iter_batches(batch_size=int(1e10)): # 1e6
    #assert batch_id < 1 # we do everything in one batch, because they are small enough
    #batch_id += 1
    local_counter = process_batch(batch, ngram_arity=arity)
    print('processed batch')
    del batch
    print('deleted batch')
    """
    with write_lock:
        num_counters = len(counters)
    while num_counters >= 2:
        to_merge = False
        with write_lock:
            if len(counters) >= 2:
                c1 = counters.pop(0)
                c2 = counters.pop(0)
                to_merge = True
        if to_merge:
            c1.update(c2)
            with write_lock:
                counters.append(c1)
                num_counters = len(counters)
                merge_counter.value += 1
            print('merged', sum(c1.values()))
    """
    print('appending counter')
    #merge_counter.value += 1
    print('saving counter', id_source)
    if sum(local_counter.values()) > 0:
        write_intermediate_results(local_counter, f"{source}_{id_source}", result_dir)
    print('counter saved', id_source)

    #return local_counter
    #with write_lock:
    #    counters.append(local_counter)
    #local_counter.clear()
    #print('num counters', len(counters))
    #assert not local_counter
    #with write_lock:
    #    len_counters = len(counters)
    #print('num counters', len_counters)
    #print('num counters', len_counters)

def update_progress_bar(progress_bar, merge_counter):
    while not merge_counter.ready:
        progress_bar.update(merge_counter.value - progress_bar.n)
        time.sleep(0.1)
    progress_bar.close()

def file_already_processed(result_dir, source, id_source):
    expected_output_file = os.path.join(result_dir, f"{source}_{id_source}.parquet")
    return os.path.exists(expected_output_file)


def main():
    _SOURCE_DIR = 'tokenized_data/rest/'


    manager = Manager()
    counters = manager.list()
    write_lock = manager.Lock()

    for arity in [4]: #[1, 8, 16, 32, 64]: #[4, 5, 6, 7, 8]: # = 4 #2 #3
        threshold = 500000
        pool_size = 10 #15 #int(cpu_count()*0.8) # 5 #3 #3 #15 # 4# 3#50 #cpu_count() # 10
        #merge_counter = manager.Value('i', 0)
        #merge_counter.ready = False  # Custom flag to indicate when to stop the thread

        file_group = {}

        sources_set = set()
        for selected_source in ['source=gutenberg']:
        #for selected_source in ['source=starcoder', 'source=falcon-refinedweb%2Fdata', 'source=reddit', 'source=s2', 'source=redpajama%2Farxiv', 'source=redpajama%2Fstackexchange', 'source=gutenberg', 'source=megawika']:
        #for selected_source in ['source=redpajama%2Farxiv',
        #                        'source=redpajama%2Fstackexchange', 'source=gutenberg', 'source=megawika']:
            try:
                #selected_source = 'source=s2' #'source=reddit' #'source=megawika' #'source=wikipedia' #'source=falcon-refinedweb%2Fdata'#'source=starcoder' #'source=gutenberg' #'source=gutenberg' #'source=redpajama%2Farxiv' #'source=falcon-refinedweb%2Fdata' #'source=starcoder' #'source=gutenberg' #'source=redpajama%2Farxiv'
                result_dir = f'ngrams_results_final/{selected_source}/arity={arity}'

                if os.path.exists(result_dir):
                    # Process files in result_dir
                    processed_ids = [f.split('.par')[0].split('_')[1] for f in listdir(result_dir) if
                                     isfile(join(result_dir, f))]
                    # Continue processing
                    print('processed files', len(processed_ids), sorted(processed_ids))

                else:
                    # Handle the case where result_dir does not exist
                    print(f"Directory '{result_dir}' does not exist.")
                    processed_ids = []


                for root, dirs, files in os.walk(_SOURCE_DIR):
                    for file in files:
                        if file.endswith('.parquet'):
                            source = os.path.basename(os.path.dirname(os.path.join(root, file)))
                            id_source = root.split('/')[-2]
                            if source == selected_source \
                                    and not file_already_processed(result_dir, source, id_source) \
                                    and (source not in file_group or root not in file_group[source]):
                                print(source, root)
                                if source not in file_group:
                                    file_group[source] = []
                                if root.split('/')[-2] not in processed_ids:
                                    file_group[source].append(root)

                                #break
                            else:
                                sources_set.add(source)
                print(sources_set)


                for source, files in file_group.items():
                    print('only 1 firs files')
                    files = sorted(files)[:1]
                    print('source', source)
                    print('files', files)
                    if source == selected_source:
                        print(f'Processing source: {source} with arity {arity}', len(files))
                        """
                        file_args = [
                            (file, write_lock, threshold, counters,
                             merge_counter, source, file.split('/')[-2], result_dir)
                            for file in files]
                        """
                        os.makedirs(result_dir, exist_ok=True)

                        file_args = [
                            (file, write_lock, threshold, counters, source, file.split('/')[-2], result_dir, arity)
                            for file in files]
                        #merge_progress = tqdm(desc="Processing Files 2", total=len(file_args))  # You need to estimate total merges
                        #progress_thread = Thread(target=update_progress_bar, args=(merge_progress, merge_counter))
                        #progress_thread.start()

                        with Pool(pool_size) as pool:
                            _ = list(tqdm(pool.imap_unordered(process_file, file_args), total=len(file_args), desc="Processing Files"))

                        #results = list(tqdm(pool.imap_unordered(process_file, file_args), total=len(file_args), desc="Processing Files"))
                        #merge_counter.ready = True  # Indicate that merging is complete
                        #progress_thread.join()  # Wait for the progress bar thread to finish
                        # Final merging step
                        #print('counters left', len(counters))
                        #final_counter = Counter()

                        #for counter in tqdm(results, total=len(results), desc="Merging all counters"):
                        #    final_counter.update(counter)
                        #while len(counters) > 0:
                        #    with write_lock:
                        #        final_counter.update(counters.pop(0))

            except Exception as err:
                print(str(err))
if __name__ == "__main__":
    main()
