import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from dask.distributed import Client
import logging
import polars as pl
from time import time
import dask.dataframe as dd


def merge_two_files(file1, file2, output_file, round_number):
    #try:
        # Read tables using Polars
        df1 = pl.read_parquet(file1)
        df2 = pl.read_parquet(file2)


        # Convert 'ngram' entries to tuples to ensure they are hashable
        # Polars handles lists natively but to emulate tuple behavior, we need to convert them to string
        # as Polars does not support tuple types
        #df1 = df1.with_columns(pl.col('ngram').arr.to_struct().cast(pl.Utf8).alias('ngram'))
        #df2 = df2.with_columns(pl.col('ngram').arr.to_struct().cast(pl.Utf8).alias('ngram'))

        # Combine and group by 'ngram', summing 'count'
        combined_df = df1.vstack(df2)
        result_df = combined_df.group_by('ngram').agg(pl.col('count').sum().alias('count'))

        # Write to Parquet
        result_df.write_parquet(output_file)
        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)
    #except Exception as e:
    #    print(f"Failed to merge {file1} and {file2}: {e}")

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=25) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i+1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i//2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=10) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i // 2}.parquet")
                futures.append(executor.submit(merge_two_files, files[i], files[i + 1], output_file, round_number))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i // 2}.parquet")
                shutil.move(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def main():
    # arity = 3
    for arity in [1]:
        source = 'source=starcoder'  # 'source=gutenberg' #'source=starcoder' #'source=redpajama%2Farxiv'
        directory = f'ngrams_results_final/{source}/arity={arity}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        files = [os.path.join(directory, f) for f in os.listdir(directory) if
                 f.endswith('.parquet') and 'final_ngram_counts' not in f]
        round_number = 0
        if round_number > 0:
            files = [file for file in files if f'round_{round_number - 1}' in file]

        a = time()
        while len(files) > 1:
            print(f"Starting round {round_number} of merges, {len(files)} files remaining...")
            files = parallel_merge(files, directory, round_number)
            round_number += 1
        print('Time processed', time() - a)
        if files:
            # Rename the final file to the final output
            final_output_file = f'{directory}/final_ngram_counts.parquet'
            shutil.move(files[0], final_output_file)
            print(f"Done. Final results written to: {final_output_file}")

if __name__ == "__main__":
    main()
