import os
import shutil
import pyarrow.parquet as pq
import pandas as pd
import pyarrow as pa
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from time import time
from filelock import FileLock
import logging
from pathlib import Path

# Setup logging
logging.basicConfig(filename='merge_logs.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def log_processed_files(file_path, directory, filename='processed_files'):
    """Logs processed files with a file lock to prevent write conflicts."""
    lock = FileLock("processed_files.log.lock")
    with lock:
        with open(f'{directory}/{filename}.log', 'a') as file:
            file.write(file_path + '\n')

def get_processed_files(directory):
    """Reads the log file to get a list of already processed files."""
    processed_files = set()
    if os.path.exists(f'{directory}/processed_files.log'):
        with open(f'{directory}/processed_files.log', 'r') as file:
            processed_files = set(file.read().splitlines())
    return processed_files

def merge_two_files_new(file1, file2, output_file, round_number, directory, chunk_size=500000):
    try:
        # Initialize PyArrow Parquet reader
        reader1 = pq.ParquetFile(file1)
        reader2 = pq.ParquetFile(file2)

        # Define an empty DataFrame for the aggregated result
        aggregated_df = pd.DataFrame()
        log_processed_files(file1, directory, filename='starting')
        log_processed_files(file2, directory, filename='starting')
        # Assume both files have the same number of row groups and can be read in sync
        for i in range(reader1.num_row_groups):
            chunk1 = reader1.read_row_group(i).to_pandas()
            chunk2 = reader2.read_row_group(i).to_pandas()

            # Convert 'ngram' entries to tuples to ensure they are hashable
            chunk1['ngram'] = chunk1['ngram'].apply(tuple)
            chunk2['ngram'] = chunk2['ngram'].apply(tuple)

            # Concatenate chunks and group by 'ngram', summing 'count'
            combined_chunk = pd.concat([chunk1, chunk2])
            result_chunk = combined_chunk.groupby('ngram', as_index=False).agg({'count': 'sum'})

            # Aggregate with previous chunks
            aggregated_df = pd.concat([aggregated_df, result_chunk])
            aggregated_df = aggregated_df.groupby('ngram', as_index=False).agg({'count': 'sum'})

        # Convert back to Parquet
        result_table = pa.Table.from_pandas(aggregated_df)
        pq.write_table(result_table, output_file)
        log_processed_files(file1, directory)
        log_processed_files(file2, directory)
        print(f"Successfully merged {file1} and {file2} into {output_file}")

        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)

    except Exception as e:
        print(f"Failed to merge {file1} and {file2}: {e}")

def merge_two_files(file1, file2, output_file, round_number, directory):
    try:
        # Read tables
        log_processed_files(file1, directory, filename='starting')
        log_processed_files(file2, directory, filename='starting')

        table1 = pq.read_table(file1)
        table2 = pq.read_table(file2)

        # Convert to Pandas DataFrame
        df1 = table1.to_pandas()
        df2 = table2.to_pandas()

        # Convert 'ngram' entries to tuples to ensure they are hashable
        df1['ngram'] = df1['ngram'].apply(tuple)
        df2['ngram'] = df2['ngram'].apply(tuple)

        # Optional: Convert 'ngram' to categorical data type
        #df1['ngram'] = pd.Categorical(df1['ngram'])
        #df2['ngram'] = pd.Categorical(df2['ngram'])

        # Combine and group by 'ngram', summing 'count'
        combined_df = pd.concat([df1, df2])
        result_df = combined_df.groupby('ngram', as_index=False).agg({'count': 'sum'})

        # Convert back to Parquet
        result_table = pa.Table.from_pandas(result_df)
        pq.write_table(result_table, output_file)
        log_processed_files(file1, directory)
        log_processed_files(file2, directory)
        print(f"Successfully merged {file1} and {file2} into {output_file}")

        if round_number > 0:
            # Optionally delete the original files to save space
            os.remove(file1)
            os.remove(file2)

    except Exception as e:
        print(f"Failed to merge {file1} and {file2}: {e}")

