#!/usr/bin/env python

import argparse
import pandas as pd
import h5py
from project.wrappers import OneHotWrapper

def main(csv_file, output_h5_file, max_length, batch_size=100):

    one_hot_encoder = OneHotWrapper(max_length)
    
    # Initialize HDF5 file for writing
    with h5py.File(output_h5_file, 'w') as h5f:
        # Process the CSV in chunks
        with pd.read_csv(csv_file, chunksize=batch_size) as reader:
            for i, chunk in enumerate(reader):
                # Get sequences
                sequences = chunk["Sequence"].tolist()
                
                # Generate embeddings
                embeddings = one_hot_encoder.encode(sequences)
                
                # Save embeddings to HDF5
                if 'embeddings' not in h5f:
                    # Create dataset if it does not exist
                    max_shape = (None,) + embeddings.shape[1:]  # None implies extendable dimension
                    h5f.create_dataset('embeddings', data=embeddings.cpu().numpy(), maxshape=max_shape, chunks=True)
                else:
                    # Append to dataset
                    old_shape = h5f['embeddings'].shape[0]
                    new_shape = old_shape + embeddings.shape[0]
                    h5f['embeddings'].resize((new_shape,) + embeddings.shape[1:])
                    h5f['embeddings'][old_shape:new_shape] = embeddings.cpu().numpy()
                
                print("{} embeddings stored".format((i+1) * batch_size))

if __name__ == "__main__":
    # Argument parser setup
    parser = argparse.ArgumentParser(description='Generate embeddings for sequences from a CSV file and store in HDF5 format.')
    parser.add_argument('--csv_file', type=str, default='data/generative-model-data/generative-model-dataset.csv',
                        help='Path to the input CSV file')
    parser.add_argument('--output_h5_file', type=str, default='data/generative-model-data/generative-model-embeddings.h5',
                        help='Output HDF5 file path')
    parser.add_argument('--max_length', type=int, default=100,
                        help='Maximum length of sequences')

    # Parse arguments
    args = parser.parse_args()

    # Execute main function
    main(args.csv_file, args.output_h5_file, args.max_length)
