import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets


x = torch.tensor([3.0])
grad = torch.tensor([5.0])

w = torch.tensor([2.])
w.requires_grad=True
x.requires_grad=True
y1 = w * x

y2 = x + 1

y1_d = y1.detach()
y2_d = y2.detach()
y1_d.requires_grad = True
y2_d.requires_grad = True

y = y1_d + y2_d

y.backward(gradient=grad)
print(y1_d.grad, y2_d.grad)
y1.backward(gradient=y2_d.grad)
print(x.grad)

y2.backward(gradient=y1_d.grad)
print(x.grad)