import re

import lightning as L
import more_itertools as mit
import torch
from tqdm import tqdm
from transformers import T5EncoderModel, T5Tokenizer

torch.set_float32_matmul_precision("medium")


class ProtT5(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
        self.model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.bfloat16)
        self.model.eval()

        self.d_model = 1024

    @torch.inference_mode()
    def encode_proteins(
        self,
        proteins: str | list[str],
        bsz: int = 32,
        disable_pbar: bool = False,
        maxlen: int = 512,
    ):
        if isinstance(proteins, str):
            proteins = [proteins]

        # Sort by length and track indices
        sorted_pairs = sorted(zip(proteins, range(len(proteins))), key=lambda x: len(x[0]), reverse=True)
        all_proteins, orig_indices = zip(*sorted_pairs)

        all_reps = torch.zeros((len(proteins), self.d_model), device=self.device)

        for batch_proteins, batch_indices in tqdm(
            zip(mit.chunked(all_proteins, bsz), mit.chunked(orig_indices, bsz)),
            total=len(proteins) // bsz,
            leave=False,
            desc="Encoding",
            disable=disable_pbar,
        ):
            batch_proteins = [sanitize_sequence(prot, maxlen=maxlen) for prot in batch_proteins]
            ids = self.tokenizer.batch_encode_plus(batch_proteins, add_special_tokens=True, padding="longest")
            input_ids = torch.tensor(ids["input_ids"]).to(self.device)
            attention_mask = torch.tensor(ids["attention_mask"]).to(self.device)

            rep = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

            rep = rep.float() * attention_mask.unsqueeze(-1)
            rep = rep.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)

            all_reps[list(batch_indices)] = rep

        return all_reps


def sanitize_sequence(sequence: str, maxlen: int = 512) -> str:
    seq = list(re.sub(r"[UZOB]", "X", sequence))
    seq = seq[:maxlen]
    return " ".join(seq)
