from typing import List, Dict, Any

import torch
from tqdm import tqdm

from esm import pretrained as esm_pretrained

from coarsebind_public.esm2.io_schema import ESM2IOSchema


def check_input(sequence: str):

    # check if all Xs
    if sequence == "X" * len(sequence):
        raise ValueError("All Xs")


class ESM2Infer:

    def __init__(self, model_name: str, rep_layer: int, device: str):

        self.model_name = model_name
        self.rep_layer = rep_layer
        self.device = device

        self.PAD_TOK = 1
        self.START_TOK = 0
        self.STOP_TOK = 2

        self.esm_model = None
        self.esm_tokenizer = None

    def load_model(self):

        # TODO: preload model from local path for use in container
        esm_pretrained._download_model_and_regression_data(self.model_name)

        model, alphabet = esm_pretrained.load_model_and_alphabet_hub(self.model_name)
        model.eval()
        model = model.to(self.device)

        self.esm_model = model
        self.esm_tokenizer = alphabet.get_batch_converter()

        return

    def featurize(self, sequences: List[str], max_seq_len: int) -> Dict[str, Any]:

        sequence_ids = list(range(len(sequences)))

        token_input = list(zip(sequence_ids, sequences))

        esm_tokens = []
        featurize_error = []

        for seq_id, seq in token_input:

            try:
                check_input(seq)

                _, _, tokens = self.esm_tokenizer([(seq_id, seq)])
                tokens = tokens.squeeze(0)

                if len(tokens) > max_seq_len:
                    raise ValueError(
                        f"Sequence length {len(tokens)} exceeds max_seq_len {max_seq_len} for sequence: {seq}"
                    )

                esm_tokens.append(tokens)
                featurize_error.append(False)
            except Exception as e:
                print(f"Error processing {seq}: {e}")

                tokens = torch.tensor([self.STOP_TOK])
                esm_tokens.append(tokens)
                featurize_error.append(True)

        # Pad tokens to max_tokens using torch
        max_tokens = max([len(t) for t in esm_tokens])

        padded_tokens = []
        for tokens in esm_tokens:
            padded_tokens.append(
                torch.tensor(
                    torch.nn.functional.pad(
                        tokens,
                        (0, max_tokens - len(tokens)),
                        mode="constant",
                        value=self.PAD_TOK,
                    )
                )
            )
        padded_tokens = torch.stack(padded_tokens)
        featurize_error = torch.tensor(featurize_error)

        featurized_inputs = {
            "tokens": padded_tokens,
            "featurize_error": featurize_error,
        }

        return featurized_inputs

    def predict(
        self, sequences: List[str], chunk_size=16, max_seq_len=2048
    ) -> List[ESM2IOSchema]:

        if self.esm_model is None or self.esm_tokenizer is None:
            self.load_model()

        output = []
        for i in tqdm(
            range(0, len(sequences), chunk_size),
            desc="ESM2Infer.predict",
            total=len(sequences) // chunk_size + 1,
        ):
            chunk = sequences[i : i + chunk_size]
            output.extend(self._predict(chunk, max_seq_len))

        return output

    @torch.no_grad()
    def _predict(self, sequences: List[str], max_seq_len: int) -> List[ESM2IOSchema]:

        featurized_inputs = self.featurize(sequences, max_seq_len=max_seq_len)
        for key in featurized_inputs:

            if torch.is_tensor(featurized_inputs[key]):
                featurized_inputs[key] = featurized_inputs[key].to(self.device)

        results = self.esm_model(
            featurized_inputs["tokens"], repr_layers=[self.rep_layer]
        )
        embeds = results["representations"][self.rep_layer]
        embeds_mask = featurized_inputs["tokens"] != self.PAD_TOK
        embeds = embeds * embeds_mask.unsqueeze(-1)

        featurize_error = featurized_inputs["featurize_error"]

        output = []
        for i in range(len(sequences)):
            _seq = sequences[i]
            _embed = embeds[i]
            _embed_mask = embeds_mask[i]
            _featurize_err = featurize_error[i].item()

            if _featurize_err:
                # If there was an error in featurization, we skip this sequence
                output.append(
                    ESM2IOSchema(
                        sequence=_seq,
                        model_name=self.model_name,
                        rep_layer=self.rep_layer,
                        embed=None,
                        embed_mask=None,
                        error=True,
                        error_msg="Featurization error",
                    )
                )
                continue

            _unpad_embed = _embed[_embed_mask].cpu().numpy()
            assert _unpad_embed.shape[0] == len(_seq) + 2

            output.append(
                ESM2IOSchema(
                    sequence=_seq,
                    model_name=self.model_name,
                    rep_layer=self.rep_layer,
                    embed=_unpad_embed,
                    embed_mask=_embed_mask.cpu().numpy(),
                    error=False,
                    error_msg=None,
                )
            )

        return output
