"""
Few-Shot Transfer Learning Evaluation

Evaluates representation quality by training a linear classifier on frozen features
with limited labeled data (20%, 40%, 60%, 80% of training set).

This evaluation probes the linear decodability of learned representations and
their ability to support rapid task adaptation from sparse supervision.

Usage:
    python main_fewshot.py --spiking --arch resnet18 --ckpt-path <path> --fewshot-ratio 0.2
"""

from __future__ import print_function

import os
import sys
import argparse
import random
import numpy as np
from collections import defaultdict
import logging

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import save_config_file, save_checkpoint, AverageMeter, accuracy
from networks.resnet_ann import SupCEResNet, LinearClassifier
from networks.resnet_snn import SupCEResNetSNN, LinearClassifierSNN, add_dimention_distribute


def parse_option():
    parser = argparse.ArgumentParser('Few-Shot Transfer Learning')

    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4)
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
    parser.add_argument('--dataset-name', type=str, default='cifar10')
    parser.add_argument('-j', '--workers', default=16, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('-b', '--batch-size', default=512, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay')
    
    # Few-shot parameters
    parser.add_argument('--fewshot-ratio', type=float, default=0.2,
                        choices=[0.2, 0.4, 0.6, 0.8, 1.0],
                        help='Fraction of training data to use')
    parser.add_argument('--seed', type=int, default=42)
    
    # Other settings
    parser.add_argument('--save-path', type=str, default='./save/FewShot')
    parser.add_argument('--ckpt-path', type=str, required=True)
    parser.add_argument('--gpu-ids', type=str, default='0')

    opt = parser.parse_args()

    if opt.dataset_name == 'cifar10':
        opt.data_folder = "~/projects/Datasets/CIFAR10/"
        opt.img_size = 32
        opt.n_cls = 10

    if opt.spiking:
        opt.log_path = os.path.join(opt.save_path,
            f'SNN_{opt.arch}_fewshot{opt.fewshot_ratio}')
    else:
        opt.log_path = os.path.join(opt.save_path,
            f'ANN_{opt.arch}_fewshot{opt.fewshot_ratio}')
    
    if not os.path.isdir(opt.log_path):
        os.makedirs(opt.log_path)

    return opt


def create_fewshot_subset(dataset, ratio, seed=42):
    """
    Create a balanced subset of the dataset.
    
    Ensures each class has the same number of samples to maintain
    balanced evaluation across categories.
    """
    if ratio == 1.0:
        return dataset
    
    random.seed(seed)
    np.random.seed(seed)
    
    # Group samples by class
    class_to_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_to_indices[label].append(idx)
    
    # Select balanced subset
    num_classes = len(class_to_indices)
    samples_per_class = len(class_to_indices[0])
    target_samples = int(samples_per_class * ratio)
    
    subset_indices = []
    for class_label in range(num_classes):
        class_indices = class_to_indices[class_label]
        selected = random.sample(class_indices, target_samples)
        subset_indices.extend(selected)
    
    subset = torch.utils.data.Subset(dataset, subset_indices)
    print(f"Few-shot subset: {len(subset)} samples ({target_samples} per class)")
    return subset


def set_loader(opt):
    """Setup data loaders with few-shot training subset."""
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    full_train = datasets.CIFAR10(root=opt.data_folder, transform=train_transform, download=True)
    val_dataset = datasets.CIFAR10(root=opt.data_folder, train=False, transform=val_transform)

    train_dataset = create_fewshot_subset(full_train, opt.fewshot_ratio, opt.seed)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=opt.batch_size, shuffle=False,
        num_workers=opt.workers, pin_memory=True)

    return train_loader, val_loader


def set_model(opt):
    """Load pretrained encoder and initialize linear classifier."""
    if opt.spiking:
        model = SupCEResNetSNN(name=opt.arch, timestep=opt.timesteps, 
                               num_classes=opt.n_cls, mode='encoder')
        classifier = LinearClassifierSNN(name=opt.arch, num_classes=opt.n_cls)
    else:
        model = SupCEResNet(name=opt.arch, num_classes=opt.n_cls, mode='encoder')
        classifier = LinearClassifier(name=opt.arch, num_classes=opt.n_cls)

    model = torch.nn.DataParallel(model)
    model.cuda()
    
    # Load pretrained weights
    checkpoint = torch.load(opt.ckpt_path, map_location="cpu", weights_only=False)
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    print(f"Loaded pretrained model from {opt.ckpt_path}")
    
    # Freeze encoder
    for param in model.parameters():
        param.requires_grad = False
    
    classifier.cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    return model, classifier, criterion


def train_epoch(model, classifier, criterion, train_loader, optimizer, opt):
    """Train classifier for one epoch on frozen features."""
    model.eval()
    classifier.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    for images, labels in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
            features = model(images)
        
        if opt.spiking:
            output = classifier(features.detach())
            output = output.mean(1)
        else:
            output = classifier(features.detach())
        
        loss = criterion(output, labels)

        losses.update(loss.item(), bsz)
        acc1, _ = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return losses.avg, top1.avg


def validate(model, classifier, val_loader, opt):
    """Evaluate classifier on validation set."""
    model.eval()
    classifier.eval()

    top1 = AverageMeter()

    for images, labels in val_loader:
        images = images.cuda()
        labels = labels.cuda()

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
                output = classifier(model(images))
                output = output.mean(1)
            else:
                output = classifier(model(images))

        acc1, _ = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], labels.shape[0])
    
    return top1.avg


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    
    logging.basicConfig(
        filename=os.path.join(opt.log_path, 'training.log'),
        level=logging.DEBUG
    )

    train_loader, val_loader = set_loader(opt)
    model, classifier, criterion = set_model(opt)

    optimizer = torch.optim.Adam(classifier.parameters(), opt.lr, weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0)

    print(f"Starting few-shot training with {opt.fewshot_ratio*100:.0f}% of training data")
    
    best_val_acc = 0.0
    for epoch in range(opt.epochs):
        loss, train_acc = train_epoch(model, classifier, criterion, train_loader, optimizer, opt)
        val_acc = validate(model, classifier, val_loader, opt)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
        
        if epoch >= 10:
            scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%, Best={best_val_acc:.2f}%")

    print(f"\nFew-shot evaluation complete")
    print(f"Data ratio: {opt.fewshot_ratio*100:.0f}%")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")


if __name__ == '__main__':
    main()
