import torch

def collect_partial_gradients(net1, net2, clf, train_loader, device):
    partial_grad_1 = None
    partial_grad_2 = None
    labels = None
    criterion = torch.nn.CrossEntropyLoss()
    for batch_idx, (inputs_a, inputs_b, targets) in enumerate(train_loader):
        inputs_a, inputs_b, targets = inputs_a.to(device), inputs_b.to(device), targets.to(device)
        x_a = net1.forward(inputs_a)  # keep on site-1_a
        x_b = net2.forward(inputs_b)  # keep on site-1_b
        x_sent = torch.cat((x_a, x_b), dim=1).detach().requires_grad_()
        

        pred = clf.forward(x_sent)

        loss = criterion(pred, targets)

        loss.backward()
        return_grad = x_sent.grad.data

        ch = return_grad.shape[1]
        return_grad_1 = return_grad[:,:int(ch/2),:,:]
        return_grad_2 = return_grad[:,int(ch/2):,:,:]
        return_grad_1 = return_grad_1.view(return_grad_1.shape[0], -1)
        return_grad_2 = return_grad_2.view(return_grad_2.shape[0], -1)

        clf.zero_grad()
        net1.zero_grad()
        net2.zero_grad()
        if batch_idx == 0:
            partial_grad_1 = return_grad_1
            partial_grad_2 = return_grad_2
            labels = targets
        else:
            partial_grad_1 = torch.concat((partial_grad_1, return_grad_1), dim=0)
            partial_grad_2 = torch.concat((partial_grad_2, return_grad_2), dim=0)
            labels = torch.concat((labels, targets), dim=0)

    partial_grad_1 = partial_grad_1.cpu().numpy()
    partial_grad_2 = partial_grad_2.cpu().numpy()
    labels = labels.cpu().numpy()
    return partial_grad_1, partial_grad_2, labels


