import torch
import torch.nn as nn
from transformers import AutoModel
import numpy as np
from tqdm import tqdm

class BertRegressor(nn.Module):
    """BERT + MLP regressor."""
    def __init__(self, bert_name="bert-base-uncased", output_dim=4096):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_name)
        hidden = self.bert.config.hidden_size
        self.mlp = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, output_dim),
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls = outputs.last_hidden_state[:, 0]
        return self.mlp(cls)


def cosine_similarity(a, b):
    return torch.nn.functional.cosine_similarity(a, b, dim=1).mean().item()


def run_epoch(model, dataloader, loss_fn, device, optimizer=None):
    """Run one training/validation/testing epoch."""
    total_loss, cos_sims = 0, []
    train_mode = optimizer is not None
    model.train() if train_mode else model.eval()

    for inputs, targets in tqdm(dataloader, desc="Training..." if train_mode else "Evaluating..."):
        inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
        targets = targets.to(device)

        with torch.set_grad_enabled(train_mode):
            preds = model(**inputs)
            loss = loss_fn(preds, targets)
            if train_mode:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            total_loss += loss.item()
            cos_sims.append(cosine_similarity(preds, targets))

    avg_loss = total_loss / len(dataloader)
    avg_cos = np.mean(cos_sims)
    return avg_loss, avg_cos
