import logging
import re
from abc import ABC, abstractmethod
from functools import partial
from types import SimpleNamespace
from typing import Dict, List, Literal, Optional

import numpy as np
import torch
import tqdm as tqdm
from datasets import Dataset
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    AutoTokenizer,
    BatchEncoding,
    DefaultDataCollator,
    T5EncoderModel,
    T5Tokenizer,
)
from transformers.modeling_outputs import BaseModelOutput

from .modality import Modality
from .eval_utils import ForwardHook, pool

logger = logging.getLogger(__name__)


class BioSeqTransformer(ABC):
    """
    Abstract class to wrap models which map biological sequences (DNA/Prot) to embeddings.
    Modelled after SentenceTransformer (https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py)

    Args:
        model_name: Name or path to the pretrained model.
        layers: List of model layers to probe. Can be integers or "mid" or "last".
        devices: List of device ids for inference. If cuda is not available, will use cpu.
        num_processes: Number of processes to use for data loading.
        max_seq_length: Maximum sequence length of the input sequences.
        l2_norm: If true, embeddings are L2-normalized before they are returned.
        batch_size: Batch size for encoding.
        pool_type: Pooling strategy to use. One of "mean", "max", "cls", "last".
    """

    def __init__(
        self,
        model_name: str,
        layers: Optional[List[int] | Literal["mid"] | Literal["last"]] = None,
        devices: List[int] = [0],
        num_processes: int = 16,
        max_seq_length: int = 1024,
        l2_norm: bool = False,
        batch_size: int = 128,
        pool_type: str = "mean",
    ):
        super().__init__()

        self.id = self.__class__.__name__
        self.hf_name = model_name
        self.encoder = self._load_model(model_name)
        if not hasattr(self.encoder, "config"):
            raise ValueError(
                'The model from `self._load_model()` must have a "config" attribute.'
            )
        self.config = self.encoder.config
        self.tokenizer = self._get_tokenizer(model_name)
        self.num_param = sum(p.numel() for p in self.encoder.parameters())
        self.data_collator = DefaultDataCollator()
        self.gpu_count = len(devices)
        self.l2_norm = l2_norm

        self.device = torch.device(
            f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu"
        )
        self.num_processes = num_processes
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.pool_type = pool_type

        if self.gpu_count > 1:
            self.encoder = torch.nn.DataParallel(self.encoder, device_ids=devices)
        self.encoder.to(self.device)
        self.encoder.eval()

        mid_layer = self.num_layers // 2
        last_layer = self.num_layers - 1
        mid_layer_label = f"mid ({mid_layer})"
        last_layer_label = f"last ({self.num_layers - 1})"

        if layers is None:
            logger.debug(f"Using default layers: {mid_layer_label}, {last_layer_label}")
            self.layers = [mid_layer, last_layer]
            self.layer_labels = [mid_layer_label, last_layer_label]
        elif layers == "mid":
            self.layers = [mid_layer]
            self.layer_labels = [mid_layer_label]
        elif layers == "last":
            self.layers = [last_layer]
            self.layer_labels = [last_layer_label]
        else:
            self.layers = layers
            self.layer_labels = [str(layer) for layer in layers]

    def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
        """Returns the output embedding for the given batch with shape [batch, num_layers, D]."""
        outputs = self.encoder(**batch_dict, output_hidden_states=True)
        embeds = [outputs.hidden_states[layer] for layer in self.layers]
        embeds = [
            pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
            for layer_embeds in embeds
        ]
        # Stack with shape [B, num_layers, D].
        embeds = torch.stack(embeds, dim=1)
        return embeds

    def _load_model(self, model_name):
        return AutoModel.from_pretrained(model_name, trust_remote_code=True)

    def _get_tokenizer(self, model_name):
        return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    def _tokenize_func(
        self, tokenizer, examples: Dict[str, List], max_seq_length: int
    ) -> BatchEncoding:
        batch_dict = tokenizer(
            examples["input_seqs"],
            max_length=max_seq_length,
            padding=True,
            truncation=True,
        )
        return batch_dict

    @property
    def metadata(self) -> Dict:
        return {
            "hf_name": self.hf_name,
            "num_layers": self.num_layers,
            "num_params": self.num_param,
            "embed_dim": self.embed_dim,
        }

    @property
    @abstractmethod
    def num_layers(self) -> int:
        pass

    @property
    @abstractmethod
    def embed_dim(self) -> int:
        pass

    @property
    @abstractmethod
    def modality(self) -> Modality:
        pass

    @torch.no_grad()
    def encode(self, sequences, **kwargs) -> np.ndarray:
        """Returns a list of embeddings for the given sequences.
        Args:
            sequences (`List[str]`): List of sequences to encode
        Returns:
            `np.ndarray`: Embeddings for the given sequences of shape [num_sequences, num_layers, embedding_dim].
        """
        dataset = Dataset.from_dict({"input_seqs": sequences})
        dataset.set_transform(
            partial(
                self._tokenize_func, self.tokenizer, max_seq_length=self.max_seq_length
            )
        )
        data_loader = DataLoader(
            dataset,
            batch_size=self.batch_size * self.gpu_count,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_processes,
            collate_fn=self.data_collator,
            pin_memory=True,
        )

        if max(self.layers) >= self.num_layers:
            raise ValueError(
                f"Layer {max(self.layers)} is not available in the model. Choose a layer between 0 and {self.num_layers - 1}"
            )

        encoded_embeds = []
        for batch_dict in tqdm.tqdm(
            data_loader, desc="encoding", mininterval=10, disable=len(sequences) < 128
        ):
            batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()}

            embeds = self._encode_single_batch(batch_dict)

            if self.l2_norm:
                embeds = F.normalize(embeds, p=2, dim=-1)
            encoded_embeds.append(embeds.cpu().numpy())

        return np.concatenate(encoded_embeds, axis=0)


