import warnings
from pathlib import Path

import lightning as L
import torch
from sentence_transformers import SentenceTransformer
from tqdm import trange

from tdc_fusion.models.MolCLIP import MolCLIP
from tdc_fusion.models.MolT5 import MolT5

warnings.filterwarnings("ignore")


class MolCLIPWrapper(L.LightningModule):
    def __init__(
        self,
        text_embedder: str,
        clip_cpkt: Path | str = "./checkpoints/ChemCLIP.ckpt",
        t5_cpkt: Path | str = "./checkpoints/MolT5.ckpt",
    ):
        super().__init__()
        self.save_hyperparameters()

        self.clip = MolCLIP.load_from_checkpoint(clip_cpkt)
        self.clip.freeze()

        self.t5 = MolT5.load_from_checkpoint(t5_cpkt)
        self.t5.freeze()

        self.text_embedder = SentenceTransformer(text_embedder, model_kwargs={"torch_dtype": "bfloat16"})
        self.text_embedder.eval()

    def encode_smiles(
        self,
        smiles: str | list[str],
        bsz: int = 512,
        kekulize: bool = True,
        disable_pbar: bool = False,
        maxlen: int = -1,
        skip_clip: bool = False,
        featureize_mol: bool = False,  # Don't project into CLIP space, return high dim rep after GELU
    ) -> torch.Tensor:
        """Encodes one or more SMILES strings into embeddings."""
        if isinstance(smiles, str):
            smiles = [smiles]

        t5_encodings = self.t5.encode_smiles(
            smiles,
            bsz=bsz,
            kekulize=kekulize,
            disable_pbar=disable_pbar,
            maxlen=maxlen,
        )

        if skip_clip:
            return t5_encodings

        clip_embeddings = self.clip.project_mol(
            t5_encodings.to(dtype=self.clip.dtype, device=self.clip.device),  # type: ignore
            featureize_mol=featureize_mol,
        )
        return clip_embeddings

    def project_mols(
        self,
        embeddings: torch.Tensor,
        featureize_mol: bool = False,
    ):
        outputs = []
        for i in trange(0, embeddings.shape[0], 512, leave=False):
            batch = embeddings[i : i + 512].to(self.device)
            batch = self.clip.project_mol(batch, featureize_mol=featureize_mol)
            outputs.append(batch)

        return torch.cat(outputs, dim=0).cpu()

    def encode_text(
        self,
        text: str | list[str],
        bsz: int = 512,
        skip_clip: bool = False,
    ) -> torch.Tensor:
        """Encodes one or more text strings into embeddings."""
        if isinstance(text, str):
            text = [text]  # Ensure it's a list

        text_embeddings = self.text_embedder.encode(
            text,
            convert_to_tensor=True,
            batch_size=bsz,
        )

        if skip_clip:
            return text_embeddings.to(dtype=self.clip.dtype, device=self.clip.device)  # type: ignore

        clip_embeddings = self.clip.project_text(text_embeddings.to(dtype=self.clip.dtype, device=self.clip.device))  # type: ignore
        return clip_embeddings

    def batch_inference(
        self,
        smiles: list[str],
        text: list[str],
    ):
        mol_embeddings = self.encode_smiles(smiles).float()
        text_embeddings = self.encode_text(text).float()

        logits = torch.matmul(mol_embeddings, text_embeddings.t()) / self.clip.t.float()

        return logits
