"""
Linear Evaluation Protocol

Evaluate learned representations by training a linear classifier
on frozen encoder features (standard protocol for self-supervised learning).

Usage:
    # Evaluate ANN
    python main_evaluator.py --arch resnet18 --ckpt-path <path>
    
    # Evaluate SNN
    python main_evaluator.py --spiking --arch resnet18 --ckpt-path <path>
"""

import os
import sys
import argparse
import logging
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import save_config_file, save_checkpoint, AverageMeter, accuracy
from networks.resnet_ann import SupConResNet, LinearClassifier
from networks.resnet_snn import SupConResNetSNN, LinearClassifierSNN, add_dimention_distribute


def parse_option():
    parser = argparse.ArgumentParser('Linear Evaluation')

    # Model settings
    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4)
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
    
    # Dataset settings
    parser.add_argument('--dataset-name', type=str, default='cifar10')
    parser.add_argument('-j', '--workers', default=16, type=int)
    
    # Training settings
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('-b', '--batch-size', default=512, type=int)
    parser.add_argument('--lr', default=0.0003, type=float)
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay')
    
    # Other settings
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save-path', type=str, default='./save/Eval')
    parser.add_argument('--ckpt-path', type=str, required=True, help='Path to pretrained checkpoint')
    parser.add_argument('--gpu-ids', type=str, default='0,1,2,3')
    parser.add_argument('--note', type=str, default='')

    opt = parser.parse_args()
    
    # Set dataset-specific parameters
    if opt.dataset_name == 'cifar10':
        opt.data_folder = "~/projects/Datasets/CIFAR10/"
        opt.img_size = 32
        opt.n_cls = 10
    elif opt.dataset_name == 'tinyimagenet':
        opt.data_folder = "~/projects/Datasets/tiny-imagenet-200/"
        opt.img_size = 64
        opt.n_cls = 200

    # Set log path
    if opt.spiking:
        opt.log_path = os.path.join(
            opt.save_path,
            f'SNN_{opt.dataset_name}_{opt.arch}_T{opt.timesteps}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    else:
        opt.log_path = os.path.join(
            opt.save_path,
            f'ANN_{opt.dataset_name}_{opt.arch}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    
    if not os.path.isdir(opt.log_path):
        os.makedirs(opt.log_path)

    return opt


def set_loader(opt):
    """Setup data loaders."""
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = datasets.CIFAR10(root=opt.data_folder, transform=train_transform, download=True)
    val_dataset = datasets.CIFAR10(root=opt.data_folder, train=False, transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        num_workers=opt.workers, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=opt.batch_size, shuffle=False,
        num_workers=opt.workers, pin_memory=True
    )

    return train_loader, val_loader


def set_model(opt):
    """Initialize model and load pretrained weights."""
    if opt.spiking:
        model = SupConResNetSNN(name=opt.arch, timestep=opt.timesteps)
        classifier = LinearClassifierSNN(name=opt.arch, num_classes=opt.n_cls)
    else:
        model = SupConResNet(name=opt.arch)
        classifier = LinearClassifier(name=opt.arch, num_classes=opt.n_cls)

    model.encoder = torch.nn.DataParallel(model.encoder)
    model.cuda()
    
    # Load pretrained checkpoint
    checkpoint = torch.load(opt.ckpt_path, map_location="cpu", weights_only=False)
    if 'model' in checkpoint.keys():
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    print("Successfully loaded pretrained model")

    classifier.cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    return model, classifier, criterion


def train(model, classifier, criterion, train_loader, optimizer, opt):
    """Train linear classifier on frozen features."""
    model.eval()  # Freeze encoder
    classifier.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    for idx, (images, labels) in enumerate(train_loader):
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        # Extract features with frozen encoder
        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
            features = model.encoder(images)
        
        # Train classifier
        if opt.spiking:
            output = classifier(features.detach())
            output = output.mean(1)
        else:
            output = classifier(features.detach())
        
        loss = criterion(output, labels)

        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return losses.avg, top1.avg


def validate(model, classifier, val_loader, opt):
    """Validate classifier."""
    model.eval()
    classifier.eval()

    top1 = AverageMeter()

    for idx, (images, labels) in enumerate(val_loader):
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
                output = classifier(model.encoder(images))
                output = output.mean(1)
            else:
                output = classifier(model.encoder(images))

        acc1 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)
    
    return top1.avg


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    writer = SummaryWriter(log_dir=opt.log_path)

    if opt.note:
        logging.basicConfig(
            filename=os.path.join(writer.log_dir, f'{opt.note}_training.log'),
            level=logging.DEBUG
        )
    else:
        logging.basicConfig(
            filename=os.path.join(writer.log_dir, 'training.log'),
            level=logging.DEBUG
        )

    save_config_file(writer.log_dir, opt)

    train_loader, val_loader = set_loader(opt)
    model, classifier, criterion = set_model(opt)

    optimizer = torch.optim.Adam(classifier.parameters(), opt.lr, weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1
    )

    logging.info(f"Start linear evaluation for {opt.epochs} epochs.")

    for epoch_counter in range(opt.epochs):
        loss, train_top1 = train(model, classifier, criterion, train_loader, optimizer, opt)
        val_top1 = validate(model, classifier, val_loader, opt)

        writer.add_scalar('train loss', loss, global_step=epoch_counter)
        writer.add_scalar('train acc/top1', train_top1, global_step=epoch_counter)
        writer.add_scalar('val acc/top1', val_top1, global_step=epoch_counter)
        writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=epoch_counter)
        
        if epoch_counter >= 10:
            scheduler.step()
        
        logging.debug(f"Epoch: {epoch_counter}\tTrain Loss: {loss}\tTrain Acc: {train_top1}")
        print(f"Epoch: {epoch_counter}\tTrain Loss: {loss:.4f}\tTrain Acc: {train_top1:.2f}\tVal Acc: {val_top1:.2f}")

    logging.info("Linear evaluation finished.")
    
    # Save final classifier
    checkpoint_name = f'checkpoint_{opt.epochs:04d}.pth.tar'
    if opt.note:
        checkpoint_name = f'{opt.note}_{checkpoint_name}'
    save_checkpoint({
        'epoch': opt.epochs,
        'arch': opt.arch,
        'state_dict': classifier.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))


if __name__ == '__main__':
    main()