class ESM(BioSeqTransformer):
    """ESM model from https://huggingface.co/docs/transformers/en/model_doc/esm"""

    MODEL_NAMES = [
        "facebook/esm2_t6_8M_UR50D",
        "facebook/esm2_t12_35M_UR50D",
        "facebook/esm2_t30_150M_UR50D",
        "facebook/esm2_t33_650M_UR50D",
        "facebook/esm2_t36_3B_UR50D",
        "facebook/esm2_t48_15B_UR50D",
    ]

    @property
    def modality(self) -> Modality:
        return Modality.PROTEIN

    @property
    def num_layers(self) -> int:
        return self.config.num_hidden_layers

    @property
    def embed_dim(self) -> int:
        return self.config.hidden_size


class ESM3(BioSeqTransformer):
    """ESM3 model from https://github.com/evolutionaryscale/esm"""

    MODEL_NAMES = ["esm3_sm_open_v1"]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Register forward hooks to store embeddings per layer.
        self.hooks = [
            ForwardHook(self.encoder.transformer.blocks[layer]) for layer in self.layers
        ]

    @property
    def modality(self) -> Modality:
        return Modality.PROTEIN

    @property
    def num_layers(self) -> int:
        return self.config.num_hidden_layers

    @property
    def embed_dim(self) -> int:
        return self.config.hidden_size

    def _load_model(self, model_name):
        try:
            from esm.models.esm3 import ESM3 as ModelESM3
        except ImportError:
            raise ImportError(
                "ESM3 is not installed. Please install it with `pip install esm`."
            )
        model = ModelESM3.from_pretrained("esm3_sm_open_v1")
        model.config = SimpleNamespace(
            num_hidden_layers=len(model.transformer.blocks),
            hidden_size=model.transformer.blocks[0].ffn[-1].out_features,
        )
        return model

    def _get_tokenizer(self, model_name):
        try:
            from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
        except ImportError:
            raise ImportError(
                "ESM3 is not installed. Please install it with `pip install esm`."
            )
        return EsmSequenceTokenizer()

    def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
        _ = self.encoder.forward(sequence_tokens=batch_dict["input_ids"])
        embeds = [hook.output for hook in self.hooks]
        embeds = [
            pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
            for layer_embeds in embeds
        ]
        # Stack with shape [B, num_layers, D].
        embeds = torch.stack(embeds, dim=1)
        embeds = embeds.to(torch.float32)
        return embeds


class ProtT5(BioSeqTransformer):
    """ProtT5 model from https://github.com/agemagician/ProtTrans"""

    MODEL_NAMES = [
        "Rostlab/prot_t5_xl_uniref50",
        "Rostlab/prot_t5_xl_bfd",
        "Rostlab/prot_t5_xxl_uniref50",
        "Rostlab/prot_t5_xxl_bfd",
    ]

    @property
    def modality(self) -> Modality:
        return Modality.PROTEIN

    @property
    def num_layers(self) -> int:
        return self.config.num_layers

    @property
    def embed_dim(self) -> int:
        return self.config.d_model

    def _load_model(self, model_name):
        return T5EncoderModel.from_pretrained(model_name)

    def _get_tokenizer(self, model_name):
        return T5Tokenizer.from_pretrained(model_name, do_lower_case=False)

    def _tokenize_func(
        self, tokenizer, examples: Dict[str, List], max_seq_length: int
    ) -> BatchEncoding:
        example_sequences = examples["input_seqs"]
        # Add space between amino acids to make sure they are tokenized correctly.
        example_sequences = [" ".join(sequence) for sequence in example_sequences]
        example_sequences = [
            re.sub(r"[UZOB]", "X", sequence) for sequence in example_sequences
        ]
        batch_dict = tokenizer(
            example_sequences,
            max_length=max_seq_length,
            padding=True,
            truncation=True,
            add_special_tokens=True,
        )

        return batch_dict


