import time
import os
from typing import List, Dict, Any, Union, Tuple
import torch
import torch.export
from torch import Tensor
from transformers import EsmModel
from utils.get_embed.extract import get_seqs_in_txts
from Bio.Seq import Seq
import numpy as np
from chai_lab.data.collate.collate import Collate
from chai_lab.data.collate.utils import AVAILABLE_MODEL_SIZES
from chai_lab.data.dataset.all_atom_feature_context import (
    MAX_MSA_DEPTH,
    MAX_NUM_TEMPLATES,
    AllAtomFeatureContext,
)
from chai_lab.data.dataset.constraints.constraint_context import ConstraintContext
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
    AllAtomStructureContext,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.features.feature_factory import FeatureFactory
from chai_lab.data.features.feature_type import FeatureType
from chai_lab.data.features.generators.atom_element import AtomElementOneHot
from chai_lab.data.features.generators.atom_name import AtomNameOneHot
from chai_lab.data.features.generators.base import EncodingType
from chai_lab.data.features.generators.blocked_atom_pair_distances import (
    BlockedAtomPairDistances,
    BlockedAtomPairDistogram,
)
from chai_lab.data.features.generators.docking import DockingConstraintGenerator
from chai_lab.data.features.generators.esm_generator import ESMEmbeddings
from chai_lab.data.features.generators.identity import Identity
from chai_lab.data.features.generators.is_cropped_chain import ChainIsCropped
from chai_lab.data.features.generators.missing_chain_contact import MissingChainContact
from chai_lab.data.features.generators.msa import (
    IsPairedMSAGenerator,
    MSADataSourceGenerator,
    MSADeletionMeanGenerator,
    MSADeletionValueGenerator,
    MSAFeatureGenerator,
    MSAHasDeletionGenerator,
    MSAProfileGenerator,
)
from chai_lab.data.features.generators.ref_pos import RefPos
from chai_lab.data.features.generators.relative_chain import RelativeChain
from chai_lab.data.features.generators.relative_entity import RelativeEntity
from chai_lab.data.features.generators.relative_sep import RelativeSequenceSeparation
from chai_lab.data.features.generators.relative_token import RelativeTokenSeparation
from chai_lab.data.features.generators.residue_type import ResidueType
from chai_lab.data.features.generators.structure_metadata import (
    IsDistillation,
    TokenBFactor,
    TokenPLDDT,
)
from chai_lab.data.features.generators.templates import (
    TemplateDistogramGenerator,
    TemplateMaskGenerator,
    TemplateResTypeGenerator,
    TemplateUnitVectorGenerator,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.features.generators.token_dist_restraint import (
    TokenDistanceRestraint,
)
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
    TokenPairPocketRestraint,
)
from chai_lab.utils.paths import chai1_component
from chai_lab.utils.tensor_utils import move_data_to_device, set_seed, und_self
from chai_lab.utils.typing import typecheck


class UnsupportedInputError(RuntimeError):
    pass


def load_exported(comp_key: str, device: torch.device) -> torch.nn.Module:
    local_path = chai1_component(comp_key)
    exported_program = torch.export.load(local_path)
    return exported_program.module().to(device)


# %%
# Create feature factory

