import torch
from tqdm import tqdm
import sys


def kd_train(synthesizer, model, criterion, optimizer):
    student, teacher = model
    student.train()
    teacher.eval()
    bar = tqdm(synthesizer.get_data(labeled=True), file=sys.stdout)
    for idx, (images, labels) in enumerate(bar):
        optimizer.zero_grad()
        images = images.cuda()
        images.requires_grad = True

        s_out, _, _ = student(images.detach(),
                              torch.arange(197).repeat(images.shape[0], 1).to(images.device),
                              torch.arange(197).repeat(images.shape[0], 1).to(images.device))
        with torch.no_grad():
            t_out, _, _ = teacher(images,
                                  torch.arange(197).repeat(images.shape[0], 1).to(images.device),
                                  torch.arange(197).repeat(images.shape[0], 1).to(images.device))

        loss = criterion(s_out, t_out.detach())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(parameters=student.parameters(), max_norm=10)
        optimizer.step()
