import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
import vaex
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from time import time

def transform_and_load(file_path):
    # Load the dataset using PyArrow
    table = pq.read_table(file_path)
    # Convert to Pandas DataFrame for manipulation
    df = table.to_pandas()
    # Apply transformation from ListScalar to string
    df['ngram'] = df['ngram'].apply(lambda x: ','.join(map(str, x)))
    # Immediately convert to Vaex DataFrame without saving to disk
    vaex_df = vaex.from_pandas(df, copy_index=False)
    return vaex_df

def groupby_vaex(file1, file2):
    # Transform and load files into Vaex DataFrames
    df1 = transform_and_load(file1)
    df2 = transform_and_load(file2)

    # Concatenate Vaex DataFrames
    combined_df = df1.concat(df2)

    # Perform groupby and aggregation in Vaex
    result_df = combined_df.groupby(by='ngram', agg={'count': vaex.agg.sum('count')})

    return result_df

def merge_two_files(file1, file2, output_file, round_number):
        result_df = groupby_vaex(file1, file2)

        # Convert back to Parquet
        result_df.export_parquet(output_file)

        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)


def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=5) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i+1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def main():
    #arity = 3
    for arity in [1]:
        source = 'source=starcoder' #'source=gutenberg' #'source=starcoder' #'source=redpajama%2Farxiv'
        directory = f'ngrams_results_final/{source}/arity={arity}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.parquet') and 'final_ngram_counts' not in f]
        round_number = 0
        if round_number > 0:
            files = [file for file in files if f'round_{round_number-1}' in file]

        a = time()
        while len(files) > 1:
            print(f"Starting round {round_number} of merges, {len(files)} files remaining...")
            files = parallel_merge(files, directory, round_number)
            round_number += 1
        print('Time processed', time()-a)
        if files:
            # Rename the final file to the final output
            final_output_file = f'{directory}/final_ngram_counts_vaex.parquet'
            shutil.move(files[0], final_output_file)
            print(f"Done. Final results written to: {final_output_file}")

if __name__ == "__main__":
    main()
