import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from pathlib import Path
import logging

# Setup logging
logging.basicConfig(filename='merge_logs.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def subtract_from_final_counts(directory, subtraction_file_name):
    """Subtract counts in subtraction_file from final_ngram_counts in the directory."""
    final_ngram_path = os.path.join(directory, 'final_ngram_counts.parquet')
    final_ngram_path_subtracted_last = os.path.join(directory, 'final_ngram_counts_subtracted_last.parquet')
    subtraction_file_path = os.path.join(directory, subtraction_file_name)

    if not os.path.exists(final_ngram_path) or not os.path.exists(subtraction_file_path):
        logging.error("Required files for the subtraction process are missing.")
        return

    try:
        # Read data from both files
        final_ngram_table = pq.read_table(final_ngram_path)
        subtraction_table = pq.read_table(subtraction_file_path)

        # Convert to pandas DataFrame
        df_final = final_ngram_table.to_pandas()
        df_subtract = subtraction_table.to_pandas()

        # Ensure 'ngram' is in a comparable format
        df_final['ngram'] = df_final['ngram'].apply(tuple)
        df_subtract['ngram'] = df_subtract['ngram'].apply(tuple)

        # Merge dataframes on 'ngram' and subtract counts
        df_merged = pd.merge(df_final, df_subtract, on='ngram', how='left', suffixes=('_final', '_subtract'))
        df_merged['count_subtract'] = df_merged['count_subtract'].fillna(
            0)  # fill NaNs with 0 for ngrams not in subtraction_file
        df_merged['count'] = (df_merged['count_final'] - df_merged['count_subtract']).astype(int)

        # Drop unnecessary columns and filter out non-positive counts if needed
        df_merged = df_merged.drop(columns=['count_final', 'count_subtract'])
        df_merged = df_merged[df_merged['count'] > 0]  # Optional: Remove entries with zero or negative counts

        # Convert back to Parquet
        result_table = pa.Table.from_pandas(df_merged)
        pq.write_table(result_table, final_ngram_path_subtracted_last)
        logging.info(f"Counts subtracted and updated file saved to {final_ngram_path}")

    except Exception as e:
        logging.error(f"Failed to subtract counts: {e}")


def main():
    source = 'source=redpajama%2Fstackexchange'
    for arity in [1, 2]:
        directory = f'ngrams_results_final/{source}/arity={arity}'
        subtraction_file_name = f'{source}_932.parquet'

        if not os.path.exists(directory):
            os.makedirs(directory)

        # Perform the subtraction operation
        subtract_from_final_counts(directory, subtraction_file_name)

if __name__ == "__main__":
    main()