feature_generators = dict(
    RelativeSequenceSeparation=RelativeSequenceSeparation(sep_bins=None),
    RelativeTokenSeparation=RelativeTokenSeparation(r_max=32),
    RelativeEntity=RelativeEntity(),
    RelativeChain=RelativeChain(),
    ResidueType=ResidueType(
        min_corrupt_prob=0.0,
        max_corrupt_prob=0.0,
        num_res_ty=32,
        key="token_residue_type",
    ),
    ESMEmbeddings=ESMEmbeddings(),  # TODO: this can probably be the identity
    BlockedAtomPairDistogram=BlockedAtomPairDistogram(),
    InverseSquaredBlockedAtomPairDistances=BlockedAtomPairDistances(
        transform="inverse_squared",
        encoding_ty=EncodingType.IDENTITY,
    ),
    AtomRefPos=RefPos(),
    AtomRefCharge=Identity(
        key="inputs/atom_ref_charge",
        ty=FeatureType.ATOM,
        dim=1,
        can_mask=False,
    ),
    AtomRefMask=Identity(
        key="inputs/atom_ref_mask",
        ty=FeatureType.ATOM,
        dim=1,
        can_mask=False,
    ),
    AtomRefElement=AtomElementOneHot(max_atomic_num=128),
    AtomNameOneHot=AtomNameOneHot(),
    TemplateMask=TemplateMaskGenerator(),
    TemplateUnitVector=TemplateUnitVectorGenerator(),
    TemplateResType=TemplateResTypeGenerator(),
    TemplateDistogram=TemplateDistogramGenerator(),
    TokenDistanceRestraint=TokenDistanceRestraint(
        include_probability=0.0,
        size=0.33,
        min_dist=6.0,
        max_dist=30.0,
        num_rbf_radii=6,
    ),
    DockingConstraintGenerator=DockingConstraintGenerator(
        include_probability=0.0,
        structure_dropout_prob=0.75,
        chain_dropout_prob=0.75,
    ),
    TokenPairPocketRestraint=TokenPairPocketRestraint(
        size=0.33,
        include_probability=0.0,
        min_dist=6.0,
        max_dist=20.0,
        coord_noise=0.0,
        num_rbf_radii=6,
    ),
    MSAProfile=MSAProfileGenerator(),
    MSADeletionMean=MSADeletionMeanGenerator(),
    IsDistillation=IsDistillation(),
    TokenBFactor=TokenBFactor(include_prob=0.0),
    TokenPLDDT=TokenPLDDT(include_prob=0.0),
    ChainIsCropped=ChainIsCropped(),
    MissingChainContact=MissingChainContact(contact_threshold=6.0),
    MSAOneHot=MSAFeatureGenerator(),
    MSAHasDeletion=MSAHasDeletionGenerator(),
    MSADeletionValue=MSADeletionValueGenerator(),
    IsPairedMSA=IsPairedMSAGenerator(),
    MSADataSource=MSADataSourceGenerator(),
)
feature_factory = FeatureFactory(feature_generators)


def raise_if_too_many_tokens(n_actual_tokens: int):
    if n_actual_tokens > max(AVAILABLE_MODEL_SIZES):
        raise UnsupportedInputError(
            f"Too many tokens in input: {n_actual_tokens} > {max(AVAILABLE_MODEL_SIZES)}. "
            "Please limit the length of the input sequence."
        )


def raise_if_too_many_templates(n_actual_templates: int):
    if n_actual_templates > MAX_NUM_TEMPLATES:
        raise UnsupportedInputError(
            f"Too many templates in input: {n_actual_templates} > {MAX_NUM_TEMPLATES}. "
            "Please limit the number of templates."
        )


def raise_if_msa_too_deep(msa_depth: int):
    if msa_depth > MAX_MSA_DEPTH:
        raise UnsupportedInputError(
            f"MSA to deep: {msa_depth} > {MAX_MSA_DEPTH}. "
            "Please limit the MSA depth."
        )


