import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from models import get_model
from edge_of_stability_utilities import get_gd_optimizer, get_loss_and_acc, compute_losses, iterate_dataset, load_cifar



dataset="cifar10"
arch_id="resnet18-nobn"
loss="ce"
opt="gd"
physical_batch_size=1000


train_dataset, test_dataset = load_cifar(loss)

loss_fn, acc_fn = get_loss_and_acc(loss)

torch.manual_seed(0)
network = get_model("resnet18nobn").cuda()

checkpoint = torch.load("<PATH_TO_CHECKPOINT_FROM_WARMSTARTING_USING_train_double_lr.py>", map_location="cpu")
lr = 0.01

network.load_state_dict(checkpoint["model"][1])

optimizer = get_gd_optimizer(network.parameters(), opt, lr, None)

for pg in optimizer.param_groups:
    pg['weight_decay'] = 5e-4


max_steps = 400
train_loss, test_loss, train_acc, test_acc = \
       torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps)
params_large_lr = []

for step in range(0, max_steps):
    train_loss[step], train_acc[step] = compute_losses(network, [loss_fn, acc_fn], train_dataset,
                                                       physical_batch_size)
    test_loss[step], test_acc[step] = compute_losses(network, [loss_fn, acc_fn], test_dataset, physical_batch_size)
    params_large_lr.append(parameters_to_vector(network.parameters()).cpu().detach())
    optimizer.zero_grad()
    for (X, y) in iterate_dataset(train_dataset, physical_batch_size):
        loss = loss_fn(network(X.cuda()), y.cuda()) / len(train_dataset)
        loss.backward()
    optimizer.step()
    print(f"{step}\t{train_loss[step]:.3f}\t{train_acc[step]:.3f}\t{test_loss[step]:.3f}\t{test_acc[step]:.3f}")


optimizer = get_gd_optimizer(network.parameters(), opt, lr, None)

for pg in optimizer.param_groups:
    pg['weight_decay'] = 5e-4
vector_to_parameters(params_large_lr[-1].cuda(), network.parameters())
max_steps = 1000
train_loss_small_lr2, test_loss_small_lr2, train_acc_small_lr2, test_acc_small_lr2 = \
        torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps)
params_small_lr2 = []

for step in range(0, max_steps):
    train_loss_small_lr2[step], train_acc_small_lr2[step] = compute_losses(network, [loss_fn, acc_fn], train_dataset,
                                                       physical_batch_size)
    test_loss_small_lr2[step], test_acc_small_lr2[step] = compute_losses(network, [loss_fn, acc_fn], test_dataset, physical_batch_size)
    params_small_lr2.append(parameters_to_vector(network.parameters()).cpu().detach())
    optimizer.zero_grad()
    for (X, y) in iterate_dataset(train_dataset, physical_batch_size):
        loss = loss_fn(network(X.cuda()), y.cuda()) / len(train_dataset)
        loss.backward()
    optimizer.step()
    print(f"{step}\t{train_loss_small_lr2[step]:.3f}\t{train_acc_small_lr2[step]:.3f}\t{test_loss_small_lr2[step]:.3f}\t{test_acc_small_lr2[step]:.3f}")

network.load_state_dict(checkpoint["model"][1])

optimizer = get_gd_optimizer(network.parameters(), opt, lr, None)

for pg in optimizer.param_groups:
    pg['weight_decay'] = 5e-4

for pg in optimizer.param_groups:
    pg['lr'] = lr * 0.1
max_steps = 4000
train_loss_small_lr, test_loss_small_lr, train_acc_small_lr, test_acc_small_lr = \
        torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps)
params_small_lr = []

for step in range(0, max_steps):
    train_loss_small_lr[step], train_acc_small_lr[step] = compute_losses(network, [loss_fn, acc_fn], train_dataset,
                                                       physical_batch_size)
    test_loss_small_lr[step], test_acc_small_lr[step] = compute_losses(network, [loss_fn, acc_fn], test_dataset, physical_batch_size)
    if step % 10 == 0:
        params_small_lr.append(parameters_to_vector(network.parameters()).cpu().detach())
    optimizer.zero_grad()
    for (X, y) in iterate_dataset(train_dataset, physical_batch_size):
        loss = loss_fn(network(X.cuda()), y.cuda()) / len(train_dataset)
        loss.backward()
    optimizer.step()
    print(f"{step}\t{train_loss_small_lr[step]:.3f}\t{train_acc_small_lr[step]:.3f}\t{test_loss_small_lr[step]:.3f}\t{test_acc_small_lr[step]:.3f}")

torch.save([params_large_lr, params_small_lr, params_small_lr2[::10]], "resnet_gd_wd.pkl")