import pandas as pd
import numpy as np
import argparse

parser = argparse.ArgumentParser(description="Combine and split CSV datasets from multiple files.")
parser.add_argument('--file_names', nargs='+', required=True, help='List of CSV file paths to combine.')
parser.add_argument('--train_output_file', type=str, required=True, help='Output path for the training dataset CSV.')
parser.add_argument('--val_output_file', type=str, required=True, help='Output path for the validation dataset CSV.')
parser.add_argument('--test_output_file', type=str, required=True, help='Output path for the test dataset CSV.')
parser.add_argument('--train_size', type=int, default=5000, help='Number of samples for the training set.')
parser.add_argument('--val_size', type=int, default=2000, help='Number of samples for the validation set.')

args = parser.parse_args()

file_names = args.file_names
train_size = args.train_size
val_size = args.val_size
train_output_file = args.train_output_file
val_output_file = args.val_output_file
test_output_file = args.test_output_file

dataframes = []
try:
    for file in file_names:
        df = pd.read_csv(file)
        df.drop_duplicates(subset='phrase', keep='first', inplace=True)
        if 'model_output' in df.columns:
            df.drop(columns=['model_output'], inplace=True)
        dataframes.append(df)
        print(f"Successfully loaded '{file}' with {len(df)} samples.")
except FileNotFoundError as e:
    print(f"Error: {e}. Please ensure all CSV files are in the same directory as the script.")
    exit()

combined_df = pd.concat(dataframes, ignore_index=True)
print(f"\nTotal combined samples: {len(combined_df)}")

shuffled_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)

train_df = shuffled_df.iloc[:train_size]
val_df = shuffled_df.iloc[train_size : train_size + val_size]
test_df = shuffled_df.iloc[train_size + val_size :]

train_df.to_csv(train_output_file, index=False)
val_df.to_csv(val_output_file, index=False)
test_df.to_csv(test_output_file, index=False)

print("Splitting complete. The datasets have been saved:")
print(f"Training set:   {len(train_df)} samples saved to '{train_output_file}'")
print(f"Validation set: {len(val_df)} samples saved to '{val_output_file}'")
print(f"Test set:       {len(test_df)} samples saved to '{test_output_file}'")