import torch
import torch.nn as nn

import copy

x = torch.randn((5,))
grad = torch.randn((4,))
model0 = nn.Sequential(nn.Linear(5, 4))
model1 = nn.Sequential(nn.Linear(5, 4))


y0 = model0(x)
y1 = model1(x)
print("expecting diffs")
print(y1 - y0)

m0_dict = model0.state_dict()
model1.load_state_dict(m0_dict)
model1.eval()
y0 = model0(x)
y1 = model1(x)
print("expecting sames")
print(y1 - y0)


opt = torch.optim.SGD(model1.parameters(), lr=1e-2)

model0.zero_grad()
y1.backward(gradient=grad)
opt.step()

y0 = model0(x)
y1 = model1(x)
print("expecting diffs")
print(y1 - y0)