class ProGen(BioSeqTransformer):
    """ProGen models from https://github.com/salesforce/progen."""

    MODEL_NAMES = [
        "hugohrban/progen2-small",
        "hugohrban/progen2-medium",
        "hugohrban/progen2-base",
        "hugohrban/progen2-large",
        "hugohrban/progen2-xlarge",
    ]

    @property
    def modality(self) -> Modality:
        return Modality.PROTEIN

    @property
    def num_layers(self) -> int:
        return self.config.n_layer

    @property
    def embed_dim(self) -> int:
        return self.config.embed_dim

    def _load_model(self, model_name):
        return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    def _get_tokenizer(self, model_name_or_path):
        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        tokenizer.pad_token = "<|pad|>"
        return tokenizer

    def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
        """Returns the output embedding for the given batch with shape [batch, num_layers, D]."""
        outputs: BaseModelOutput = self.encoder(
            input_ids=batch_dict["input_ids"],
            output_hidden_states=True,
            use_cache=False,
        )
        embeds = [outputs.hidden_states[layer] for layer in self.layers]
        embeds = [
            pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
            for layer_embeds in embeds
        ]
        # Stack with shape [B, num_layers, D].
        embeds = torch.stack(embeds, dim=1)
        return embeds


class EvoModel(BioSeqTransformer):
    """https://github.com/evo-design/evo."""

    MODEL_NAMES = [
        "togethercomputer/evo-1-8k-base",
        "togethercomputer/evo-1-131k-base",
    ]

    @property
    def modality(self) -> Modality:
        return Modality.DNA

    @property
    def num_layers(self) -> int:
        return self.config.num_layers

    @property
    def embed_dim(self) -> int:
        return self.config.hidden_size

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Register forward hooks to store embeddings per layer.
        self.hooks = []
        for layer in self.layers:
            # For the last layer, get the output of `backbone.norm`, which directly precedes `backbone.unembed`.
            # This is equivalent to the approach in https://github.com/evo-design/evo/issues/32.
            if layer == self.num_layers - 1 or layer == -1:
                self.hooks.append(ForwardHook(self.encoder.backbone.norm))
            else:
                self.hooks.append(ForwardHook(self.encoder.backbone.blocks[layer]))

    def _load_model(self, model_name):
        config = AutoConfig.from_pretrained(
            model_name, trust_remote_code=True, revision="1.1_fix"
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_name, config=config, trust_remote_code=True, revision="1.1_fix"
        )
        return model

    def _get_tokenizer(self, model_name):
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, revision="1.1_fix", trust_remote_code=True
        )
        # Evo tokenizer is missing pad_token by default.
        tokenizer.add_special_tokens({"pad_token": "N"})
        return tokenizer

    def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
        _ = self.encoder(batch_dict["input_ids"], use_cache=False)
        embeds = [hook.output for hook in self.hooks]
        # The hook output for Evo middle layers is a tuple (embedding, inference_params=None).
        embeds = [x[0] if isinstance(x, tuple) else x for x in embeds]
        embeds = [
            pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
            for layer_embeds in embeds
        ]
        # Stack with shape [B, num_layers, D].
        embeds = torch.stack(embeds, dim=1)
        embeds = embeds.to(torch.float32)
        return embeds


class NTModel(BioSeqTransformer):
    """Nucleotide Transformer https://github.com/instadeepai/nucleotide-transformer"""

    MODEL_NAMES = [
        "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species",
        "InstaDeepAI/nucleotide-transformer-v2-100m-multi-species",
        "InstaDeepAI/nucleotide-transformer-v2-250m-multi-species",
        "InstaDeepAI/nucleotide-transformer-v2-500m-multi-species",
        "InstaDeepAI/nucleotide-transformer-2.5b-multi-species",
    ]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_seq_length = self.tokenizer.model_max_length

    @property
    def modality(self) -> Modality:
        return Modality.DNA

    @property
    def num_layers(self) -> int:
        return self.config.num_hidden_layers

    @property
    def embed_dim(self) -> int:
        return self.config.hidden_size

    def _load_model(self, model_name):
        return AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
