import os
import argparse
import torch
import torch.nn.functional as F
from utils.data_utils import get_sequential_data_loaders
from models.splr_model import SPLRModel

def main():
    # Argument parsing
    parser = argparse.ArgumentParser(description='Evaluate SPLRModel on Sequential CIFAR10/100')
    parser.add_argument('-device', default='cuda:0', help='Device to evaluate on (e.g., "cuda:0" or "cpu")')
    parser.add_argument('-b', default=128, type=int, help='Batch size for testing')
    parser.add_argument('-data-dir', type=str, required=True, help='Root directory of CIFAR10/100 dataset')
    parser.add_argument('-class-num', type=int, default=10, help='Number of classes (10 for CIFAR-10, 100 for CIFAR-100)')
    parser.add_argument('-model-path', type=str, required=True, help='Path to the trained model checkpoint')
    parser.add_argument('-num-steps', type=int, default=10, help='Number of steps (frames) in the sequence for each sample')
    args = parser.parse_args()

    # Set device for evaluation
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    # Load the sequential CIFAR dataset using the utility function
    _, test_loader = get_sequential_data_loaders(args, num_steps=args.num_steps)

    # Initialize the model
    model = SPLRModel(
        input_channels=3,
        num_classes=args.class_num,
        decay_rate=0.1,
        tau_d_list=[2.0, 5.0, 10.0],
        tau_s=5.0
    )
    model.to(device)

    # Load the trained model checkpoint
    if not os.path.isfile(args.model_path):
        raise FileNotFoundError(f"Checkpoint not found at: {args.model_path}")

    checkpoint = torch.load(args.model_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    print(f"Loaded model from checkpoint: {args.model_path}")

    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            loss = F.cross_entropy(outputs, labels)
            test_loss += loss.item() * labels.size(0)

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / total

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")

if __name__ == '__main__':
    main()