class CHAI1:
    def __init__(self, device: torch.device, n_actual_tokens: int, use_esm_embeddings: bool = True,) -> None:
        self.n_actual_tokens = n_actual_tokens

        # Model is size-specific
        model_size = min(x for x in AVAILABLE_MODEL_SIZES if self.n_actual_tokens <= x)

        self.feature_embedding = load_exported(f"{model_size}/feature_embedding.pt2", device)
        self.token_input_embedder = load_exported(
            f"{model_size}/token_input_embedder.pt2", device
        )
        self.trunk = load_exported(f"{model_size}/trunk.pt2", device)
        
        from chai_lab.data.sources.rdkit import RefConformerGenerator
        from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import AllAtomResidueTokenizer
        conformer_generator = RefConformerGenerator()
        self.tokenizer = AllAtomResidueTokenizer(conformer_generator)
        self.use_esm_embeddings = use_esm_embeddings
        if use_esm_embeddings:
            # local import, requires huggingface transformers
            from transformers import EsmTokenizer

            self.tokenizer_esm = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
            self.model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D").to(device)
            self.model_esm.eval()

    @torch.no_grad()
    def run_inference(self,
        seqs: List[str],
        seq_type: str,
        # expose some params for easy tweaking
        num_trunk_recycles: int = 3,
        seed: int | None = None,
        device: torch.device | None = None,
    ) -> Tensor:
        # Prepare inputs
        assert len(seqs) > 0 and isinstance(seqs, list), "No inputs found in fasta file"
        entity_type = dict(AA=EntityType.PROTEIN, DNA=EntityType.DNA, RNA=EntityType.RNA)[seq_type]
        inputs = [Input(sequence, entity_type.value) for sequence in seqs]
        
        # Load structure context
        chains = load_chains_from_raw(inputs, tokenizer=self.tokenizer)
        contexts = [c.structure_context for c in chains]
        merged_context = AllAtomStructureContext.merge(contexts)
        n_actual_tokens = merged_context.num_tokens
        assert self.n_actual_tokens == n_actual_tokens, "n_actual_tokens does not match"
        raise_if_too_many_tokens(n_actual_tokens)

        # Load MSAs
        msa_context = MSAContext.create_empty(
            n_tokens=n_actual_tokens,
            depth=MAX_MSA_DEPTH,
        )
        main_msa_context = MSAContext.create_empty(
            n_tokens=n_actual_tokens,
            depth=MAX_MSA_DEPTH,
        )

        # Load templates
        template_context = TemplateContext.empty(
            n_tokens=n_actual_tokens,
            n_templates=MAX_NUM_TEMPLATES,
        )

        # Load ESM embeddings
        if self.use_esm_embeddings:
            embedding_context = get_esm_embedding_context(chains, device=device, tokenizer=self.tokenizer_esm, model=self.model_esm)
        else:
            embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens)

        # Constraints
        constraint_context = ConstraintContext.empty()

        # Build final feature context
        feature_context = AllAtomFeatureContext(
            chains=chains,
            structure_context=merged_context,
            msa_context=msa_context,
            main_msa_context=main_msa_context,
            template_context=template_context,
            embedding_context=embedding_context,
            constraint_context=constraint_context,
        )
        import time
        t1 = time.time()
        self.n_actual_tokens = feature_context.structure_context.num_tokens
        embeds = run_folding_on_context(
            self.feature_embedding,
            self.token_input_embedder,
            self.trunk,
            feature_context,
            num_trunk_recycles=num_trunk_recycles,
            seed=seed,
            device=device,
        )
        print(f"Time taken: {time.time() - t1:.2f} seconds")
        return embeds


def embedding_context_from_sequence(tokenizer, model, seq: str, device) -> EmbeddingContext:
    inputs = tokenizer(seq, return_tensors="pt")
    inputs = move_data_to_device(dict(**inputs), device=device)

    with torch.no_grad():
        outputs = model(**inputs)

    # remove BOS/EOS, back to CPU
    esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
    seq_len, _emb_dim = esm_embeddings.shape
    assert seq_len == len(seq)
    return EmbeddingContext(esm_embeddings=esm_embeddings)


@typecheck
def get_esm_embedding_context(chains, device, tokenizer, model) -> EmbeddingContext:
    # device is used for computing, but result is still on CPU
    chain_embs = []

    for chain in chains:
        if chain.entity_data.entity_type == EntityType.PROTEIN:
            emb = embedding_context_from_sequence(
                tokenizer=tokenizer,
                model=model,
                # modified residues represented as X
                seq=chain.entity_data.sequence,
                device=device,
            )
            chain_embs.append(emb)
        else:
            # embed non-proteins with zeros
            chain_embs.append(
                EmbeddingContext.empty(n_tokens=chain.structure_context.num_tokens)
            )

    exploded_embs = [
        embedding.esm_embeddings[chain.structure_context.token_residue_index, :]
        for embedding, chain in zip(chain_embs, chains, strict=True)
    ]

    # don't crop any chains during inference
    cropped_embs = exploded_embs

    # if we had to crop, we'd need some logic like below:
    # crop_idces: list[torch.Tensor]
    # cropped_embs = [
    #     embedding[crop_idx, :] for embedding, crop_idx in zip(exploded_embs, crop_idces)
    # ]

    # Merge the embeddings along the tokens dimension (i.e. merge the chains)
    merged_embs = torch.cat(cropped_embs, dim=0)

    return EmbeddingContext(esm_embeddings=merged_embs)


