import torch
import torch.nn as nn
from softsignsgd import SoftSignSGD
import time

HALF_FLAG = False

torch.manual_seed(0)

model = nn.Linear(10, 5)
if HALF_FLAG:
    model.half()

model_dict = model.state_dict()

input = torch.randn((10, 10), device='cuda:0', dtype=torch.float)
if HALF_FLAG:
    input = input.half()

for optim in [SoftSignSGD(model.parameters(), 0.3, 0.9, 1e-8, 0.2, 2.0, foreach=False, fused=False), 
              SoftSignSGD(model.parameters(), 0.3, 0.9, 1e-8, 0.2, 2.0, foreach=True, fused=False),
              SoftSignSGD(model.parameters(), 0.3, 0.9, 1e-8, 0.2, 2.0, foreach=False, fused=True),
              SoftSignSGD(model.parameters(), 0.3, 0.9, 1e-8, 0.2, 2.0, foreach=True, fused=True)]:
    if HALF_FLAG and optim.defaults['foreach'] and optim.defaults['fused']:
        print('Using multi tensor softsignsgd kernel, assume single type across p,g now')
        continue
    model.load_state_dict(model_dict)
    model.cuda()
    s_time = time.time()
    for i in range(100):
        output = model(input)
        loss = output.mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
    e_time = time.time()
    print(list(model.parameters()))
    print('time during: {} s'.format(e_time-s_time))

