import lightning as L
import more_itertools as mit
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from torch import Tensor
from tqdm import tqdm

from tdc_fusion.models.tokenization import collate_strings


class MolT5(L.LightningModule):
    def __init__(
        self,
        vocab: dict[str, int],
        d_model: int = 128,
        enc_nhead: int = 8,
        enc_dim_ff: int = 512,
        enc_num_l: int = 6,
        dec_nhead: int = 8,
        dec_dim_ff: int = 256,
        dec_num_l: int = 6,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.vocab = vocab

        self.d_model = d_model
        self.enc_tok_emb = nn.Sequential(
            nn.Embedding(len(vocab), d_model),
        )

        self.dec_tok_emb = nn.Sequential(
            nn.Embedding(len(vocab), d_model),
        )

        self.encoder = nn.ModuleList(
            [TransformerEncoderLayer(d_model, enc_nhead, enc_dim_ff) for _ in range(enc_num_l)]
        )
        self.decoder = nn.ModuleList(
            [TransformerDecoderLayer(d_model, dec_nhead, dec_dim_ff) for _ in range(dec_num_l)]
        )

        self.logit_norm = nn.LayerNorm(d_model)
        self.logit_proj = nn.Linear(d_model, len(vocab))

    def forward(self, src: Tensor, tgt: Tensor):
        enc = self.encode(src)
        logits = self.decode(tgt, enc)

        return logits

    def encode(self, src: Tensor) -> Tensor:
        embed = self.enc_tok_emb(src)
        mask = src != self.vocab["[PAD]"]
        mask = mask.reshape(mask.shape[0], 1, 1, mask.shape[1])

        for layer in self.encoder:
            embed = layer(embed, mask)

        return embed

    def decode(self, tgt: Tensor, enc: Tensor) -> Tensor:
        embed = self.dec_tok_emb(tgt)

        # Tgt mask and src mask are the same given how we are training
        mask = tgt != self.vocab["[PAD]"]
        mask = mask.reshape(mask.shape[0], 1, 1, mask.shape[1])

        for layer in self.decoder:
            embed = layer(embed, enc, mask)

        embed = self.logit_norm(embed)
        logits = self.logit_proj(embed)
        return logits

    @torch.inference_mode()
    def encode_smiles(
        self,
        smiles: str | list[str],
        bsz: int = 512,
        kekulize: bool = True,
        disable_pbar: bool = False,
        maxlen: int = -1,
    ):
        if isinstance(smiles, str):
            smiles = [smiles]

        # Sort by length and track indices
        sorted_pairs = sorted(zip(smiles, range(len(smiles))), key=lambda x: len(x[0]), reverse=True)
        sorted_smiles, orig_indices = zip(*sorted_pairs)

        all_reps = torch.zeros((len(smiles), self.d_model), device=self.device)

        for batch_smiles, batch_indices in tqdm(
            zip(mit.chunked(sorted_smiles, bsz), mit.chunked(orig_indices, bsz)),
            total=len(smiles) // bsz,
            leave=False,
            desc="Encoding",
            disable=disable_pbar,
        ):
            tokens = collate_strings(batch_smiles, self.vocab, kekulize=kekulize, pad_to_mult=16, maxlen=maxlen).to(
                self.device
            )
            with torch.amp.autocast(device_type=self.device.type, dtype=torch.bfloat16):
                rep = self.encode(tokens)

            rep = rep.float()

            mask = tokens != self.pad_tok

            rep = rep * mask.unsqueeze(-1)
            rep = rep.sum(dim=1) / mask.sum(dim=1, keepdim=True)

            all_reps[list(batch_indices)] = rep

        return all_reps

    @property
    def start_tok(self):
        return self.vocab["[START]"]

    @property
    def stop_tok(self):
        return self.vocab["[STOP]"]

    @property
    def pad_tok(self):
        return self.vocab["[PAD]"]

    @property
    def mask_tok(self):
        return self.vocab["[MASK]"]


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048) -> None:
        super().__init__()

        self.self_attn = SelfAttention(d_model, nhead)

        self.ff_block = FFNSwiGLU(d_model, dim_feedforward)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(
        self,
        x: Tensor,
        mask: Tensor,
    ) -> Tensor:
        x = x + self.self_attn(self.norm1(x), mask)
        x = x + self.ff_block(self.norm2(x))

        return x


