import torch
from src.core.base import BaseClient
from src.algorithms.utils import ce_loss


class SupervisedClient(BaseClient):
    def train_step(self, x_lb, y_lb):
        
        self.optimizer.zero_grad()
        logits_x_lb = self.model(x_lb)['logits']
        sup_loss = ce_loss(logits_x_lb, y_lb, reduction='mean')
        
        sup_loss.backward() 
        if self.clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
            
        self.optimizer.step() 
        
        res_dict = {'client_train/s_loss': sup_loss.item()}

        return res_dict
