import torch 
from torch.utils.data import DataLoader
import torch.nn.functional as F

# --- Gradient for Taylor ---
def get_grads(
        model: torch.nn.Module, 
        dataloader: DataLoader, 
        device: torch.device,
        text_logits = None
        ) -> None:
    """
    Performs a single training step to compute gradients for the model
    using Cross-Entropy.

    Args:
        model (torch.nn.Module): Model whose gradients are to be computed.
        dataloader (DataLoader): A PyTorch dataloader providing input data.
        device (torch.device): Device to run the forward and backward pass on.
        text_logits (torch.Tensor, optional): Precomputed text logits for CLIP models. 
    """
    model.train()
    model.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        if text_logits is not None:
            image_logits = model.encode_image(images, normalize=True)
            outputs = model.logit_scale.exp() * image_logits @ text_logits
        else:
            outputs = model(images)
        loss = loss_fn(outputs, labels)
        model.zero_grad()
        loss.backward(retain_graph=False)
        break