import os
import sys
import torch
import torch.nn as nn
import open_clip


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LangModule(nn.Module):
    def __init__(
        self,
        out_features=256,
        pretrained_weights_path=None,
    ):
        super().__init__()
        self.out_features = out_features
        model_path = os.path.join(
            pretrained_weights_path,
            "open_clip_pytorch_model.bin",
        )
        self.text_encoder = open_clip.create_model(
            "convnext_large_d_320",
            pretrained=model_path,
        )
        self.text_encoder.eval()

        self.Mlp = Mlp(
            in_features=self.text_encoder.text_projection.shape[1],
            hidden_features=self.text_encoder.text_projection.shape[1],
            out_features=self.out_features,
        )

    def forward(self, text_tokens):
        """
        encode the input descriptions
        text_token: list of tensor of shape (n_samples, seq_len), len(text_token) = batchsize
        """
        word_embeddings = []
        for text_token in text_tokens:
            word_embedding = self.text_encoder.encode_text(text_token)
            word_logits = self.Mlp(word_embedding)  # (n_samples, out_features)
            word_embeddings.append(word_logits)
        return word_embeddings
