from typing import List
import torch
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig


class DNABERTModel:
    model_names = [
        "zhihan1996/DNABERT-2-117M",
        "zhihan1996/DNABERT-S",
        "zhihan1996/DNA_bert_3",
        "zhihan1996/DNA_bert_4",
        "zhihan1996/DNA_bert_5",
        "zhihan1996/DNA_bert_6",
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if model_name == "zhihan1996/DNABERT-S":
            self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to('cuda')
        else:
            config = BertConfig.from_pretrained(model_name)
            self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, config=config).to('cuda')
        self.model.eval()
        
        self.kmer = None
        if model_name[-1].isdigit():
            self.kmer = int(model_name[-1])

    def __call__(self, seqs: List[str]):
        if self.kmer is not None:
            seqs_kmer = [" ".join(seq[i : i + self.kmer] for i in range(0, len(seq), self.kmer)) for seq in seqs]
            inputs = self.tokenizer(seqs_kmer, return_tensors="pt", padding=True).to('cuda')
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state[:, 1:-1]
        else:
            embeds = []
            # as DNABERT-2/S uses BPE, we need to tokenize each sequence separately
            for seq in seqs:
                inputs = self.tokenizer(seq, return_tensors="pt").to('cuda')
                outputs = self.model(**inputs)
                embeds.append(outputs[0].mean(1))
            embeddings = torch.cat(embeds, dim=0)
        return embeddings
