from __future__ import print_function
import os
import argparse
import time
import ssl
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from zen import ZenGrad_M
from resnet import ResNet32  
from lion_pytorch import Lion 
ssl._create_default_https_context = ssl._create_unverified_context

# --------------------------
# Training settings
# --------------------------
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--batch-size', type=int, default=256,
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=256,
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=160,
                    help='number of epochs to train (default: 160)')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='initial learning rate (default: 0.1)')
parser.add_argument('--schedule', type=int, nargs='+', default=[80, 120],
                    help='epochs to decay learning rate')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='learning rate decay factor')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
                    help='weight decay (default: 1e-4)')

args = parser.parse_args()

# --------------------------
# Reproducibility
# --------------------------
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = True

# --------------------------
# Dataset: CIFAR-10
# --------------------------
print("==> Preparing CIFAR-10 data...")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='/workspace/data', train=True,
                                        download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                           shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='/workspace/data', train=False,
                                       download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
                                          shuffle=False, num_workers=4)

# --------------------------
# Model: ResNet32
# --------------------------
print("==> Building ResNet32 model...")
model = ResNet32(num_classes=10).cuda()
model = torch.nn.DataParallel(model)
print('Total params: %.2fM' % (sum(p.numel()
                                   for p in model.parameters()) / 1e6))

# --------------------------
# Optimizer & Criterion
# --------------------------
criterion = nn.CrossEntropyLoss()
optimizer = Lion(model.parameters(),
                    lr=args.lr,
                    weight_decay=args.weight_decay)

# --------------------------
# Adjust Learning Rate (manual MultiStep)
# --------------------------
state = {'lr': args.lr}

def adjust_learning_rate(optimizer, epoch):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    for milestone in args.schedule:
        if epoch >= milestone:
            lr *= args.gamma
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    state['lr'] = lr

# --------------------------
# Test function
# --------------------------
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    return 100. * correct / total

# --------------------------
# Training Loop with Metrics Logging
# --------------------------
metrics = []
best_acc = 0.0

for epoch in range(1, args.epochs + 1):
    adjust_learning_rate(optimizer, epoch)

    starttime = time.time()
    print(f"\nEpoch {epoch}/{args.epochs} | LR: {state['lr']:.6f}")

    model.train()
    train_loss, correct, total = 0.0, 0, 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * target.size(0)
        total += target.size(0)
        _, predicted = output.max(1)
        correct += predicted.eq(target).sum().item()

    endtime = time.time()
    train_loss /= total
    train_acc = 100. * correct / total
    test_acc = test(model, test_loader)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Test Acc: {test_acc:.2f}% | Time: {endtime - starttime:.2f}s")

    metrics.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "test_acc": test_acc,
        "time_sec": endtime - starttime,
        "lr": state['lr']
    })

    if test_acc > best_acc:
        best_acc = test_acc
        state_ckpt = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc,
        }
        os.makedirs("/workspace/checkpoint", exist_ok=True)
        torch.save(state_ckpt, '/workspace/checkpoint/netbest.pkl')

print("Training completed. Best Accuracy: %.2f%%" % best_acc)

# --------------------------
# Save Metrics to Excel
# --------------------------
df = pd.DataFrame(metrics)
os.makedirs("/workspace/results", exist_ok=True)
df.to_excel("/workspace/results/metrics.xlsx", index=False)
print("Metrics saved to /workspace/results/metrics.xlsx")
