import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import time
import math
import pandas as pd
import os
#Import Preconditioner, model, and loaders appropiately. 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = # fill in appropiately
test_loader = # fill in appropiately
train_loader = # fill in appropiatley
criterion = nn.CrossEntropyLoss()
preconditioner = # Change to Preconditioner that we want to use
optimizer = torch.optim.SGD(model.parameters(), lr = 0.0001)
best_acc = 0
history = []  # To store per-epoch data
start_time = time.time()
opt = # Name of Precondtioner

for epoch in range(1, 101):  # Train for 100 epochs
    model.train()
    correct_train = 0
    total_train = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        preconditioner.step()
        optimizer.step()

        # Calculate training accuracy
        _, predicted = outputs.max(1)
        total_train += targets.size(0)
        correct_train += predicted.eq(targets).sum().item()

    train_acc = 100. * correct_train / total_train

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    epoch_time = time.time() - start_time

    print('Optimizer: {} | Epoch: {} | Train Acc: {:.2f}% | Test Acc: {:.2f}% | Time: {:.2f}s'.format(opt, epoch, train_acc, acc, epoch_time))

    # Record data
    history.append({
        'Optimizer': opt,
        'Epoch': epoch,
        'Train Accuracy': train_acc,
        'Test Accuracy': acc,
        'Epoch Time (s)': epoch_time
    })

    if acc > best_acc:
        best_acc = acc
    if best_acc > 44:
        break

end_time = time.time()
training_time = end_time - start_time
# Convert history to DataFrame
history_df = pd.DataFrame(history)
# Save the DataFrame to a CSV file (optional)
history_df.to_csv(f'{opt}_training_history.csv', index=False)
print('Optimizer: {} | Total Training Time: {:.2f} seconds | Best Test Accuracy: {:.2f}%'.format(
opt, training_time, best_acc))
#Use colab to download files if using colab
#files.download(f'{opt}_training_history.csv')
