import torch
from src.core.base import BaseServer
from src.algorithms.utils import ce_loss


class SupervisedServer(BaseServer):

    def train_step(self, optimizer, x_lb, y_lb):
        self.model.train()
        
        optimizer.zero_grad()
        logits_x_lb = self.model(x_lb)['logits']
        sup_loss = ce_loss(logits_x_lb, y_lb, reduction='mean')
        
        sup_loss.backward() 
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
            # print(f"Gradient norm: {total_norm}")
            
        optimizer.step() 
        
        lr = optimizer.param_groups[0]['lr']
        
        res_dict = {f"{self.mode}_train/s_loss": sup_loss.item(),
                    f"{self.mode}_train/lr": lr}
        
        return res_dict
