import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from time import time
import numpy as np
from merge_counters import FastCounter  # Import the updated Cython class handling strings

def serialize_ngrams(df):
    """Converts ngram tuples in the DataFrame to a serialized string form."""
    df['ngram'] = df['ngram'].apply(lambda x: ','.join(map(str, x)))

def deserialize_ngrams(df):
    """Converts serialized string ngrams back to tuple form in the DataFrame."""
    df['ngram'] = df['ngram'].apply(lambda x: tuple(map(int, x.split(','))))

def merge_two_files(file1, file2, output_file, round_number):
    # Read data
    df1 = pq.read_table(file1).to_pandas()
    df2 = pq.read_table(file2).to_pandas()

    # Serialize ngrams
    serialize_ngrams(df1)
    serialize_ngrams(df2)

    # Prepare data for Cython processing
    keys1 = df1['ngram'].tolist()
    counts1 = df1['count'].tolist()
    keys2 = df2['ngram'].tolist()
    counts2 = df2['count'].tolist()

    # Create FastCounter and process data
    fc = FastCounter(len(keys1) + len(keys2))
    fc.add(keys1, counts1)
    fc.add(keys2, counts2)

    # Get results and convert to DataFrame
    results = fc.get_results()
    result_df = pd.DataFrame(results, columns=['ngram', 'count'])

    # Deserialize ngrams
    deserialize_ngrams(result_df)

    # Write results to Parquet
    result_table = pa.Table.from_pandas(result_df)
    pq.write_table(result_table, output_file)

    if round_number > 0:
        # Optionally delete the original files to save space
        os.remove(file1)
        os.remove(file2)
    #except Exception as e:
    #    print(f"Failed to merge {file1} and {file2}: {e}")

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=3) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i+1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def main():
    #arity = 3
    for arity in [1]:
        source = 'source=starcoder' #'source=gutenberg' #'source=starcoder' #'source=redpajama%2Farxiv'
        directory = f'ngrams_results_final/{source}/arity={arity}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.parquet') and 'final_ngram_counts' not in f]
        round_number = 0
        if round_number > 0:
            files = [file for file in files if f'round_{round_number-1}' in file]

        a = time()
        while len(files) > 1:
            print(f"Starting round {round_number} of merges, {len(files)} files remaining...")
            files = parallel_merge(files, directory, round_number)
            round_number += 1
        print('Time processed', time()-a)
        if files:
            # Rename the final file to the final output
            final_output_file = f'{directory}/final_ngram_counts_check.parquet'
            shutil.move(files[0], final_output_file)
            print(f"Done. Final results written to: {final_output_file}")

if __name__ == "__main__":
    main()
