# Adapted from https://github.com/kuangliu/pytorch-cifar/tree/master
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

import numpy as np
import random

from resnet import *
from utils import progress_bar

from lion_VR import Lion_VR


best_acc = 0


def main():
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    seed = 40  # or any other integer
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False 

    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--wd', default=0.1, type=float, help='weight decay')
    parser.add_argument('--b1', default=0.9, type=float, help='beta1')
    parser.add_argument('--b2', default=0.99, type=float, help='beta2')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # best_acc = 0  # best test accuracy
    global best_acc
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=8)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=8)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('==> Building model..')
    # net = VGG('VGG19')
    net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    # net = SimpleDLA()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()

    optimizer = Lion_VR(net.parameters(), betas=(args.b1, args.b2), lr=args.lr, weight_decay=args.wd)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    for epoch in range(start_epoch, start_epoch+200):

        train_loss, train_acc = train(epoch, net, trainloader, device, criterion, optimizer)
        test_loss, test_acc = test(epoch, net, testloader, device, criterion)

        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)

        scheduler.step()
    filename = f"Lion++_lr={args.lr}_wd={args.wd}.npz"
    np.savez(filename,
             train_losses=np.array(train_losses),
             train_accuracies=np.array(train_accuracies),
             test_losses=np.array(test_losses),
             test_accuracies=np.array(test_accuracies),
            #  epoch_times=np.array(epoch_times)
            )


# Training
def train(epoch, net, trainloader, device, criterion, optimizer):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    prev_params = None
    prev_grads = None
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
    
    
        # --- Save current weights ---
        current_params = {p: p.data.clone() for p in net.parameters() if p.requires_grad}

        if prev_params is not None:
            # --- Load previous weights ---
            with torch.no_grad():
                for p in net.parameters():
                    if p.requires_grad:
                        p.data.copy_(prev_params[p])

            # --- Compute gradient at previous weights with current batch ---
            net.zero_grad()
            outputs = net(inputs)
            loss_prev = criterion(outputs, targets)
            loss_prev.backward()

            prev_grads = {p: p.grad.detach().clone() for p in net.parameters() if p.grad is not None}

            # --- Restore current weights ---
            with torch.no_grad():
                for p in net.parameters():
                    if p.requires_grad:
                        p.data.copy_(current_params[p])

            net.zero_grad()

        # --- Compute gradient at current weights (main step) ---
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        curr_grads = {p: p.grad.detach().clone() for p in net.parameters() if p.grad is not None}
        
        optimizer.prev_grads = prev_grads
        optimizer.curr_grads = curr_grads

        # Optional: clip grads
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=5.0)
        
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        # --- Save weights for next iteration ---
        prev_params = current_params

        train_loss += loss.item()
        # _, predicted = outputs.max(1)
        
        _, predicted = net(inputs).max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    return train_loss / len(trainloader), 100. * correct / total


def test(epoch, net, testloader, device, criterion):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
    return test_loss / len(testloader), acc


if __name__ == '__main__':
    main()