import torch
from copy import deepcopy

s0 = torch.load('k553_epoch_0.pt', map_location='cpu')

# odict_keys(['cnn.0.weight', 'cnn.0.bias', 'cnn.2.weight', 'cnn.2.bias', 'cnn.4.weight', 'cnn.4.bias', 
#   'policy.weight', 'policy.bias', 'value.weight', 'value.bias'])
print (s0.keys())
# d = {k:v for k,v in d0.items() if 'value' not in k}

def perturb(w):
    w.add_(torch.randn(w.shape) * (w**2).mean().sqrt() * 0.2)

# generate 8 different initialization points
for i in range(16):
    s = deepcopy(s0)
    perturb(s['cnn.4.weight'])
    perturb(s['policy.weight'])
    s['policy.bias'].zero_()
    s['value.weight'].copy_(torch.randn(s['value.weight'].shape) * 0.0001)
    s['value.bias'].zero_()
    torch.save(s, '%d.pt'%i)
    print (i)
