import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class MolCLIP(L.LightningModule):
    def __init__(
        self,
        mol_dim: int,
        text_dim: int,
        emb_dim: int,
        normalize_output: bool = False,
        hdim: int | None = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.mol_dim = mol_dim
        self.text_dim = text_dim
        self.hdim = hdim if hdim is not None else 2 * emb_dim

        self.text_proj = nn.Sequential(
            nn.Linear(text_dim, self.hdim),
            nn.SiLU(),
            nn.Linear(self.hdim, self.hdim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hdim, emb_dim),
            UnitNorm(normalize_output=normalize_output),
        )

        self.mol_proj = nn.Sequential(
            nn.Linear(mol_dim, self.hdim),
            nn.SiLU(),
            nn.Linear(self.hdim, self.hdim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hdim, emb_dim),
            UnitNorm(normalize_output=normalize_output),
        )

        _t = 0.1 if normalize_output else 0.5

        self._t = nn.Parameter(torch.tensor(_t).log())

    def project_mol(self, mol_emb: Tensor, featureize_mol: bool = False):
        mol_emb = F.normalize(mol_emb, p=2, dim=-1)

        if featureize_mol:  # noqa: SIM108
            # Stop before dropout and normalization
            mol_emb = self.mol_proj[:-3](mol_emb)
        else:
            mol_emb = self.mol_proj(mol_emb)

        return mol_emb

    def project_text(self, text_emb: Tensor):
        text_emb = F.normalize(text_emb, p=2, dim=-1)
        text_emb = self.text_proj(text_emb)
        return text_emb

    def forward(self, mol_emb: Tensor, text_emb: Tensor):
        mol_emb = self.project_mol(mol_emb)
        text_emb = self.project_text(text_emb)

        loss, metrics = clip_loss(
            mol_emb=mol_emb,
            text_emb=text_emb,
            t=self.t,
        )
        metrics["t"] = self.t.detach()

        return loss, metrics

    def training_step(self, batch, batch_idx):  # noqa: ANN001
        mol_emb, text_emb = batch
        loss, metrics = self.forward(mol_emb, text_emb)

        metrics = {f"train/{k}": v for k, v in metrics.items()}
        self.log_dict(metrics, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):  # noqa: ANN001
        mol_emb, text_emb = batch
        loss, metrics = self.forward(mol_emb, text_emb)

        metrics = {f"val/{k}": v for k, v in metrics.items()}
        self.log_dict(metrics, prog_bar=True, sync_dist=True)

        return loss

    def configure_optimizers(self):  # type: ignore
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-4,
            betas=(0.9, 0.99),
            weight_decay=0.01,
        )
        # 4096 step warmup
        sched = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: min(1.0, step / 4096),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": sched,
                "interval": "step",
                "frequency": 1,
            },
        }

    @property
    def t(self):
        return self._t.exp()


class UnitNorm(nn.Module):
    def __init__(self, normalize_output: bool = True):
        super().__init__()
        self.normalize_output = normalize_output

    def forward(self, x: Tensor) -> Tensor:
        if self.normalize_output:
            x = F.normalize(x, p=2, dim=-1)

        return x


@torch.compile
def clip_loss(
    mol_emb: Tensor,
    text_emb: Tensor,
    t: Tensor,
) -> tuple[Tensor, dict[str, Tensor]]:
    with torch.amp.autocast(enabled=False, device_type=mol_emb.device.type):
        logits = torch.matmul(mol_emb.float(), text_emb.t().float()) / t.float()

        # CE in both directions
        labels = torch.arange(mol_emb.shape[0], device=mol_emb.device)
        mol_loss = F.cross_entropy(logits, labels)
        text_loss = F.cross_entropy(logits.t(), labels)

        # Total loss is the average of both directions
        total_loss = (mol_loss + text_loss) / 2.0

    # Compute "accuracy" as the number of correct predictions
    with torch.no_grad():
        mol_acc = (logits.argmax(dim=1) == labels).float().mean()
        text_acc = (logits.t().argmax(dim=1) == labels).float().mean()

    return total_loss, {"loss": total_loss, "mol_acc": mol_acc, "text_acc": text_acc}
