import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Config
seed = 4
batch_size_global = 32
epochs = 200
save_every = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = f"neurips25_batch_order/models_vgg16_clean_cifar10/batch{batch_size_global}_seed{seed}"

print(save_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Set seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(seed)

# Transforms

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(root='data', train=True,
                                             download=False, transform=train_transform)

test_dataset = torchvision.datasets.CIFAR10(root='data', train=False,
                                            download=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size_global, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size_global, shuffle=False, num_workers=4)

# Model
model = models.vgg16(pretrained=False)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)
model = model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)


# Evaluation function
def evaluate(model, loader, criterion, device, dataset_name="Test"):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"{dataset_name} Loss: {total_loss/len(loader):.4f}, Accuracy: {accuracy:.2f}%")
    return accuracy

# Training function
def train(model, criterion, optimizer, device, epochs):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%")

        # Save checkpoint & evaluate every `save_every` epochs
        if (epoch + 1) % save_every == 0:
            checkpoint_path = f'{save_dir}/vgg16_epoch{epoch+1}.pth'
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")
            evaluate(model, test_loader, criterion, device, dataset_name="Test")

# Run training
train(model, criterion, optimizer, device, epochs)

# Final save
torch.save(model.state_dict(), f'{save_dir}/vgg16_final.pth')
print("Final model saved.")

# Final evaluation
evaluate(model, test_loader, criterion, device, dataset_name="Test")
