import os
import argparse
from datasets import load_dataset, load_from_disk
import json

def convert_arrow_to_json(arrow_path, output_path, max_convert=None, split="train"):
    """
    Converts a .arrow dataset into a JSON file.
    
    Args:
        arrow_path (str): Path to the folder containing the .arrow file(s).
        output_path (str): Path where the JSON file will be saved.
        max_convert (int, optional): Maximum number of data points to convert. If None, converts all data.
    """
    # Load the dataset
    print(f"Loading dataset from {arrow_path}...")
    try:
        dataset = load_dataset(arrow_path, split=split)
    except:
        dataset = load_from_disk(arrow_path)
    
    # Check if dataset is loaded
    if not dataset:
        print("No datasets found in the provided path.")
        return
    

    
    # Slice the dataset if max_convert is specified
    if max_convert is not None:
        print(f"Limiting dataset to the first {max_convert} examples...")
        data_list = [dataset[i] for i in range(min(max_convert, len(dataset)))]
    else:
        data_list = [example for example in dataset]
    
    # Save to JSON file
    print(f"Saving dataset to {output_path}...")
    with open(output_path, "w", encoding="utf-8") as json_file:
        json.dump(data_list, json_file, indent=4, ensure_ascii=False)
    
    print(f"Dataset successfully converted and saved to {output_path}.")

if __name__ == "__main__":
    # Command-line argument parsing
    parser = argparse.ArgumentParser(description="Convert .arrow dataset to JSON format.")
    parser.add_argument("--arrow_path", type=str, required=True, help="Path to the folder containing the .arrow file(s).")
    parser.add_argument("--output_path", type=str, required=True, help="Path to save the output JSON file.")
    parser.add_argument("--max_convert", type=int, default=None, help="Maximum number of data points to convert.")
    parser.add_argument("--split", type=str, default="train", help="split of the dataset to be annotated, by default 'train'")
    
    args = parser.parse_args()
    
    # Run conversion
    convert_arrow_to_json(args.arrow_path, args.output_path, args.max_convert)
