import sys
sys.path.insert(0, "..")
import model, envs, torch

env = envs.SumoEnv()
state_dim = env.state_dim
nact = env.nact
print ('obs', env.env.observation_space, 'act', env.env.action_space)
for a in env.env.action_space:
    print ('action range:', a.low, a.high)

for i in range(8):
    x = model.MLPControl(state_dim, nact, init_std=1.0)
    torch.save(x.state_dict(), f"randx{i}.pt")
    y = model.MLPControl(state_dim, nact, init_std=1.0)
    torch.save(y.state_dict(), f"randy{i}.pt")

# check the numerical range
with torch.no_grad():
    o = env.reset()
    for t in range(10):
        a, v = x(torch.from_numpy(o[0])[None])
        # a = (mean, log_std) pair, log_std is initialized to 0
        print (f't={t} x a={a[0]}, v={v[0]}')
        b, v = y(torch.from_numpy(o[1])[None])
        print (f't={t} y a={b[0]}, v={v[0]}')
        o,*_ = env.step(a[0][0].numpy(), b[0][0].numpy())
