from torch import nn
from transformers import AutoModel, AutoTokenizer

import torch

def get_molformer_tokenizer():
    return AutoTokenizer.from_pretrained(
        "ibm/MoLFormer-XL-both-10pct", trust_remote_code=True, use_fast=False
    )


class MolFormerRegressor(nn.Module):
    def __init__(self, tokenizer, n_last_hidden_units=100, activation_fn=nn.ReLU, dtype="float64"):
        super().__init__()
        self.dtype = {
            "float64": torch.float64,
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }[dtype]
        self.tokenizer = tokenizer
        self.feature_extractor = AutoModel.from_pretrained(
            "ibm/MoLFormer-XL-both-10pct",
            deterministic_eval=True,
            trust_remote_code=True,
        ).float()
        
        self.feature_dim = self.feature_extractor.config.hidden_size
        self.head = nn.Sequential(
            nn.Linear(self.feature_dim, n_last_hidden_units, dtype=self.dtype),
            activation_fn(),
            nn.Linear(n_last_hidden_units, 1, dtype=self.dtype),
        )

    def forward(self, data):
        feat = self.forward_features(data)
        feat = feat.to(self.dtype)
        return self.head(feat)

    def forward_features(self, data):
        input_ids, attn_mask = data["input_ids"], data["attention_mask"]
        device = next(self.parameters()).device
        input_ids = input_ids.to(device, non_blocking=True)
        attn_mask = attn_mask.to(device, non_blocking=True)
        feat = self.feature_extractor(input_ids, attn_mask).pooler_output
        return feat

    def freeze_params(self):
        for p in self.feature_extractor.parameters():
            p.requires_grad = False

    def unfreeze_params(self):
        for p in self.feature_extractor.parameters():
            p.requires_grad = True
