import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
import os

#  Testing Function

def test_model(model, test_loader, device):
    """Evaluates the model's accuracy on the test dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X_test, Y_test in test_loader:
            X_test, Y_test = X_test.to(device), Y_test.to(device)
            outputs = model(X_test)
            _, predicted = torch.max(outputs.data, 1)
            total += Y_test.size(0)
            correct += (predicted == Y_test).sum().item()
    accuracy = 100 * correct / total
    model.train()
    return accuracy


#  Main Training Function

def main(args):
    # Setup and Data Loading
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f'Running with args: {args}')
    os.makedirs(args.results_dir, exist_ok=True)

    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2023, 0.1994, 0.2010)
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

    # Hyperparameters
    LearningRate = 0.002
    WeightDecay = 1e-4
    WhichOptimizer = torch.optim.Adam

    # Model Definition
    if args.use_stride_conv:
        model = nn.Sequential(
            nn.Sequential(nn.Conv2d(3, 10, 5, stride=2, padding=2), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Conv2d(10, 5, 5, stride=2, padding=2), nn.ReLU6(inplace=True), nn.Flatten()),
            nn.Sequential(nn.Linear(5 * 8 * 8, 50), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(50, 30), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(30, 10))
        ).to(device)
    else:
        # This part is kept for completeness but should not be used for this comparison.
        print("Warning: Max Pooling model is being used. You should use --use_stride_conv.")
        model = nn.Sequential(
            nn.Sequential(nn.Conv2d(3, 10, 3), nn.ReLU6(inplace=True), nn.MaxPool2d(2)),
            nn.Sequential(nn.Conv2d(10, 5, 3), nn.ReLU6(inplace=True), nn.Flatten()),
            nn.Sequential(nn.Linear(5 * 13 * 13, 50), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(50, 30), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(30, 10))
        ).to(device)

    LossFun = nn.CrossEntropyLoss()
    optimizer = WhichOptimizer(model.parameters(), lr=LearningRate, weight_decay=WeightDecay)

    # The Training Loop
    test_accuracies = []
    print(f"Starting training for {args.num_epochs} epochs.")
    for k in range(args.num_epochs):
        for i, (X, Y) in enumerate(train_loader):
            X, Y = X.to(device), Y.to(device)
            
            outputs = model(X)
            loss = LossFun(outputs, Y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        accuracy = test_model(model, test_loader, device)
        test_accuracies.append(accuracy)
        print(f"--- End of Epoch {k+1}/{args.num_epochs} | Test Accuracy: {accuracy:.2f}% ---")

    print(f"Final accuracy: {test_accuracies[-1]:.2f}%")

    # Save Results
    model_type = "stride_conv" if args.use_stride_conv else "max_pool"
    base_filename = f"bp_{model_type}_seed_{args.seed}"
    
    np.save(os.path.join(args.results_dir, f"{base_filename}_acc.npy"), test_accuracies)

    plt.figure(figsize=(10, 5))
    plt.plot(test_accuracies, marker='o', linestyle='--')
    plt.title(f'Backpropagation Test Accuracy (Seed: {args.seed})')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.ylim(0, 100)
    plt.grid(True)
    plt.savefig(os.path.join(args.results_dir, f"{base_filename}_plot.png"))
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CIFAR-10 Training with Backpropagation (Comparison Baseline)')
    
    parser.add_argument('--num_epochs', type=int, default=60, help='Number of training epochs')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to run on (e.g., cuda:0)')
    parser.add_argument('--batch_size', type=int, default=300, help='Training batch size')
    parser.add_argument('--results_dir', type=str, default='results_bp', help='Directory to save results')
    parser.add_argument('--use_stride_conv', action='store_true', help='Use strided convolutions instead of max pooling')

    args = parser.parse_args()
    main(args)