import os
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model, download_ckpt_and_hyperparams
from typing import List, Dict, Any, Union
import numpy as np
import time


class AA_filebase_jax:
    def __init__(self, save_path: str) -> None:
        self.addr = save_path
        
        # Create save folder if not exists
        if not os.path.exists(self.addr):
            os.makedirs(self.addr)
        
        # Create index file if not exists
        self.index_path = self.addr.rstrip('/') + '_index.csv'
        if not os.path.exists(self.index_path):
            with open(self.index_path, 'w') as f:
                f.write('seq,filename\n')

        # Load index file
        self.index = dict()
        with open(self.index_path, 'r') as f:
            for line in f.readlines()[1:]:
                seq, filename = line.strip().split(',')
                self.index[seq] = filename
        self.top = len(self.index)
    
    def __len__(self,) -> int:
        return len(self.index)
    
    def update(self, data: Dict[str, Dict[str, Any]]) -> Dict:
        for seq, log in data.items():
            if seq in self.index: continue
            self.top += 1
            file_path = os.path.join(self.addr, f"{self.top}.npy")
            embed = jax.device_get(log['embed'])
            
            self.index[seq] = f'{self.top}.npy'
            with open(self.index_path, 'a') as f:
                f.write(f"{seq},{self.index[seq]}\n")
            np.save(file_path, embed)
        return {"status": "success", "message": f"Saved {len(data)} embeddings"}
    
    def fetch(self, keys: Union[str, List]) -> Dict[str, Dict[str, Any]]:
        if isinstance(keys, str): keys = [keys]
        result = {}
        for key in keys:
            if key in self.index:
                file_path = os.path.join(self.addr, self.index[key])
                assert os.path.exists(file_path), f'BUGs in os.path.exits: {key}={file_path}'
                embed = np.load(file_path)
                result[key] = dict(embed=jax.numpy.array(embed))
        return result


class NucleotideTransformerModel:
    model_names = [    
        "NT50M_multi_species_v2",
        "NT100M_multi_species_v2",
        "NT250M_multi_species_v2",
        "NT500M_multi_species_v2",
        "NT500M_human_ref",
        "NT500M_1000G",
        "NT2B5_1000G",
        "NT2B5_multi_species",
        "NT1B_agro_nt",
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        self.random_key = jax.random.PRNGKey(0)
        
        model_name = model_name[2:]
        # Get pretrained model
        _, self.hyperparams = download_ckpt_and_hyperparams(model_name)
        self.parameters, forward_fn, self.tokenizer, config = get_pretrained_model(
            model_name=model_name,
            embeddings_layers_to_save=(self.hyperparams['num_layers'],),
            max_positions=87,
        )
        self.forward_fn = hk.transform(forward_fn)

    def __call__(self, seqs: List[str]):
        tokens_ids = [b[1] for b in self.tokenizer.batch_tokenize(seqs)]
        tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
        outs = self.forward_fn.apply(self.parameters, self.random_key, tokens)
        return outs[f"embeddings_{self.hyperparams['num_layers']}"]
