import argparse
import pandas as pd
import numpy as np
import os

def split_data(args):
    """
    Loads a processed dataset, splits it by prompt, shuffles, and then
    explodes the training and validation sets before saving.
    """
    # 1. Load the dataset
    print(f"Loading dataset from {args.input_file}...")
    try:
        df = pd.read_parquet(args.input_file)
    except Exception as e:
        print(f"Error loading Parquet file: {e}")
        return

    # 2. Shuffle the data at the prompt level (before exploding)
    print(f"Shuffling {len(df)} prompts with random seed {args.random_seed}...")
    shuffled_df = df.sample(frac=1, random_state=args.random_seed).reset_index(drop=True)

    # 3. Split the DataFrame of prompts into training and validation sets
    split_index = int(len(shuffled_df) * args.split_ratio)

    train_prompts_df = shuffled_df.iloc[:split_index]
    val_prompts_df = shuffled_df.iloc[split_index:]

    print(f"Splitting prompts: {len(train_prompts_df)} for training, {len(val_prompts_df)} for validation.")

    # 4. Explode the training and validation sets separately
    list_columns = [
        'responses',
        'reward',
        'proposal_log_probs',
        'base_log_probs',
        'target_log_scores',
        'target_log_probs'
    ]
    existing_list_columns = [col for col in list_columns if col in df.columns]

    print(f"Exploding training set...")
    train_df = train_prompts_df.explode(existing_list_columns).reset_index(drop=True)

    print(f"Exploding validation set...")
    val_df = val_prompts_df.explode(existing_list_columns).reset_index(drop=True)

    print(f"Final dataset sizes: {len(train_df)} training samples, {len(val_df)} validation samples.")

    # 5. Ensure output directories exist and save the new files
    for path in [args.train_output_file, args.val_output_file]:
        output_dir = os.path.dirname(path)
        if output_dir:  # Check if the path is not just a filename in the current directory
            os.makedirs(output_dir, exist_ok=True)

    print(f"Saving training data to {args.train_output_file}...")
    train_df.to_parquet(args.train_output_file, index=False)

    print(f"Saving validation data to {args.val_output_file}...")
    val_df.to_parquet(args.val_output_file, index=False)

    print("\nSplitting complete! ✅")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Explode, shuffle, and split a processed dataset into training and validation sets."
    )

    parser.add_argument("--input-file", type=str, required=True, help="Path to the processed Parquet file from the previous script.")
    parser.add_argument("--train-output-file", type=str, required=True, help="Path to save the training data Parquet file.")
    parser.add_argument("--val-output-file", type=str, required=True, help="Path to save the validation data Parquet file.")
    parser.add_argument("--split-ratio", type=float, default=0.99, help="The proportion of the data to use for the training set (e.g., 0.99 for 99%%).")
    parser.add_argument("--random-seed", type=int, default=42, help="Random seed for shuffling to ensure reproducibility.")

    args = parser.parse_args()

    if not 0 < args.split_ratio < 1:
        raise ValueError("--split-ratio must be between 0 and 1.")

    split_data(args)