@torch.no_grad()
def run_folding_on_context(
    feature_embedding: torch.nn.Module,
    token_input_embedder: torch.nn.Module,
    trunk: torch.nn.Module,
    feature_context: AllAtomFeatureContext,
    # expose some params for easy tweaking
    num_trunk_recycles: int = 3,
    seed: int | None = None,
    device: torch.device | None = None,
) -> Tensor:
    """
    Function for in-depth explorations.
    User completely controls folding inputs.
    """
    # Set seed
    if seed is not None:
        set_seed([seed])

    if device is None:
        device = torch.device("cuda:0")

    ##
    ## Validate inputs
    ##

    n_actual_tokens = feature_context.structure_context.num_tokens
    raise_if_too_many_tokens(n_actual_tokens)
    raise_if_too_many_templates(feature_context.template_context.num_templates)
    raise_if_msa_too_deep(feature_context.msa_context.depth)
    raise_if_msa_too_deep(feature_context.main_msa_context.depth)

    ##
    ## Prepare batch
    ##

    # Collate inputs into batch
    collator = Collate(
        feature_factory=feature_factory,
        num_key_atoms=128,
        num_query_atoms=32,
    )

    feature_contexts = [feature_context]
    batch = collator(feature_contexts)
    batch = move_data_to_device(batch, device=device)

    # Get features and inputs from batch
    features = {name: feature for name, feature in batch["features"].items()}
    inputs = batch["inputs"]
    block_indices_h = inputs["block_atom_pair_q_idces"]
    block_indices_w = inputs["block_atom_pair_kv_idces"]
    atom_single_mask = inputs["atom_exists_mask"]
    atom_token_indices = inputs["atom_token_index"].long()
    token_single_mask = inputs["token_exists_mask"]
    token_pair_mask = und_self(token_single_mask, "b i, b j -> b i j")
    token_reference_atom_index = inputs["token_ref_atom_index"]
    atom_within_token_index = inputs["atom_within_token_index"]
    msa_mask = inputs["msa_mask"]
    template_input_masks = und_self(
        inputs["template_mask"], "b t n1, b t n2 -> b t n1 n2"
    )
    block_atom_pair_mask = inputs["block_atom_pair_mask"]

    ##
    ## Run the features through the feature embedder
    ##

    embedded_features = feature_embedding.forward(**features)
    token_single_input_feats = embedded_features["TOKEN"]
    token_pair_input_feats, token_pair_structure_input_features = embedded_features[
        "TOKEN_PAIR"
    ].chunk(2, dim=-1)
    atom_single_input_feats, atom_single_structure_input_features = embedded_features[
        "ATOM"
    ].chunk(2, dim=-1)
    block_atom_pair_input_feats, block_atom_pair_structure_input_feats = (
        embedded_features["ATOM_PAIR"].chunk(2, dim=-1)
    )
    template_input_feats = embedded_features["TEMPLATES"]
    msa_input_feats = embedded_features["MSA"]

    ##
    ## Run the inputs through the token input embedder
    ##

    token_input_embedder_outputs: tuple[Tensor, ...] = token_input_embedder.forward(
        token_single_input_feats=token_single_input_feats,
        token_pair_input_feats=token_pair_input_feats,
        atom_single_input_feats=atom_single_input_feats,
        block_atom_pair_feat=block_atom_pair_input_feats,
        block_atom_pair_mask=block_atom_pair_mask,
        block_indices_h=block_indices_h,
        block_indices_w=block_indices_w,
        atom_single_mask=atom_single_mask,
        atom_token_indices=atom_token_indices,
    )
    token_single_initial_repr, token_single_structure_input, token_pair_initial_repr = (
        token_input_embedder_outputs
    )

    ##
    ## Run the input representations through the trunk
    ##

    # Recycle the representations by feeding the output back into the trunk as input for
    # the subsequent recycle
    token_single_trunk_repr = token_single_initial_repr
    token_pair_trunk_repr = token_pair_initial_repr
    for _ in range(num_trunk_recycles):
        (token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward(
            token_single_trunk_initial_repr=token_single_initial_repr,
            token_pair_trunk_initial_repr=token_pair_initial_repr,
            token_single_trunk_repr=token_single_trunk_repr,  # recycled
            token_pair_trunk_repr=token_pair_trunk_repr,  # recycled
            msa_input_feats=msa_input_feats,
            msa_mask=msa_mask,
            template_input_feats=template_input_feats,
            template_input_masks=template_input_masks,
            token_single_mask=token_single_mask,
            token_pair_mask=token_pair_mask,
        )
    return token_single_trunk_repr[token_single_mask]



class AA_filebase:
    def __init__(self, save_path: str) -> None:
        self.addr = save_path
        
        # Create save folder if not exists
        if not os.path.exists(self.addr):
            os.makedirs(self.addr)
        
        # Create index file if not exists
        self.index_path = self.addr.rstrip('/') + '_index.csv'
        if not os.path.exists(self.index_path):
            with open(self.index_path, 'w') as f:
                f.write('seq,filename\n')

        # Load index file
        self.index = dict()
        with open(self.index_path, 'r') as f:
            for line in f.readlines()[1:]:
                seq, filename = line.strip().split(',')
                self.index[seq] = filename
        self.top = len(self.index)
        print(f'Loaded {self.top} embeddings from {self.addr}')
    
    def __len__(self,) -> int:
        return len(self.index)
    
    def update(self, data: Dict[str, Dict[str, Any]]) -> Dict:
        for seq, log in data.items():
            # if seq in self.index: continue
            self.top += 1
            filename = log.get('filename', f"{self.top}_{time.time()}.npy")
            if isinstance(log['embed'], torch.Tensor):
                if log['embed'].dtype == torch.bfloat16:
                    embed = log['embed'].to(torch.float32).numpy()
                else:
                    embed = log['embed'].numpy()
            else:
                embed = log['embed']
            
            self.index[seq] = filename
            np.save(os.path.join(self.addr, filename), embed)
            with open(self.index_path, 'a') as f:
                f.write(f"{seq},{self.index[seq]}\n")
        return {"status": "success", "message": f"Saved {len(data)} embeddings"}
    
    def fetch(self, keys: Union[str, List]) -> Dict[str, Dict[str, Any]]:
        if isinstance(keys, str): keys = [keys]
        result = {}
        for key in keys:
            if key in self.index:
                file_path = os.path.join(self.addr, self.index[key])
                assert os.path.exists(file_path), f'BUGs in os.path.exits: {key}={file_path}'
                embed = np.load(file_path)
                result[key] = dict(embed=torch.tensor(embed))
        return result


def main(seq_type: str, use_esm_embeddings: bool, ratio: Tuple[float, float] = (0.0, 1.0)):
    sequences = get_seqs_in_txts('path/to/your/dataset')
    folder = f"path/to/your/embedding/{seq_type}_CHAI1{'ESM' if use_esm_embeddings else ''}"
    sequences = sorted(sequences)[int(len(sequences)*ratio[0]): int(len(sequences)*ratio[1])]
    filebase = AA_filebase(folder)
    if seq_type == 'RNA':
        n_actual_tokens = 501
        seqs = [seq.replace('T', 'U') for seq in sequences if seq.replace('T', 'U') not in filebase.index]
    elif seq_type == 'DNA':
        n_actual_tokens = 501
        seqs = [seq for seq in sequences if seq not in filebase.index]
    elif seq_type == 'AA':
        n_actual_tokens = 167
        seqs = [str(Seq(seq).translate()) for seq in sequences if str(Seq(seq).translate()) not in filebase.index]
        seqs = list(set(seqs))
    print(f"Loaded {len(seqs)} unique unseen sequences to extract embeddings.")
    device = torch.device("cuda:0")
    model = CHAI1(device=device, n_actual_tokens=n_actual_tokens, use_esm_embeddings=use_esm_embeddings)

    t0 = time.time()
    for i, seq in enumerate(seqs):
        if seq in filebase.index: continue
        
        t1 = time.time()
        embeds = model.run_inference([seq], seq_type, device=device)
        
        filebase.update({seq: {"embed": embeds}})
        print(f"[CHAI1{'ESM' if use_esm_embeddings else ''}-{seq_type} {i+1}/{len(seqs)}] Time taken: {time.time() - t1:.2f} seconds, Elapsed time: {(time.time() - t0)/3600:.2f} hours, ETA: {(time.time() - t0) / (i+1) * (len(seqs) - i - 1) / 3600:.2f} hours")


import subprocess
def get_free_gpu_memory():
    # 使用nvidia-smi命令获取GPU内存使用情况
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
        stdout=subprocess.PIPE, encoding='utf-8'
    )
    # 解析返回的结果，得到空闲内存（单位MB）
    free_memory = int(result.stdout.strip().split('\n')[0])
    return free_memory


def wait_for_gpu_memory(required_free_memory=1000):
    start_time = time.time()
    cnt = 0
    while True:
        free_memory = get_free_gpu_memory()
        elapsed_time = time.time() - start_time
        print(f"\rFree GPU Memory: {free_memory}MB | Elapsed Time: {int(elapsed_time)}s", end="")
        
        if free_memory > required_free_memory:
            print("\nSufficient GPU memory available. Proceeding with the program...")
            cnt += 1
            if cnt == 10:
                break
            # time.sleep(1)
        else:
            time.sleep(5)  # 等待5秒再检查


if __name__ == '__main__':
    wait_for_gpu_memory(7000)
    for use_esm_embeddings in [False]:
        main('DNA', use_esm_embeddings, ratio=(0.0, 0.4))
