import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from dask.distributed import Client
import logging
from time import time
import dask.dataframe as dd

def merge_two_files(file1, file2, output_file, round_number):
    try:
        # Read tables
        table1 = pq.read_table(file1)
        table2 = pq.read_table(file2)

        # Convert to Pandas DataFrame
        df1 = table1.to_pandas()
        df2 = table2.to_pandas()

        # Convert 'ngram' entries to tuples to ensure they are hashable
        df1['ngram'] = df1['ngram'].apply(tuple)
        df2['ngram'] = df2['ngram'].apply(tuple)

        # Combine and group by 'ngram', summing 'count'
        combined_df = pd.concat([df1, df2])
        result_df = combined_df.groupby('ngram', as_index=False).agg({'count': 'sum'})

        # Convert back to Parquet
        result_table = pa.Table.from_pandas(result_df)
        pq.write_table(result_table, output_file)

        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)

    except Exception as e:
        print(f"Failed to merge {file1} and {file2}: {e}")

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=3) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i+1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def merge_all_files(files, output_dir):
    # Setup Dask client to enable the dashboard

    a = time()
    print('reading files')
    ddf = dd.read_parquet(files, engine='pyarrow')
    print('converting to str', time()-a)
    a = time()
    ddf['ngram'] = ddf['ngram'].apply(lambda x: tuple(x), meta=('ngram', 'object'))
    print('computing groupby', time()-a)
    a = time()
    result = ddf.groupby('ngram')['count'].sum().compute()
    print('finished groupby', time() - a)
    result.to_frame().to_parquet(f"{output_dir}/final_ngram_counts_dask.parquet", engine='pyarrow')
def main():
    #arity = 3
    for arity in [1]:
        source = 'source=starcoder' #'source=redpajama%2Farxiv'
        directory = f'ngrams_results_final/{source}/arity={arity}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.parquet') and 'final_ngram_counts' not in f]

        #while len(files) > 1:
        print(f"Starting merges, {len(files)} files remaining...")
        #files = parallel_merge(files, directory, round_number)
        merge_all_files(files, directory)

if __name__ == "__main__":
    main()
