import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from time import time

def merge_two_files(file1, file2, output_file, round_number):
    try:
        # Read tables
        table1 = pq.read_table(file1)
        table2 = pq.read_table(file2)

        # Convert to Pandas DataFrame
        df1 = table1.to_pandas()
        df2 = table2.to_pandas()

        # Convert 'ngram' entries to tuples to ensure they are hashable
        df1['ngram'] = df1['ngram'].apply(tuple)
        df2['ngram'] = df2['ngram'].apply(tuple)

        # Create counters from the dataframes
        counter1 = Counter(dict(zip(df1['ngram'], df1['count'])))
        counter2 = Counter(dict(zip(df2['ngram'], df2['count'])))

        # Combine and group by 'ngram', summing 'count'
        counter1.update(counter2)
        result_df = pd.DataFrame(list(counter1.items()), columns=['ngram', 'count'])

        # Convert DataFrame to Parquet Table and save
        result_table = pa.Table.from_pandas(result_df)
        pq.write_table(result_table, output_file)

        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)

    except Exception as e:
        print(f"Failed to merge {file1} and {file2}: {e}")

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=3) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i+1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def main():
    #arity = 3
    for arity in [1]:
        source = 'source=starcoder' #'source=starcoder' #'source=redpajama%2Farxiv'
        directory = f'ngrams_results_final/{source}/arity={arity}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.parquet') and 'final_ngram_counts' not in f]
        round_number = 0
        if round_number > 0:
            files = [file for file in files if f'round_{round_number-1}' in file]

        a = time()
        while len(files) > 1:
            print(f"Starting round {round_number} of merges, {len(files)} files remaining...")
            files = parallel_merge(files, directory, round_number)
            round_number += 1
        print('Time processed', time()-a)
        if files:
            # Rename the final file to the final output
            final_output_file = f'{directory}/final_ngram_counts_check.parquet'
            shutil.move(files[0], final_output_file)
            print(f"Done. Final results written to: {final_output_file}")

if __name__ == "__main__":
    main()
