# Adapted from https://github.com/kuangliu/pytorch-cifar/tree/master
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import torch.distributed as dist

import os
import argparse

import numpy as np
import random

from resnet import *
from utils import progress_bar

from muon_VR import Muon

best_acc = 0

# Initialize the distributed process group
def init_distributed_mode(rank, world_size):
    # Set the environment variables for distributed training
    os.environ['MASTER_ADDR'] = 'localhost'  # or the IP address of the master node
    os.environ['MASTER_PORT'] = '29501'     # any unused port number

    dist.init_process_group(
        backend='nccl',  # Use 'nccl' for GPU-based training
        init_method='env://',  # Use the environment variables for distributed setup
        world_size=world_size,
        rank=rank
    )
    torch.cuda.set_device(rank)  # Ensure each process uses a different GPU

# Example: Assume world_size is the number of GPUs, and rank is the current GPU rank
rank = 0  # or manually assign rank if running on a single node
world_size = 1 # Number of total processes

# Initialize distributed setup
init_distributed_mode(rank, world_size)

def main():
    
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    seed = 40  # or any other integer
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False 

    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--wd', default=1.0, type=float, help='weight decay')
    parser.add_argument('--b1', default=0.9, type=float, help='beta1')
    parser.add_argument('--b2', default=0.99, type=float, help='beta2')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # best_acc = 0  # best test accuracy
    global best_acc
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=8)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=8)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('==> Building model..')
    # net = VGG('VGG19')
    net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    # net = SimpleDLA()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    
    # Assuming 'net' is your ResNet model
    model = net.module if isinstance(net, torch.nn.DataParallel) else net

    # Update to target the body and head of ResNet
    muon_params = [p for p in model.parameters() if p.ndim >= 2 and p is not model.linear.weight]  # Exclude fully connected layer
    adamw_params = [p for p in model.parameters() if p.ndim < 2 or p is model.linear.weight]

    # Define the optimizers
    muon_opt = Muon(muon_params, lr=args.lr, momentum=0.95, weight_decay=args.wd, rank=0, world_size=1)
    adamw_opt = torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)
    optimizers = [muon_opt, adamw_opt]

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[0], T_max=200, eta_min=1e-2)

    for epoch in range(start_epoch, start_epoch+200):

        train_loss, train_acc = train(epoch, net, trainloader, device, criterion, optimizers)
        test_loss, test_acc = test(epoch, net, testloader, device, criterion)

        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)

        scheduler.step()
    filename = f"Muon++_clip=5_lr={args.lr}_wd={args.wd}.npz"
    np.savez(filename,
             train_losses=np.array(train_losses),
             train_accuracies=np.array(train_accuracies),
             test_losses=np.array(test_losses),
             test_accuracies=np.array(test_accuracies),
            )


# Training
def train(epoch, net, trainloader, device, criterion, optimizers):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    prev_params = None
    prev_grads = None
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # --- Save current weights ---
        current_params = {p: p.data.clone() for p in net.parameters() if p.requires_grad}

        if prev_params is not None:
            # --- Load previous weights ---
            with torch.no_grad():
                for p in net.parameters():
                    if p.requires_grad:
                        p.data.copy_(prev_params[p])

            # --- Compute gradient at previous weights with current batch ---
            net.zero_grad()
            outputs = net(inputs)
            loss_prev = criterion(outputs, targets)
            loss_prev.backward()

            prev_grads = {p: p.grad.detach().clone() for p in net.parameters() if p.grad is not None}

            # --- Restore current weights ---
            with torch.no_grad():
                for p in net.parameters():
                    if p.requires_grad:
                        p.data.copy_(current_params[p])

            net.zero_grad()

        # --- Compute gradient at current weights (main step) ---
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        current_grads = {p: p.grad.detach().clone() for p in net.parameters() if p.grad is not None}

        # Optional: clip grads
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=5.0)
        
        optimizers[0].prev_grads = prev_grads
        optimizers[0].curr_grads = current_grads
        
        optimizers[0].step()
        optimizers[1].step()
        
        optimizers[0].zero_grad(set_to_none=True)
        optimizers[1].zero_grad(set_to_none=True)

        # --- Save weights for next iteration ---
        prev_params = current_params

        train_loss += loss.item()
        # _, predicted = outputs.max(1)
        _, predicted = net(inputs).max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    return train_loss / len(trainloader), 100. * correct / total


def test(epoch, net, testloader, device, criterion):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
    return test_loss / len(testloader), acc


if __name__ == '__main__':
    main()