def parallel_merge(files, directory, round_number):
    """ Use multiprocessing to merge files in pairs """
    with ProcessPoolExecutor(max_workers=1) as executor:
        futures = []
        num_files = len(files)
        new_files = []

        # Generate output files with round number and pair index, ensuring unique filenames
        for i in range(0, num_files, 2):
            if i + 1 < num_files:
                output_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i // 2}.parquet")
                file_index = 0

                # Check if the file already exists and create a unique filename
                while Path(output_file).exists():
                    print('File', output_file, 'exists.')
                    file_index += 1
                    output_file = os.path.join(directory,
                                               f"merged_round_{round_number}_pair_{i // 2}_{file_index}.parquet")

                print('Writing to', output_file)
                futures.append(
                    executor.submit(merge_two_files, files[i], files[i + 1], output_file, round_number, directory))
                new_files.append(output_file)
            else:
                # Odd file count, just rename the last file
                new_file = os.path.join(directory, f"merged_round_{round_number}_pair_{i // 2}.parquet")
                file_index = 0

                # Check if the file already exists and create a unique filename
                while Path(new_file).exists():
                    print('File', new_file, 'exists.')
                    file_index += 1
                    new_file = os.path.join(directory,
                                            f"merged_round_{round_number}_pair_{i // 2}_{file_index}.parquet")
                print('Writing new file to', new_file)

                shutil.copy(files[i], new_file)
                new_files.append(new_file)

        # Wait for all futures to complete
        for future in tqdm(futures, desc=f"Merging files - Round {round_number}"):
            future.result()  # If needed, handle errors here

        return new_files

def main():
    #arity = 3
    #sources = ['source=starcoder', 'source=gutenberg',
    #           'source=redpajama%2Farxiv']  # 'source=starcoder' #'source=redpajama%2Farxiv'
    sources = ['source=megawika', 'source=redpajama%2Fstackexchange', 'source=reddit', 'source=falcon-refinedweb%2Fdata']
    sources = ['source=megawika_source=redpajama%2Fstackexchange_source=reddit_source=falcon-refinedweb%2Fdata',
               'source=starcoder_source=gutenberg_source=redpajama%2Farxiv']
    sources = ['source=starcoder', 'source=falcon-refinedweb%2Fdata', 'source=reddit', 'source=s2', 'source=redpajama%2Farxiv', 'source=redpajama%2Fstackexchange', 'source=gutenberg', 'source=megawika']
    #filename_to_merge = 'final_ngram_counts_subtracted_last.parquet'

    generated_filename = '_'.join(sources)
    num_files_merged = 1

    for arity in [4]: #[2, 3]: #[2, 1]:
        filename_to_merge = f'source='
        print('Doing arity', arity)
        files = []
        directory_output = f'ngrams_results_final/joined_final_{num_files_merged}/arity={arity}'

        for source in sources:
            directory_input = f'ngrams_results_final/{source}/arity={arity}'

            if not os.path.exists(directory_output):
                os.makedirs(directory_output)

            processed_files = get_processed_files(directory_output)

            input_files = [os.path.join(directory_input, f) for f in os.listdir(directory_input) if f.endswith('.parquet')
                          and filename_to_merge in f and os.path.join(directory_input, f) not in processed_files]
            #print('all input files', sorted(input_files))
            input_files = sorted(input_files)[:num_files_merged]
            print('Using only the first file', source, input_files)

            files.extend(input_files)

        round_number = 0
        if round_number > 0:
            files = [file for file in files if f'round_{round_number-1}' in file]
        elif round_number == 0:
            files = [file for file in files if f'round_' not in file]
        else:
            raise ValueError("round_number should be >= 0.")

        a = time()
        while len(files) > 1:
            print(f"Starting round {round_number} of merges, {len(files)} files remaining...:", files)
            files = parallel_merge(files, directory_output, round_number)
            round_number += 1
        print('Time processed', time()-a)
        if files:
            # Rename the final file to the final output
            final_output_file = f'{directory_output}/{generated_filename}.parquet'
            shutil.move(files[0], final_output_file)
            print(f"Done. Final results written to: {final_output_file}")

if __name__ == "__main__":
    main()
