import torch
from torch import nn
from typing import Callable
import torch.nn.functional as F

from configs import Config


CallableEncoder = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]


class Encoder(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.config = config

    def forward(
        self, codebook: torch.Tensor, embeddings: torch.Tensor, pad_token_id: int
    ) -> torch.Tensor:
        pass

    @staticmethod
    def from_config(config: Config) -> "Encoder":
        classes = {
            "attention": AttentionEncoder,
            "transformer": TransformerEncoder,
        }

        name = config.embedding_encoder.embedding_encoder_name

        if name not in classes:
            raise ValueError(f"Invalid encoder type: {name}")

        return classes[name](config)


class AttentionEncoder(Encoder):
    def __init__(self, config: Config) -> None:
        super().__init__(config)

        self.hidden_size = int(
            config.embedding_encoder.unsafe_config.get("hidden_size", 768)
        )
        self.num_heads = int(
            config.embedding_encoder.unsafe_config.get("num_heads", 12)
        )

        self.position_embeddings = nn.Parameter(
            torch.randn(self.config.compression.max_subtokens, self.hidden_size)
            * self.hidden_size**-0.5
        )

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

    def forward(
        self, codebook: torch.Tensor, embeddings: torch.Tensor, pad_token_id: int
    ) -> torch.Tensor:
        H, S = codebook.shape

        codebook_embeddings = (
            F.embedding(codebook, embeddings, padding_idx=pad_token_id)
            + self.position_embeddings
        )
        queries = self.q_proj(codebook_embeddings)
        keys = self.k_proj(codebook_embeddings)
        values = self.v_proj(codebook_embeddings)

        mask = codebook != pad_token_id
        mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
        output = F.scaled_dot_product_attention(
            queries.view(H, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            keys.view(H, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            values.view(H, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            attn_mask=mask.unsqueeze(1),
        )
        output = output.transpose(1, 2).contiguous().view(H, S, -1)
        return self.o_proj(output).mean(dim=1)


class MLP(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self.hidden_size = int(
            config.embedding_encoder.unsafe_config.get("hidden_size", 768)
        )
        self.intermediate_size = int(
            config.embedding_encoder.unsafe_config.get(
                "intermediate_size", self.hidden_size * 4
            )
        )

        self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.gelu(self.fc1(x)))


class SelfAttention(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self.hidden_size = int(
            config.embedding_encoder.unsafe_config.get("hidden_size", 768)
        )
        self.num_heads = int(
            config.embedding_encoder.unsafe_config.get("num_heads", 12)
        )

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        B, S, _ = x.size()
        queries = self.q_proj(x)
        keys = self.k_proj(x)
        values = self.v_proj(x)

        output = F.scaled_dot_product_attention(
            queries.view(B, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            keys.view(B, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            values.view(B, S, self.num_heads, -1).transpose(1, 2).contiguous(),
            attn_mask=mask,
        )
        output = output.transpose(1, 2).contiguous().view(B, S, -1)
        return self.o_proj(output)


class Layer(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self.hidden_size = int(
            config.embedding_encoder.unsafe_config.get("hidden_size", 768)
        )

        self.mlp = MLP(config)
        self.attention = SelfAttention(config)
        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size)
        self.post_mlp_layernorm = nn.LayerNorm(self.hidden_size)

    def __call__(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.post_attention_layernorm(x + self.attention(x, mask))
        x = self.post_mlp_layernorm(x + self.mlp(x))
        return x


class TransformerEncoder(Encoder):
    def __init__(self, config: Config) -> None:
        super().__init__(config)

        self.hidden_size = int(
            config.embedding_encoder.unsafe_config.get("hidden_size", 768)
        )

        self.position_embeddings = nn.Parameter(
            torch.randn(self.config.compression.max_subtokens, self.hidden_size)
            * self.hidden_size**-0.5
        )

        self.layers = nn.ModuleList(
            [
                Layer(config)
                for _ in range(
                    int(
                        config.embedding_encoder.unsafe_config.get(
                            "num_hidden_layers", 4
                        )
                    )
                )
            ]
        )

    def forward(
        self, codebook: torch.Tensor, embeddings: torch.Tensor, pad_token_id: int
    ) -> torch.Tensor:
        x = embeddings[codebook] + self.position_embeddings

        mask = codebook != pad_token_id
        mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).unsqueeze(1)

        for layer in self.layers:
            x = layer(x, mask)

        return x.mean(dim=1)
