#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
This script concatenates the INTERNAL contents of specific columns from two 
Apache Parquet files on a row-by-row basis.

It reads two Parquet files and for each row, it concatenates the lists/arrays
found within the 'responses', 'log_probs', and 'tokenized_responses' columns.

- 'responses': Concatenates the internal list of strings.
  (e.g., ['a'] from file1 + ['b'] from file2 becomes ['a', 'b'])
- 'log_probs', 'tokenized_responses': Concatenates the internal NumPy arrays.

The script uses an 'apply' function for clarity on the row-wise operation.

Usage:
    python concat_parquet_apply.py <path_to_file1.parquet> <path_to_file2.parquet> <output_path.parquet>
"""

import argparse
import pandas as pd
import numpy as np

def concatenate_parquet_columns(file1_path: str, file2_path: str, output_path: str):
    """
    Reads two Parquet files, concatenates specific columns row-wise using
    an apply function, and saves the result to a new Parquet file.

    Args:
        file1_path (str): Path to the first input Parquet file.
        file2_path (str): Path to the second input Parquet file.
        output_path (str): Path to save the output Parquet file.
    """
    try:
        # Step 1: Read the Parquet files
        print(f"INFO: Reading data from '{file1_path}' and '{file2_path}'...")
        df1 = pd.read_parquet(file1_path)
        df2 = pd.read_parquet(file2_path)

        # Step 2: Perform a sanity check
        if len(df1) != len(df2):
            raise ValueError(
                f"Row count mismatch: File 1 has {len(df1)} rows, while File 2 has {len(df2)}."
            )

        # Step 3: Create a copy of the first DataFrame to store the result
        result_df = df1.copy()
        print("INFO: Starting row-wise concatenation...")

        # --- Columns to process ---
        # A dictionary mapping column name to its data type for concatenation
        columns_to_process = {
            "responses": "list",
            "log_probs": "numpy",
            "tokenized_responses": "numpy"
        }

        for col, data_type in columns_to_process.items():
            if col not in df1.columns or col not in df2.columns:
                print(f"  - WARNING: Column '{col}' not found in both files. Skipping.")
                continue

            # We'll use a temporary DataFrame to make the apply syntax clean.
            # It holds the two columns we want to combine, side-by-side.
            temp_df = pd.DataFrame({
                'col1': df1[col],
                'col2': df2[col]
            })

            print(f"  - Processing column '{col}'...")
            result_df[col] = temp_df.apply(
                lambda row: np.concatenate([row['col1'], row['col2']]), 
                axis=1
            )
        
        # Step 4: Save the resulting DataFrame
        print(f"INFO: Saving concatenated DataFrame to '{output_path}'...")
        result_df.to_parquet(output_path, index=False)
        
        print("\nDone! Concatenation complete. ✨")

    except FileNotFoundError as e:
        print(f"ERROR: File not found - {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def main():
    """Main function to parse command-line arguments and execute the script."""
    parser = argparse.ArgumentParser(
        description="Concatenate the internal lists/arrays of specific columns from two Parquet files.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("file1", type=str, help="Path to the first input Parquet file.")
    parser.add_argument("file2", type=str, help="Path to the second input Parquet file.")
    parser.add_argument("output", type=str, help="Path for the output Parquet file.")
    args = parser.parse_args()
    concatenate_parquet_columns(args.file1, args.file2, args.output)

if __name__ == "__main__":
    main()