import torch
from evo2 import Evo2
import h5py
import numpy as np
from transformers import AutoTokenizer, AutoConfig
import math

tokenizer = AutoTokenizer.from_pretrained('/home/jovyan/dnalm/data/tokenizers/t2t_1000h_multi_32k/')

evo2_model = Evo2('evo2_1b_base')

# dataset_file = h5py.File('/home/jovyan/dnalm/datasets/mane_transcript_train_dataset_max_exon_cds.hdf5', "r")
dataset_file = h5py.File('/home/jovyan/evo2_experiments/sequences.hdf5', "r")


# with h5py.File('mane_transcript_train_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5', 'w') as f:
with h5py.File('sequences_exon_gene_level_embeddings.h5', 'w') as f:

    for i in range(len(list(dataset_file.keys()))):

        dict_to_write = dict()

        print(f'Processing sample: {i}')
    
        sample_name = "transcript_" + str(i)

        # TRANSCRIPT SEQ EMBEDDINGS
        
        sequence = dataset_file[sample_name].attrs['transcript_seq']
    
        assert len(sequence) == len(evo2_model.tokenizer.tokenize(sequence))

        transcript_partial_embeddings = []
        num_letter_level_segments = math.ceil(len(sequence) / 32000)
        for i in range(num_letter_level_segments):
            current_sequence = sequence[i*32000:(i+1)*32000]
        
            input_ids = torch.tensor(
                evo2_model.tokenizer.tokenize(current_sequence),
                dtype=torch.int,
            ).unsqueeze(0).to('cuda:0')
        
            layer_name = 'blocks.23.mlp.l3'
        
            outputs, embeddings = evo2_model(input_ids, return_embeddings=True, layer_names=[layer_name])
            logits = outputs[0]

            transcript_partial_embeddings.append(embeddings[layer_name].to(dtype=torch.float32).detach().cpu())

        collected_transcript_embeddings = torch.cat(transcript_partial_embeddings, dim=1)

        dict_to_write['transcript_embeddings'] = collected_transcript_embeddings.numpy().astype(np.float16)

        # GENE SEQ EMBEDDINGS
        
        sequence = dataset_file[sample_name].attrs['gene_seq']
    
        assert len(sequence) == len(evo2_model.tokenizer.tokenize(sequence))

        gene_partial_embeddings = []
        num_letter_level_segments = math.ceil(len(sequence) / 32000)
        for i in range(num_letter_level_segments):
            current_sequence = sequence[i*32000:(i+1)*32000]
        
            input_ids = torch.tensor(
                evo2_model.tokenizer.tokenize(current_sequence),
                dtype=torch.int,
            ).unsqueeze(0).to('cuda:0')
        
            layer_name = 'blocks.23.mlp.l3'
        
            outputs, embeddings = evo2_model(input_ids, return_embeddings=True, layer_names=[layer_name])
            logits = outputs[0]

            gene_partial_embeddings.append(embeddings[layer_name].to(dtype=torch.float32).detach().cpu())

        collected_gene_embeddings = torch.cat(gene_partial_embeddings, dim=1)

        dict_to_write['gene_embeddings'] = collected_gene_embeddings.numpy().astype(np.float16)

        print(f'Saving sample: {i}')
        group = f.create_group(sample_name)
        group.create_dataset('transcript_embeddings', data=dict_to_write['transcript_embeddings'], compression='gzip', compression_opts=4) # 4 is default
        group.create_dataset('gene_embeddings', data=dict_to_write['gene_embeddings'], compression='gzip', compression_opts=4) # 4 is default

# print('Logits: ', logits)
# print('Shape (batch, length, vocab): ', logits.shape)
# print('Embeddings shape: ', embeddings[layer_name].shape)