import torch
from evo2 import Evo2
import h5py
import numpy as np
from transformers import AutoTokenizer, AutoConfig

tokenizer = AutoTokenizer.from_pretrained('/home/jovyan/dnalm/data/tokenizers/t2t_1000h_multi_32k/')

evo2_model = Evo2('evo2_1b_base')

# dataset_file = h5py.File('/home/jovyan/dnalm/datasets/mane_transcript_train_dataset_max_exon_cds.hdf5', "r")
dataset_file = h5py.File('/home/jovyan/evo2_experiments/32k_human_chr21.hdf5', "r")


# with h5py.File('mane_transcript_train_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5', 'w') as f:
with h5py.File('minja_human_chr21_t2t_evo2_mlm_embeddings.h5', 'w') as f:

    for i in range(len(list(dataset_file.keys()))):

        print(f'Processing sample: {i}')
    
        # sample_name = "transcript_" + str(i)
        # token_atcg = np.array(dataset_file[sample_name]["token_atcg"])[:32000]
        # sequence = ''.join(list(tokenizer.decode(token_atcg)))
        sample_name = "sample_" + str(i)
        sequence = dataset_file[sample_name].attrs['seq']
    
        assert len(sequence) == len(evo2_model.tokenizer.tokenize(sequence))
        
        input_ids = torch.tensor(
            evo2_model.tokenizer.tokenize(sequence),
            dtype=torch.int,
        ).unsqueeze(0).to('cuda:0')
        
        layer_name = 'blocks.23.mlp.l3'
        
        outputs, embeddings = evo2_model(input_ids, return_embeddings=True, layer_names=[layer_name])
        logits = outputs[0]

        dict_to_write = dict()
        # dict_to_write['embeddings'] = embeddings[layer_name].to(dtype=torch.float32).detach().cpu().numpy().astype(np.float16)
        dict_to_write['logits'] = logits.to(dtype=torch.float32).detach().cpu().numpy().astype(np.float16)

        print(f'Saving sample: {i}')
        group = f.create_group(sample_name)
        # group.create_dataset('embeddings', data=dict_to_write['embeddings'], compression='gzip', compression_opts=4) # 4 is default
        group.create_dataset('logits', data=dict_to_write['logits'], compression='gzip', compression_opts=4) # 4 is default

# print('Logits: ', logits)
# print('Shape (batch, length, vocab): ', logits.shape)
# print('Embeddings shape: ', embeddings[layer_name].shape)