from loss import loss_function
from tqdm import tqdm




def train(args, model, train_data, optimizers_gen):
    for i in range(args.num_classes):
        print(f"Train Class Number {i}")
        for epoch in tqdm(range(args.num_epochs)):
            for data, _ in train_data[i]:
                data = data.to(device=args.device)
                # targets = targets.to(device=device)
                # data = data.reshape(data.shape[0], -1)

                recon_batch = model[i](data)
                loss = loss_function(recon_batch, data)

                optimizers_gen[i].zero_grad()
                loss.backward()

                optimizers_gen[i].step()