#!/usr/bin/env python

import argparse
import pandas as pd
import torch
import h5py
from project.wrappers import EsmWrapper

def main(csv_file, output_h5_file, embedding_dim, max_length, batch_size=100):
    # Set device for model computation
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load ESM wrapper
    esm_model = EsmWrapper(embedding_dim, max_length, device)

    sequences = pd.read_csv(csv_file)["Sequence"].tolist()

    sequences = [seq for seq in sequences if len(seq) <= max_length]

    # Initialize HDF5 file for writing
    with h5py.File(output_h5_file, 'w') as h5f:
        for i in range(0, len(sequences), batch_size):
            # Get sequences
            chunk = sequences[i:i+batch_size]

            # Generate embeddings
            embeddings = esm_model.encode(chunk)
            
            # Save embeddings to HDF5
            if 'embeddings' not in h5f:
                # Create dataset if it does not exist
                max_shape = (None,) + embeddings.shape[1:]  # None implies extendable dimension
                h5f.create_dataset('embeddings', data=embeddings.cpu().numpy(), maxshape=max_shape, chunks=True)
            else:
                # Append to dataset
                old_count = h5f['embeddings'].shape[0]
                new_count = old_count + embeddings.shape[0]
                h5f['embeddings'].resize((new_count,) + embeddings.shape[1:])
                h5f['embeddings'][old_count:new_count] = embeddings.cpu().numpy()
            
            print("{} embeddings stored".format(i + batch_size))

if __name__ == "__main__":
    # Argument parser setup
    parser = argparse.ArgumentParser(description='Generate embeddings for sequences from a CSV file and store in HDF5 format.')
    parser.add_argument('--csv_file', type=str, default='data/generative-model-data/generative-model-dataset.csv',
                        help='Path to the input CSV file')
    parser.add_argument('--output_h5_file', type=str, default='data/generative-model-data/generative-model-embeddings-esm2.h5',
                        help='Output HDF5 file path')
    parser.add_argument('--embedding_dim', type=int, default=320,
                        help='Dimension size of embeddings (default: 320)')
    parser.add_argument('--max_length', type=int, default=40,
                        help='Maximum length of sequences (default: 40)')

    # Parse arguments
    args = parser.parse_args()

    # Execute main function
    main(args.csv_file, args.output_h5_file, args.embedding_dim, args.max_length)
