#cifarconv_relu6_nobn.py
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from time import time as tm
import argparse
import os

import TorchDiffPC as T2PC


#  Testing Function

def test_model(model, test_loader, device):
    """Evaluates the model's accuracy on the test dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X_test, Y_test in test_loader:
            X_test, Y_test = X_test.to(device), Y_test.to(device)
            outputs = model(X_test)
            _, predicted = torch.max(outputs.data, 1)
            total += Y_test.size(0)
            correct += (predicted == Y_test).sum().item()
    accuracy = 100 * correct / total
    model.train()
    return accuracy


#  Main Training Function

def main(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f'Running with args: {args}')
    
    os.makedirs(args.results_dir, exist_ok=True)

    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2023, 0.1994, 0.2010)

    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=args.batch_size, shuffle=True
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=1000, shuffle=False
    )

    LearningRate = 0.002
    WeightDecay = 1e-4
    WhichOptimizer = torch.optim.Adam
    
    if args.use_stride_conv:
        model = nn.Sequential(
            nn.Sequential(nn.Conv2d(3, 10, 5, stride=2, padding=2), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Conv2d(10, 5, 5, stride=2, padding=2), nn.ReLU6(inplace=True), nn.Flatten()),
            nn.Sequential(nn.Linear(5 * 8 * 8, 50), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(50, 30), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(30, 10))
        ).to(device)
    else: #Not used in tests
        print("Warning: Max Pooling model is being used. You should use --use_stride_conv.")
        model = nn.Sequential(
            nn.Sequential(nn.Conv2d(3, 10, 3, padding=1), nn.ReLU6(inplace=True), nn.MaxPool2d(2)), # Layer 0 - Added padding for better feature retention
            nn.Sequential(nn.Conv2d(10, 5, 5, stride=2, padding=2), nn.ReLU6(inplace=True), nn.Flatten()), # Layer 1 - Modified conv
            nn.Sequential(nn.Linear(5 * 8 * 8, 50), nn.ReLU6(inplace=True)), # Layer 2 - Adjusted in_features
            nn.Sequential(nn.Linear(50, 30), nn.ReLU6(inplace=True)),
            nn.Sequential(nn.Linear(30, 10))
        ).to(device)
    LossFun = nn.CrossEntropyLoss()
    optimizer = WhichOptimizer(model.parameters(), lr=LearningRate, weight_decay=WeightDecay)

    test_accuracies = []
    print(f"Starting training for {args.num_epochs} epochs.")
    for k in range(args.num_epochs):
        # Initialize trackers for spike counts and neuron counts for the epoch
        epoch_ff_spikes = 0
        epoch_lrn_spikes = 0
        total_neurons_processed = 0

        for i, (X, Y) in enumerate(train_loader):
            X, Y = X.to(device), Y.to(device)
            optimizer.zero_grad()
            
            vhat, Loss, _, _, _, ff_spikes, lrn_spikes = T2PC.PCInfer(
                model, LossFun, X, Y, "QuantizedPred", eta=0.05, n=15,
                lt_m=args.lt_m, lt_n=args.lt_n, lt_a=args.lt_a, e_mult=args.e_mult
            )
            optimizer.step()

            epoch_ff_spikes += ff_spikes
            epoch_lrn_spikes += lrn_spikes
            total_neurons_processed += sum(v.numel() for v in vhat)

        accuracy = test_model(model, test_loader, device)
        test_accuracies.append(accuracy)

        avg_ff_spikes = epoch_ff_spikes / total_neurons_processed if total_neurons_processed > 0 else 0
        avg_lrn_spikes = epoch_lrn_spikes / total_neurons_processed if total_neurons_processed > 0 else 0
        
        print(f"--- End of Epoch {k} | Test Accuracy: {accuracy:.2f}% ---")
        print(f"Spike Stats | Avg Fwd Spikes/Neuron: {avg_ff_spikes:.4f} | Avg Lrn Spikes/Neuron: {avg_lrn_spikes:.4f}")


    print(f"Final accuracy: {test_accuracies[-1]:.2f}%")

    base_filename = (
        f"ltm_{args.lt_m}-ltn_{args.lt_n}-lta_{args.lt_a}-"
        f"emult_{args.e_mult}-seed_{args.seed}"
    )
    
    np.save(os.path.join(args.results_dir, f"{base_filename}_acc.npy"), test_accuracies)

    fig, ax1 = plt.subplots(figsize=(12, 5))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Test Accuracy (%)', color='tab:red')
    ax1.plot(test_accuracies, color='tab:red', marker='o', linestyle='--')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.set_ylim(0, 100)
    fig.tight_layout()
    plt.title(f'Test Accuracy (lt_n={args.lt_n}, lt_a={args.lt_a}, e_mult={args.e_mult})')
    plt.grid(True)
    plt.savefig(os.path.join(args.results_dir, f"{base_filename}_plot.png"))
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CIFAR-10 Training with Quantized Predictive Coding')
    
    parser.add_argument('--lt_m', type=int, default=2, help='Quantization parameter m')
    parser.add_argument('--lt_n', type=int, default=16, help='Quantization parameter n (precision)')
    parser.add_argument('--lt_a', type=float, default=1.0, help='Quantization parameter alpha')
    parser.add_argument('--e_mult', type=float, default=0.0025, help='Multiplier for error quantization alpha')
    parser.add_argument('--num_epochs', type=int, default=60, help='Number of training epochs')
    parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility')
    parser.add_argument('--device', type=str, default='cuda:1', help='Device to run on (e.g., cuda:0)')
    parser.add_argument('--batch_size', type=int, default=300, help='Training batch size')
    parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
    parser.add_argument('--use_stride_conv', action='store_true', help='Use strided convolutions instead of max pooling')

    args = parser.parse_args()
    main(args)