class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int):
        super().__init__()

        self.self_attn = SelfAttention(d_model, nhead)
        self.cross_attn = CrossAttention(d_model, nhead)

        self.ff_block = FFNSwiGLU(d_model, dim_feedforward)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)

    def forward(self, x: Tensor, mem: Tensor, mask: Tensor | None = None) -> Tensor:
        x = x + self.self_attn(self.norm1(x), causal=True)
        x = x + self.cross_attn(self.norm2(x), self.norm3(mem), mask)
        x = x + self.ff_block(self.norm4(x))
        return x


class CrossAttention(nn.Module):
    def __init__(self, embed_dim: int, heads: int):
        super().__init__()

        self.embed_size = embed_dim
        self.num_heads = heads
        self.head_dim = embed_dim // heads

        assert self.head_dim * heads == embed_dim, "Embedding size needs to be divisible by heads"

        self.q_proj = nn.Linear(self.embed_size, self.embed_size)
        self.kv_proj = nn.Linear(self.embed_size, self.embed_size * 2)

        self.out_proj = nn.Linear(heads * self.head_dim, embed_dim)

        self.rotary_emb = RotaryEmbedding(self.head_dim // 2)

    def forward(
        self,
        query: Tensor,
        kv: Tensor,
        mask: Tensor | None = None,
    ) -> Tensor:
        q = self.q_proj(query)
        k, v = self.kv_proj(kv).chunk(2, dim=-1)

        q = rearrange(q, "... n (h d) -> ... h n d", h=self.num_heads)
        k = rearrange(k, "... n (h d) -> ... h n d", h=self.num_heads)
        v = rearrange(v, "... n (h d) -> ... h n d", h=self.num_heads)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

        attn = rearrange(attn, "... h n d -> ... n (h d)")
        return self.out_proj(attn)


class SelfAttention(nn.Module):
    def __init__(self, embed_dim: int, heads: int):
        super().__init__()

        self.embed_size = embed_dim
        self.num_heads = heads
        self.head_dim = embed_dim // heads

        assert self.head_dim * heads == embed_dim, "Embedding size needs to be divisible by heads"

        self.qkv_proj = nn.Linear(self.embed_size, self.embed_size * 3)
        self.out_proj = nn.Linear(heads * self.head_dim, embed_dim)

        self.rotary_emb = RotaryEmbedding(self.head_dim // 2)

    def forward(
        self,
        x: Tensor,
        mask: Tensor | None = None,
        causal: bool = False,
    ) -> Tensor:
        assert not (causal and mask is not None), "Causal and mask are mutually exclusive"

        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)

        q = rearrange(q, "... n (h d) -> ... h n d", h=self.num_heads)
        k = rearrange(k, "... n (h d) -> ... h n d", h=self.num_heads)
        v = rearrange(v, "... n (h d) -> ... h n d", h=self.num_heads)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        attn = F.scaled_dot_product_attention(q, k, v, is_causal=causal, attn_mask=mask)

        attn = rearrange(attn, "... h n d -> ... n (h d)")
        return self.out_proj(attn)


class FFNSwiGLU(nn.Module):
    """
    GLU Variants Improve Transformer
    https://arxiv.org/abs/2002.05202
    """

    def __init__(
        self,
        dim: int,
        dim_feedforward: int,
        out_dim: int | None = None,
        use_bias: bool = True,
    ):
        super().__init__()

        out_dim = out_dim or dim

        self.ff1 = nn.Linear(dim, dim_feedforward * 2, bias=use_bias)
        self.ff2 = nn.Linear(dim_feedforward, out_dim, bias=use_bias)

    def forward(self, x: Tensor) -> Tensor:
        y, gate = self.ff1(x).chunk(2, dim=-1)
        x = y * F.silu(gate)
        return self.ff2(x)
