import os
import time
import torch
from typing import Any, Union, Dict, List
import numpy as np
from Bio.Seq import Seq
from utils.util import load_config
from utils.get_embed.constant import MODEL_MAP


def get_seqs_in_txts(folder_path: str) -> List[str]:
    seqs = set()
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if not file.endswith('.txt'): continue
            with open(os.path.join(root, file), 'r') as f:
                for line in f:
                    seq = line.strip()
                    if not seq or seq in seqs: continue
                    seqs.add(seq)
    print(f'Loaded {len(seqs)} fixed sequences from {folder_path}')
    return list(seqs)


class AA_filebase:
    def __init__(self, save_path: str) -> None:
        self.addr = save_path
        
        # Create save folder if not exists
        if not os.path.exists(self.addr):
            os.makedirs(self.addr)
        
        # Create index file if not exists
        self.index_path = self.addr.rstrip('/') + '_index.csv'
        if not os.path.exists(self.index_path):
            with open(self.index_path, 'w') as f:
                f.write('seq,filename\n')

        # Load index file
        self.index = dict()
        with open(self.index_path, 'r') as f:
            for line in f.readlines()[1:]:
                seq, filename = line.strip().split(',')
                self.index[seq] = filename
        self.top = len(self.index)
        print(f'Loaded {self.top} embeddings from {self.addr}')
    
    def __len__(self,) -> int:
        return len(self.index)
    
    def update(self, data: Dict[str, Dict[str, Any]]) -> Dict:
        for seq, log in data.items():
            if seq in self.index: continue
            self.top += 1
            filename = log.get('filename', f"{self.top}_{time.time()}.npy")
            if isinstance(log['embed'], torch.Tensor):
                if log['embed'].dtype == torch.bfloat16:
                    embed = log['embed'].to(torch.float32).numpy()
                else:
                    embed = log['embed'].numpy()
            else:
                embed = log['embed']
            
            self.index[seq] = filename
            np.save(os.path.join(self.addr, filename), embed)
            with open(self.index_path, 'a') as f:
                f.write(f"{seq},{self.index[seq]}\n")
        return {"status": "success", "message": f"Saved {len(data)} embeddings"}
    
    def fetch(self, keys: Union[str, List]) -> Dict[str, Dict[str, Any]]:
        if isinstance(keys, str): keys = [keys]
        result = {}
        for key in keys:
            if key in self.index:
                file_path = os.path.join(self.addr, self.index[key])
                assert os.path.exists(file_path), f'BUGs in os.path.exits: {key}={file_path}'
                embed = np.load(file_path)
                result[key] = dict(embed=torch.tensor(embed))
        return result


@torch.no_grad()
def extract_embed(cfg):
    if isinstance(cfg.model_names, str):
        model_names = [cfg.model_names]
    else:
        model_names = cfg.model_names
    
    if isinstance(cfg.seq_type, str):
        seq_types = [cfg.seq_type]
    else:
        seq_types = cfg.seq_type
    
    choices = [[m, s] for m in model_names for s in seq_types]
    
    for model_name, seq_type in choices:
        print(f'\nExtracting embeddings for {model_name} for {seq_type}')
        if hasattr(cfg, 'use_jax') and cfg.use_jax:
            from utils.get_embed.nucleotidetransformer import AA_filebase_jax
            filebase = AA_filebase_jax(f'{cfg.root}/{seq_type}_{model_name.split("/")[-1]}')
        else:
            filebase = AA_filebase(f'{cfg.root}/{seq_type}_{model_name.split("/")[-1]}')
        
        seqs = get_seqs_in_txts(cfg.dataset_path)
        if seq_type == 'DNA':
            seqs = [seq for seq in seqs if seq not in filebase.index]
        elif seq_type == 'RNA':
            seqs = [seq.replace('T', 'U') for seq in seqs if seq.replace('T', 'U') not in filebase.index]
        elif seq_type == 'AA':
            seqs = [str(Seq(seq).translate()) for seq in seqs if str(Seq(seq).translate()) not in filebase.index]
            seqs = list(set(seqs))
        else:
            raise ValueError(f'Invalid sequence type: {seq_type}')
        print(f'Loaded {len(seqs)} unique unseen sequences to extract embeddings.')
        if len(seqs) == 0:
            print(f'All sequences have been processed. Skipping {model_name}.')
            continue
        model = MODEL_MAP(model_name, seq_type, cfg)
        t0 = time.time()
        for i in range(0, len(seqs), cfg.bs):
            t1 = time.time()
            batch_seqs = seqs[i: i + cfg.bs]
            embeddings = model(batch_seqs).cpu()
            filebase.update({batch_seqs[j]: dict(embed=embeddings[j]) for j in range(len(batch_seqs))})
            print(f"[{model_name}-{seq_type} {i+len(batch_seqs)}/{len(seqs)}] Time taken: {time.time() - t1:.2f} seconds, " + 
                  f"Elapsed time: {(time.time() - t0)/3600:.2f} hours, ETA: {(time.time() - t0) / (i+len(batch_seqs)) * (len(seqs) - i - len(batch_seqs)) / 3600:.2f} hours")


if __name__ == "__main__":
    cfg = load_config()
    func = globals().get(cfg.task)
    if func and callable(func):
        func(cfg)
    else:
        print(f"Function '{cfg.task}' not found or not callable.")
