import torch.nn as nn
from sentence_transformers import SentenceTransformer

WHERE_TO_REQUIRE_GRADS = {
    "all-MiniLM-L6-v2": "layer.5",
    "sentence-transformers/paraphrase-albert-small-v2": "albert_layer_groups.0.",
    "paraphrase-TinyBERT-L6-v2": "layer.5",
}


class STClassifier(nn.Module):
    """
    Simple implementation of a sentence-transformer classifier.
    """

    def __init__(
        self,
        model_name: str,
        model: SentenceTransformer,
        device: str = "cuda",
        n_class: int = 1,
        normalize: bool = False,
    ):
        super(STClassifier, self).__init__()

        self.model = model.to(device)

        at_least_a_group = False
        for name, param in self.model.named_parameters():
            if WHERE_TO_REQUIRE_GRADS[model_name] in name:
                param.requires_grad = True
                at_least_a_group = True
            else:
                param.requires_grad = False

        assert at_least_a_group, "No group from ST was activated"

        if n_class == 1:
            self.classifier = nn.Sequential(
                nn.Linear(self.model.get_sentence_embedding_dimension(), 1),
                nn.Sigmoid(),
            ).to(device)
        else:
            self.classifier = nn.Sequential(
                nn.Linear(
                    self.model.get_sentence_embedding_dimension(), n_class
                ),
                nn.Softmax(dim=1),
            ).to(device)

        self.device = device
        self.normalize = normalize

    def forward(self, x):
        x = {
            key: value.to(self.device)
            for key, value in self.model.tokenize(x).items()
        }
        x = self.model(x)["sentence_embedding"]

        if self.normalize:
            x = nn.functional.normalize(x)

        x = self.classifier(x)
        return x

    def encode(self, x):
        x = self.model.encode(x, convert_to_tensor=True)
        if self.normalize:
            x = nn.functional.normalize(x)

        return x
