import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer


class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, cfg):
        self.texts = texts
        self.tokenizer = tokenizer
        self.cfg = cfg
        self.max_length = cfg.llm.max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=self.max_length
        )
        return {key: value.squeeze(0) for key, value in encoding.items()}


from transformers import BertModel, BertTokenizer


class TextEncoder(torch.nn.Module):
    def __init__(self, cfg):
        super(TextEncoder, self).__init__()
        self.model = SentenceTransformer(cfg.llm.model_name)
        self.tokenizer = self.model._first_module().tokenizer
        self.scaler = torch.cuda.amp.GradScaler()
        self.device = torch.device(cfg.training.device)
        self.model.to(self.device)
        self.encoder = self.model._first_module().auto_model

    def forward(self, texts, seq_len, return_pooled=False):
        tokenized = self.tokenizer(
            texts,
            max_length=seq_len,
            truncation=True,
            padding=True,
            return_tensors="pt"
        )
        tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
        with torch.amp.autocast(device_type='cuda'):
            outputs = self.encoder(**tokenized).last_hidden_state

        if return_pooled:
            return outputs[:, 0]
        else:
            return outputs