"""
Supervised Cross-Entropy Training

Train ANN or SNN models with standard supervised classification.

Usage:
    # Train ANN
    python main_supce.py --arch resnet18 --dataset-name cifar10
    
    # Train SNN
    python main_supce.py --spiking --arch resnet18 --timesteps 4 --dataset-name cifar10
"""

import os
import sys
import argparse
import logging

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import save_config_file, save_checkpoint, AverageMeter, accuracy
from networks.resnet_ann import SupCEResNet
from networks.resnet_snn import SupCEResNetSNN, add_dimention_distribute


def parse_option():
    parser = argparse.ArgumentParser('Supervised Cross-Entropy Training')

    # Model settings
    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4, help='Number of timesteps for SNN')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                        choices=['resnet18', 'resnet34', 'resnet50'])
    
    # Dataset settings
    parser.add_argument('--dataset-name', type=str, default='cifar10',
                        choices=['cifar10', 'tinyimagenet'])
    parser.add_argument('-j', '--workers', default=16, type=int)
    
    # Training settings
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('-b', '--batch-size', default=512, type=int)
    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, dest='lr')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay')
    
    # Other settings
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save-path', type=str, default='./save/SupCE')
    parser.add_argument('--gpu-ids', type=str, default='0,1,2,3')
    parser.add_argument('--note', type=str, default='')

    opt = parser.parse_args()

    # Set dataset-specific parameters
    if opt.dataset_name == 'cifar10':
        opt.data_folder = "~/projects/Datasets/CIFAR10/"
        opt.img_size = 32
        opt.n_cls = 10
    elif opt.dataset_name == 'tinyimagenet':
        opt.data_folder = "~/projects/Datasets/tiny-imagenet-200/"
        opt.img_size = 64
        opt.n_cls = 200

    # Set log path
    if opt.spiking:
        opt.log_path = os.path.join(
            opt.save_path,
            f'SNN_{opt.dataset_name}_{opt.arch}_T{opt.timesteps}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    else:
        opt.log_path = os.path.join(
            opt.save_path,
            f'ANN_{opt.dataset_name}_{opt.arch}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    
    if not os.path.isdir(opt.log_path):
        os.makedirs(opt.log_path)

    return opt


def set_loader(opt):
    """Setup data loaders with standard augmentation."""
    if opt.dataset_name == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset_name == 'tinyimagenet':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2675, 0.2565, 0.2761)

    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=opt.img_size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    if opt.dataset_name == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder, transform=train_transform, download=True)
        val_dataset = datasets.CIFAR10(root=opt.data_folder, train=False, transform=val_transform)
    elif opt.dataset_name == 'tinyimagenet':
        train_root = os.path.join(opt.data_folder, 'train')
        train_dataset = datasets.ImageFolder(root=train_root, transform=train_transform)
        val_root = os.path.join(opt.data_folder, 'val')
        val_dataset = datasets.ImageFolder(root=val_root, transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.workers, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=opt.batch_size, shuffle=False,
        num_workers=opt.workers, pin_memory=True
    )

    return train_loader, val_loader


def set_model(opt):
    """Initialize model and loss function."""
    if opt.spiking:
        model = SupCEResNetSNN(name=opt.arch, timestep=opt.timesteps, num_classes=opt.n_cls)
    else:
        model = SupCEResNet(name=opt.arch, num_classes=opt.n_cls)

    criterion = torch.nn.CrossEntropyLoss()

    model = torch.nn.DataParallel(model)
    model = model.cuda()
    criterion = criterion.cuda()
    cudnn.benchmark = True

    return model, criterion


def train(model, criterion, train_loader, optimizer, opt):
    """Train for one epoch."""
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()

    for idx, (images, labels) in enumerate(train_loader):
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        if opt.spiking:
            images = add_dimention_distribute(images, opt.timesteps)
            output = model(images)
            output = output.mean(1)  # Average over timesteps
        else:
            output = model(images)
        
        loss = criterion(output, labels)

        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return losses.avg, top1.avg


def validate(model, val_loader, opt):
    """Validate model."""
    model.eval()
    top1 = AverageMeter()

    for idx, (images, labels) in enumerate(val_loader):
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
                output = model(images)
                output = output.mean(1)
            else:
                output = model(images)

        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)
    
    return top1.avg


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    writer = SummaryWriter(log_dir=opt.log_path)

    logging.basicConfig(
        filename=os.path.join(writer.log_dir, 'training.log'),
        level=logging.DEBUG
    )
    save_config_file(writer.log_dir, opt)

    train_loader, val_loader = set_loader(opt)
    model, criterion = set_model(opt)

    optimizer = torch.optim.Adam(model.parameters(), opt.lr, weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1
    )

    logging.info(f"Start training for {opt.epochs} epochs.")

    best_val_acc = 0.0

    for epoch_counter in range(opt.epochs):
        loss, train_top1 = train(model, criterion, train_loader, optimizer, opt)
        val_top1 = validate(model, val_loader, opt)

        # Track best validation accuracy
        if val_top1 > best_val_acc:
            best_val_acc = val_top1
            save_checkpoint({
                'epoch': epoch_counter,
                'arch': opt.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_acc': best_val_acc,
            }, is_best=True, filename=os.path.join(writer.log_dir, 'best_checkpoint.pth.tar'))

        # Save periodic checkpoints
        if (epoch_counter + 1) % 10 == 0:
            save_checkpoint({
                'epoch': epoch_counter,
                'arch': opt.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best=False, filename=os.path.join(writer.log_dir, f'checkpoint_{epoch_counter+1:04d}.pth.tar'))

        writer.add_scalar('train loss', loss, global_step=epoch_counter)
        writer.add_scalar('train acc/top1', train_top1, global_step=epoch_counter)
        writer.add_scalar('val acc/top1', val_top1, global_step=epoch_counter)
        writer.add_scalar('best val acc', best_val_acc, global_step=epoch_counter)
        writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=epoch_counter)

        # Warmup for first 10 epochs
        if epoch_counter >= 10:
            scheduler.step()

        logging.debug(f"Epoch: {epoch_counter}\tTrain Loss: {loss:.4f}\tTrain Acc: {train_top1:.2f}")
        logging.debug(f"Epoch: {epoch_counter}\tVal Acc: {val_top1:.2f}\tBest Val Acc: {best_val_acc:.2f}")
        print(f"Epoch: {epoch_counter}\tTrain Loss: {loss:.4f}\tTrain Acc: {train_top1:.2f}\tVal Acc: {val_top1:.2f}")

    logging.info(f"Training finished. Best validation accuracy: {best_val_acc:.2f}")
    
    # Save final checkpoint
    checkpoint_name = f'checkpoint_{opt.epochs:04d}.pth.tar'
    if opt.note:
        checkpoint_name = f'{opt.note}_{checkpoint_name}'
    save_checkpoint({
        'epoch': opt.epochs,
        'arch': opt.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_acc': best_val_acc,
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))


if __name__ == '__main__':
    main()
