from __future__ import print_function
import os
import argparse
import time
import ssl
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from resnet import ResNet32


ssl._create_default_https_context = ssl._create_unverified_context

# --------------------------
# Training settings
# --------------------------
parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training')
parser.add_argument('--batch-size', type=int, default=256,
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=256,
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=160,
                    help='number of epochs to train (default: 160)')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')

args = parser.parse_args()

# --------------------------
# Reproducibility
# --------------------------
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = False  # allow optimizations
cudnn.benchmark = True

# --------------------------
# Dataset: CIFAR-100
# --------------------------
print("==> Preparing CIFAR-100 data...")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR100(
    root='/workspace/data', train=True, download=True, transform=transform_train
)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True,
    num_workers=4, pin_memory=True, persistent_workers=True
)

testset = torchvision.datasets.CIFAR100(
    root='/workspace/data', train=False, download=True, transform=transform_test
)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=args.test_batch_size, shuffle=False,
    num_workers=4, pin_memory=True, persistent_workers=True
)

# --------------------------
# Model: ResNet32 (100 classes)
# --------------------------
print("==> Building ResNet32 model...")
model = ResNet32(num_classes=100).cuda()
model = torch.nn.DataParallel(model)
print('Total params: %.2fM' % (sum(p.numel()
                                   for p in model.parameters()) / 1e6))

# --------------------------
# Optimizer & Criterion
# --------------------------
criterion = nn.CrossEntropyLoss()
init_lr = 0.003
weight_decay = 1e-4
optimizer = Lion(model.parameters(), lr=init_lr, weight_decay=weight_decay)

# --------------------------
# AMP Scaler for mixed precision
# --------------------------
scaler = torch.cuda.amp.GradScaler()

# --------------------------
# Test function
# --------------------------
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            with torch.cuda.amp.autocast():
                output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    return 100. * correct / total

# --------------------------
# Training Loop with Metrics Logging
# --------------------------
metrics = []
best_acc = 0.0

for epoch in range(1, args.epochs + 1):
    starttime = time.time()
    print(f"\nEpoch {epoch}/{args.epochs} | LR: {init_lr:.6f}")

    model.train()
    train_loss, correct, total = 0.0, 0, 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)

        optimizer.zero_grad(set_to_none=True)  # more memory efficient
        with torch.cuda.amp.autocast():
            output = model(data)
            loss = criterion(output, target)

        # backward + optimizer step with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * target.size(0)
        total += target.size(0)
        _, predicted = output.max(1)
        correct += predicted.eq(target).sum().item()

    endtime = time.time()
    train_loss /= total
    train_acc = 100. * correct / total
    test_acc = test(model, test_loader)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Test Acc: {test_acc:.2f}% | Time: {endtime - starttime:.2f}s")

    metrics.append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "test_acc": test_acc,
        "time_sec": endtime - starttime,
        "lr": init_lr
    })

    if test_acc > best_acc:
        best_acc = test_acc
        state_ckpt = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc,
        }
        os.makedirs("/workspace/checkpoint", exist_ok=True)
        torch.save(state_ckpt, '/workspace/checkpoint/netbest.pkl')

print("Training completed. Best Accuracy: %.2f%%" % best_acc)

# --------------------------
# Save Metrics to Excel
# --------------------------
df = pd.DataFrame(metrics)
os.makedirs("/workspace/results", exist_ok=True)
df.to_excel("/workspace/results/metrics.xlsx", index=False)
print("Metrics saved to /workspace/results/metrics.xlsx")
