import os
import gc
import torch
import numpy as np
import random

seed = 7

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

from tqdm.auto import tqdm
from deli import load
from transformers import AutoTokenizer, AutoModelForMaskedLM


def get_tokens(seqs, tokenizer):
    tokens = tokenizer.batch_encode_plus(seqs)['input_ids']
    return tokens


def get_embeddings(tokens, esm_model, device):    
    embeddings = []
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(tokens)):
            if not i%1000 and i!=0:
                # print(f"{i+1} embeddings were obtained")
                torch.cuda.empty_cache()
                gc.collect()
            batch = torch.tensor(batch).to(device)
            batch = batch[None, :]
            res = esm_model(batch, output_hidden_states=True)['hidden_states'][-1]
            embeddings.extend(res[:, 1:-1].mean(dim=1).cpu())
                
    embeddings = torch.stack(embeddings).numpy()
    return embeddings


def get_embeddings_residue(tokens, esm_model, device):    
    embeddings = []
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(tokens)):
            if not i%1000 and i!=0:
                torch.cuda.empty_cache()
                gc.collect()
            batch = torch.tensor(batch).to(device)
            batch = batch[None, :]
            res = esm_model(batch, output_hidden_states=True)['hidden_states'][-1]
            embeddings.append(res[0, 1:-1].cpu())
                
    return embeddings


def main(dataset_sequence_path, save_emb_path, model_type, emb_type, 
         reduce_to_unique_sequences=False):
    os.makedirs(save_emb_path, exist_ok=True)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Available device:', device)

    if model_type == 'hf_esm_6':
        model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
    elif model_type == 'hf_esm_12':
        model_checkpoint = 'facebook/esm2_t12_35M_UR50D'
    elif model_type == 'hf_esm_33':
        model_checkpoint = 'facebook/esm2_t33_650M_UR50D'
    else:
        print(f'Model {model_type} not found')

    model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 
    model.eval()
    model.to(device=device);

    print('Model loaded')

    num_params_trainable = 0
    num_params_all = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params_trainable += int(torch.prod(torch.tensor(param.data.shape)))
        num_params_all += int(torch.prod(torch.tensor(param.data.shape)))
    print('Trainable parameters:', num_params_trainable)
    print('All parameters:', num_params_all)

    all_data = load(dataset_sequence_path)

    print('Sequences loaded')

    if reduce_to_unique_sequences:
        print('Reducing to unique sequences')
        print('Number of sequences:', len(all_data))
        prepared_sequences = list(set([''.join(seq) for seq in all_data.values()]))
        print('Number of unique sequences:', len(prepared_sequences))
    else:
        prepared_sequences = [''.join(seq) for seq in all_data.values()]

    tokens = get_tokens(prepared_sequences, tokenizer)
    if emb_type == 'protein':
        embeddings = get_embeddings(tokens=tokens,
                                    esm_model=model, 
                                    device=device)
    else:
        embeddings = get_embeddings_residue(tokens=tokens,
                                    esm_model=model, 
                                    device=device)
        
    if reduce_to_unique_sequences:
        names = prepared_sequences
    else:
        names = all_data.keys()
    
    print(type(embeddings), len(embeddings), embeddings[0].shape)
    print('names', len(names))
    save_data = {name: emb for name, emb in zip(names, embeddings)}
    torch.save(save_data, os.path.join(save_emb_path, f'{model_type}_embeddings.pt'))


if __name__ == "__main__":
    dataset_sequence_path = '<path_to_dataset_sequence_path>'
    save_emb_path = '<path_to_save_embeddings_path>'
    reduce_to_unique_sequences = False

    emb_type = 'residue'
    model_type = 'hf_esm_12'

    main(dataset_sequence_path, save_emb_path, model_type=model_type, 
         emb_type=emb_type, reduce_to_unique_sequences=reduce_to_unique_sequences)
