import argparse
import os
from datasets import load_dataset
from memgpt.database.utils.utils_database import DatabaseManager
from memgpt.trl.utils.utils_filter import filter_invalid_dblookups


def load_args():
    parser = argparse.ArgumentParser(description="Process and store dataset in a database.")
    parser.add_argument("--annotation_path", type=str, required=True, help="Path to the JSON annotation file.")
    parser.add_argument("--save_path", type=str, default=None, help="Path to save the database JSON file.")
    parser.add_argument("--filter_key", type=str, default=None, help="Key to filter the dataset on.")
    parser.add_argument("--filter_value", type=str, default=None, help="Value to filter the dataset on.")
    return parser.parse_args()

def load_and_filter_dataset(annotation_path, filter_key=None, filter_value=None):
    """Load dataset and apply filters if provided."""
    try:
        if 'json' in annotation_path:
            dataset = load_dataset('json', data_files=annotation_path, split='train', field='examples')
        else:
            dataset = load_dataset(annotation_path, split='train', field='examples')
    except Exception as e:
        print(f"Error loading dataset: {e}")
        dataset = load_dataset(annotation_path, split='train')
    
    if filter_key and filter_value:
        dataset = dataset.filter(lambda example: example.get(filter_key) == filter_value)
    
    return dataset

def main():
    args = load_args()

    # Set default save path if not provided
    if not args.save_path:
        save_name = os.path.basename(args.annotation_path) + "_cleaned_database.json"
        args.save_path = os.path.join("./database", save_name)

    # Load and filter the dataset
    dataset = load_and_filter_dataset(args.annotation_path, args.filter_key, args.filter_value)

    dataset = dataset.map(filter_invalid_dblookups)  

    # Output dataset info
    print(f"Dataset size: {len(dataset)}")
    print(f"First entry: {dataset[0] if len(dataset) > 0 else 'Empty dataset'}")

    # Build and save database
    db_manager = DatabaseManager()
    db_manager.build_database(dataset)
    print(f"Database built successfully: {db_manager}")
    db_manager.save_database(args.save_path)
    print(f"Database saved to: {args.save_path}")

def load_database():
    """Load and print stats of all databases in the specified directory."""
    database_dir = "./database"
    for file in os.listdir(database_dir):
        if file.endswith(".json"):
            database_path = os.path.join(database_dir, file)
            db_manager = DatabaseManager()
            db_manager.load_database(database_path)
            print(f"Loaded database from {database_path}: {db_manager}")

if __name__ == "__main__":
    main()

