# Import necessary libraries
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.backends import cudnn
from bisect import bisect_right
import math
import os
import time
from tqdm import tqdm

# Define argument parser (same as original)
parser = argparse.ArgumentParser(description='PyTorch local error training')
parser.add_argument('--model', default='vgg19', help='model, mlp, vgg13, vgg16, vgg19, vgg8b, vgg11b, resnet18, resnet34, wresnet28-10 and more (default: vgg8b)')
parser.add_argument('--dataset', default='CIFAR100', help='dataset, MNIST, KuzushijiMNIST, FashionMNIST, CIFAR10, CIFAR100, SVHN, STL10 or ImageNet (default: CIFAR10)')
parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 128)')
parser.add_argument('--num-layers', type=int, default=1, help='number of hidden fully-connected layers for mlp and vgg models (default: 1)')
parser.add_argument('--num-hidden', type=int, default=1024, help='number of hidden units for mlp model (default: 1024)')
parser.add_argument('--dim-in-decoder', type=int, default=4096, help='input dimension of decoder_y used in pred and predsim loss (default: 4096)')
parser.add_argument('--feat-mult', type=float, default=1, help='multiply number of CNN features with this number (default: 1)')
parser.add_argument('--epochs', type=int, default=400, help='number of epochs to train (default: 400)')
parser.add_argument('--classes-per-batch', type=int, default=0, help='aim for this number of different classes per batch during training (default: 0, random batches)')
parser.add_argument('--classes-per-batch-until-epoch', type=int, default=0, help='limit number of classes per batch until this epoch (default: 0, until end of training)')
parser.add_argument('--lr', type=float, default=5e-4, help='initial learning rate (default: 5e-4)')
parser.add_argument('--lr-decay-milestones', nargs='+', type=int, default=[200,300,350,375], help='decay learning rate at these milestone epochs (default: [200,300,350,375])')
parser.add_argument('--lr-decay-fact', type=float, default=0.25, help='learning rate decay factor to use at milestone epochs (default: 0.25)')
parser.add_argument('--optim', default='adam', help='optimizer, adam, amsgrad or sgd (default: adam)')
parser.add_argument('--momentum', type=float, default=0.0, help='SGD momentum (default: 0.0)')
parser.add_argument('--weight-decay', type=float, default=0.0, help='weight decay (default: 0.0)')
parser.add_argument('--alpha', type=float, default=0.0, help='unsupervised fraction in similarity matching loss (default: 0.0)')
parser.add_argument('--beta', type=float, default=0.99, help='fraction of similarity matching loss in predsim loss (default: 0.99)')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout after each nonlinearity (default: 0.0)')
parser.add_argument('--loss-sup', default='predsim', help='supervised local loss, sim or pred (default: predsim)')
parser.add_argument('--loss-unsup', default='none', help='unsupervised local loss, none, sim or recon (default: none)')
parser.add_argument('--nonlin', default='relu', help='nonlinearity, relu or leakyrelu (default: relu)')
parser.add_argument('--no-similarity-std', action='store_true', default=False, help='disable use of standard deviation in similarity matrix for feature maps')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disable CUDA training')
parser.add_argument('--backprop', action='store_true', default=False, help='disable local loss training')
parser.add_argument('--no-batch-norm', action='store_true', default=False, help='disable batch norm before non-linearities')
parser.add_argument('--no-detach', action='store_true', default=False, help='do not detach computational graph')
parser.add_argument('--pre-act', action='store_true', default=False, help='use pre-activation in ResNet')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--save-dir', default='/hdd/results/local-error', type=str, help='the directory used to save the trained models')
parser.add_argument('--resume', default='', type=str, help='checkpoint to resume training from')
parser.add_argument('--progress-bar', action='store_true', default=False, help='show progress bar during training')
parser.add_argument('--no-print-stats', action='store_true', default=False, help='do not print layerwise statistics during training with local loss')
parser.add_argument('--bio', action='store_true', default=False, help='use more biologically plausible versions of pred and sim loss (default: False)')
parser.add_argument('--target-proj-size', type=int, default=128, help='size of target projection back to hidden layers for biologically plausible loss (default: 128)')
parser.add_argument('--cutout', action='store_true', default=False, help='apply cutout regularization')
parser.add_argument('--n_holes', type=int, default=1, help='number of holes to cut out from image')
parser.add_argument('--length', type=int, default=16, help='length of the cutout holes in pixels')

# Set arguments explicitly for Jupyter (matches the provided command)
args = parser.parse_args([
    '--model', 'vgg8',
    '--dataset', 'CIFAR100',
    '--dropout', '0.2',
    '--lr', '5e-4',
    '--nonlin', 'leakyrelu',
    '--dim-in-decoder', '2048',
    '--num-hidden', '1024',
    '--num-layers', '1',
    '--save-dir', './results_cross_lr1e-4',  # Override default save_dir to a notebook-friendly path
    '--loss-sup', 'pred'
])

# Configure CUDA and seeds
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args.cuda)
if args.cuda:
    cudnn.enabled = True
    cudnn.benchmark = True

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Define custom classes and functions (unchanged from original)
class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)
        mask = np.ones((h, w), np.float32)
        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)
            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)
            mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
        return img

class NClassRandomSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, targets, n_classes_per_batch, batch_size):
        self.targets = targets
        self.n_classes = int(np.max(targets))
        self.n_classes_per_batch = n_classes_per_batch
        self.batch_size = batch_size

    def __iter__(self):
        n = self.n_classes_per_batch
        ts = list(self.targets)
        ts_i = list(range(len(self.targets)))
        np.random.shuffle(ts_i)
        while len(ts_i) > 0:
            idxs, ts_i = ts_i[:n], ts_i[n:]
            t_slice_set = set([ts[i] for i in idxs])
            k = 0
            while len(t_slice_set) < 10 and k < n*10 and k < len(ts_i):
                if ts[ts_i[k]] not in t_slice_set:
                    idxs.append(ts_i.pop(k))
                    t_slice_set = set([ts[i] for i in idxs])
                else:
                    k += 1
            j = 0
            while j < len(ts_i) and len(idxs) < self.batch_size:
                if ts[ts_i[j]] in t_slice_set:
                    idxs.append(ts_i.pop(j))
                else:
                    j += 1
            if len(idxs) < self.batch_size:
                needed = self.batch_size - len(idxs)
                idxs += ts_i[:needed]
                ts_i = ts_i[needed:]
            for i in idxs:
                yield i

    def __len__(self):
        return len(self.targets)

class KuzushijiMNIST(datasets.MNIST):
    urls = [
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz'
    ]

# Data loading
kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'MNIST':
    input_dim = 28
    input_ch = 1
    num_classes = 10
    train_transform = transforms.Compose([
            transforms.RandomCrop(28, padding=2),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length)) 
    dataset_train = datasets.MNIST('../data/MNIST', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'FashionMNIST':
    input_dim = 28
    input_ch = 1
    num_classes = 10
    train_transform = transforms.Compose([
            transforms.RandomCrop(28, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.286,), (0.353,))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))      
    dataset_train = datasets.FashionMNIST('../data/FashionMNIST', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../data/FashionMNIST', train=False, 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.286,), (0.353,))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'KuzushijiMNIST':
    input_dim = 28
    input_ch = 1
    num_classes = 10
    train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1904,), (0.3475,))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = KuzushijiMNIST('../data/KuzushijiMNIST', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        KuzushijiMNIST('../data/KuzushijiMNIST', train=False, 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1904,), (0.3475,))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'CIFAR10':
    input_dim = 32
    input_ch = 3
    num_classes = 10
    train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = datasets.CIFAR10('../data/CIFAR10', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels, args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../data/CIFAR10', train=False, 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'CIFAR100':
    input_dim = 32
    input_ch = 3
    num_classes = 100
    train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.438, 0.418, 0.377), (0.300, 0.287, 0.294))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = datasets.CIFAR100('../data/CIFAR100', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels, args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('../data/CIFAR100', train=False, 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.438, 0.418, 0.377), (0.300, 0.287, 0.294))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)  
elif args.dataset == 'SVHN':
    input_dim = 32
    input_ch = 3
    num_classes = 10
    train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = torch.utils.data.ConcatDataset((
        datasets.SVHN('../data/SVHN', split='train', download=True, transform=train_transform),
        datasets.SVHN('../data/SVHN', split='extra', download=True, transform=train_transform)))
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.labels, args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.SVHN('../data/SVHN', split='test', download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'STL10':
    input_dim = 96
    input_ch = 3
    num_classes = 10
    train_transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.447, 0.440, 0.407), (0.260, 0.257, 0.271))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = datasets.STL10('../data/STL10', split='train', download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.labels, args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.STL10('../data/STL10', split='test', 
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.447, 0.440, 0.407), (0.260, 0.257, 0.271))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs) 
elif args.dataset == 'ImageNet':
    input_dim = 224
    input_ch = 3
    num_classes = 1000
    train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    dataset_train = datasets.ImageFolder('../data/ImageNet/train', transform=train_transform)
    labels = np.array([a[1] for a in dataset_train.samples])
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(labels, args.classes_per_batch, args.batch_size),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('../data/ImageNet/val', 
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'Omniglot':
    input_dim = 28
    input_ch = 1
    num_classes = 964
    train_transform = transforms.Compose([
        transforms.Resize(28),
        transforms.RandomCrop(28, padding=2),
        transforms.ToTensor(),
        transforms.Normalize((0.9221,), (0.2681,))
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
    # Load Omniglot background dataset
    full_dataset = datasets.Omniglot('../data/Omniglot', background=True, download=True, transform=train_transform)
    # Assign labels (each class has 20 samples)
    full_dataset.train_labels = np.repeat(np.arange(num_classes), 20)
    # Split dataset into train and test (e.g., 80% train, 20% test)
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    dataset_train, dataset_test = torch.utils.data.random_split(
        full_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42)  # Fixed seed for reproducibility
    )
    # DataLoader for training
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler=None if args.classes_per_batch == 0 else NClassRandomSampler(
            full_dataset.train_labels[dataset_train.indices], args.classes_per_batch, args.batch_size
        ),
        batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
    # DataLoader for testing (apply same transform as training for consistency)
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size, shuffle=False, **kwargs)
else:
    print('No valid dataset is specified')
    raise ValueError('Dataset handling for {} not implemented in this example'.format(args.dataset))

# Model components (unchanged, abbreviated for brevity; full definitions should be included)
class LinearFAFunction(torch.autograd.Function):
    @staticmethod
    def forward(context, input, weight, weight_fa, bias=None):
        context.save_for_backward(input, weight, weight_fa, bias)
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(context, grad_output):
        input, weight, weight_fa, bias = context.saved_variables
        grad_input = grad_weight = grad_weight_fa = grad_bias = None
        if context.needs_input_grad[0]:
            grad_input = grad_output.matmul(weight_fa)
        if context.needs_input_grad[1]:
            grad_weight = grad_output.t().matmul(input)
        if bias is not None and context.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        return grad_input, grad_weight, grad_weight_fa, grad_bias

class LinearFA(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(LinearFA, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        self.weight_fa = nn.Parameter(torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        if args.cuda:
            self.weight.data = self.weight.data.cuda()
            self.weight_fa.data = self.weight_fa.data.cuda()
            if bias:
                self.bias.data = self.bias.data.cuda()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        self.weight_fa.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, input):
        return LinearFAFunction.apply(input, self.weight, self.weight_fa, self.bias)

class LocalLossBlockLinear(nn.Module):
    def __init__(self, num_in, num_out, num_classes, first_layer=False, dropout=None, batchnorm=None):
        super(LocalLossBlockLinear, self).__init__()
        self.num_classes = num_classes
        self.first_layer = first_layer
        self.dropout_p = args.dropout if dropout is None else dropout
        self.batchnorm = not args.no_batch_norm if batchnorm is None else batchnorm
        self.encoder = nn.Linear(num_in, num_out, bias=True)
        if not args.backprop and args.loss_unsup == 'recon':
            self.decoder_x = nn.Linear(num_out, num_in, bias=True)
        if not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim'):
            if args.bio:
                self.decoder_y = LinearFA(num_out, args.target_proj_size)
            else:
                self.decoder_y = nn.Linear(num_out, num_classes)
            self.decoder_y.weight.data.zero_()
        if not args.backprop and args.bio:
            self.proj_y = nn.Linear(num_classes, args.target_proj_size, bias=False)
        if not args.backprop and not args.bio and (args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim'):
            self.linear_loss = nn.Linear(num_out, num_out, bias=False)
        if self.batchnorm:
            self.bn = torch.nn.BatchNorm1d(num_out)
            nn.init.constant_(self.bn.weight, 1)
            nn.init.constant_(self.bn.bias, 0)
        if args.nonlin == 'relu':
            self.nonlin = nn.ReLU(inplace=True)
        elif args.nonlin == 'leakyrelu':
            self.nonlin = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        if self.dropout_p > 0:
            self.dropout = torch.nn.Dropout(p=self.dropout_p, inplace=False)
        if args.optim == 'sgd':
            self.optimizer = optim.SGD(self.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
        elif args.optim == 'adam' or args.optim == 'amsgrad':
            self.optimizer = optim.Adam(self.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
        self.clear_stats()

    def clear_stats(self):
        if not args.no_print_stats:
            self.loss_sim = 0.0
            self.loss_pred = 0.0
            self.correct = 0
            self.examples = 0

    def print_stats(self):
        if not args.backprop:
            stats = '{}, loss_sim={:.4f}, loss_pred={:.4f}, error={:.3f}%, num_examples={}\n'.format(
                self.encoder, self.loss_sim / self.examples, self.loss_pred / self.examples,
                100.0 * float(self.examples - self.correct) / self.examples, self.examples)
            return stats
        return ''

    def set_learning_rate(self, lr):
        self.lr = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def optim_zero_grad(self):
        self.optimizer.zero_grad()

    def optim_step(self):
        self.optimizer.step()

    def forward(self, x, y, y_onehot):
        h = self.encoder(x)
        if self.batchnorm:
            h = self.bn(h)
        h = self.nonlin(h)
        h_return = h
        if self.dropout_p > 0:
            h_return = self.dropout(h_return)
        if (self.training or not args.no_print_stats) and not args.backprop:
            if args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim':
                if args.bio:
                    h_loss = h
                else:
                    h_loss = self.linear_loss(h)
                Rh = similarity_matrix(h_loss)
            if args.loss_unsup == 'sim':
                Rx = similarity_matrix(x).detach()
                loss_unsup = F.mse_loss(Rh, Rx)
            elif args.loss_unsup == 'recon' and not self.first_layer:
                x_hat = self.nonlin(self.decoder_x(h))
                loss_unsup = F.mse_loss(x_hat, x.detach())
            else:
                loss_unsup = torch.cuda.FloatTensor([0]) if args.cuda else torch.FloatTensor([0])
            if args.loss_sup == 'sim':
                Ry = similarity_matrix(self.proj_y(y_onehot) if args.bio else y_onehot).detach()
                loss_sup = F.mse_loss(Rh, Ry)
                if not args.no_print_stats:
                    self.loss_sim += loss_sup.item() * h.size(0)
                    self.examples += h.size(0)
            elif args.loss_sup == 'pred':
                y_hat_local = self.decoder_y(h.view(h.size(0), -1))
                if args.bio:
                    float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
                    y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
                    loss_sup = F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
                else:
                    loss_sup = F.cross_entropy(y_hat_local, y.detach())
                if not args.no_print_stats:
                    self.loss_pred += loss_sup.item() * h.size(0)
                    self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
                    self.examples += h.size(0)
            elif args.loss_sup == 'predsim':
                y_hat_local = self.decoder_y(h.view(h.size(0), -1))
                if args.bio:
                    Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
                    float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
                    y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
                    loss_pred = (1 - args.beta) * F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
                else:
                    Ry = similarity_matrix(y_onehot).detach()
                    loss_pred = (1 - args.beta) * F.cross_entropy(y_hat_local, y.detach())
                loss_sim = args.beta * F.mse_loss(Rh, Ry)
                loss_sup = loss_pred + loss_sim
                if not args.no_print_stats:
                    self.loss_pred += loss_pred.item() * h.size(0)
                    self.loss_sim += loss_sim.item() * h.size(0)
                    self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
                    self.examples += h.size(0)
            loss = args.alpha * loss_unsup + (1 - args.alpha) * loss_sup
            if self.training:
                loss.backward(retain_graph=args.no_detach)
            if self.training and not args.no_detach:
                self.optimizer.step()
                self.optimizer.zero_grad()
                h_return.detach_()
            loss = loss.item()
        else:
            loss = 0.0
        return h_return, loss

class OrthogonalLoss(nn.Module):
    def __init__(self):
        super(OrthogonalLoss, self).__init__()
        
    def forward(self, outputs, labels):
        probs = torch.tanh(outputs)
        batch_size = outputs.size(0)
        num_classes = outputs.size(1)
        loss = 0.0
        one_hot = torch.zeros_like(probs).scatter_(1, labels.unsqueeze(1), 1.0)
        
        for j in range(num_classes):
            p_j = probs[:, j]
            diag_term = (1.0 - p_j) ** 2
            off_diag_term = torch.sum(probs ** 2, dim=1) - p_j ** 2
            mask = one_hot[:, j]
            loss += torch.sum(mask * (diag_term + off_diag_term))
        
        loss = loss / batch_size
        return loss

class LocalLossBlockConv(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride, padding, num_classes, dim_out, first_layer=False, dropout=None, bias=None, pre_act=False, post_act=True):
        super(LocalLossBlockConv, self).__init__()
        self.ch_in = ch_in
        self.ch_out = ch_out
        self.num_classes = num_classes
        self.first_layer = first_layer
        self.dropout_p = args.dropout if dropout is None else dropout
        self.bias = True if bias is None else bias
        self.pre_act = pre_act
        self.post_act = post_act
        self.encoder = nn.Conv2d(ch_in, ch_out, kernel_size, stride=stride, padding=padding, bias=self.bias)
        if not args.backprop and args.loss_unsup == 'recon':
            print('aaaa')
            self.decoder_x = nn.ConvTranspose2d(ch_out, ch_in, kernel_size, stride=stride, padding=padding)
        if args.bio or (not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim')):
            ks_h, ks_w = 1, 1
            dim_out_h, dim_out_w = dim_out, dim_out
            dim_in_decoder = ch_out * dim_out_h * dim_out_w
            while dim_in_decoder > args.dim_in_decoder and ks_h < dim_out:
                ks_h *= 2
                dim_out_h = math.ceil(dim_out / ks_h)
                dim_in_decoder = ch_out * dim_out_h * dim_out_w
                if dim_in_decoder > args.dim_in_decoder:
                    ks_w *= 2
                    dim_out_w = math.ceil(dim_out / ks_w)
                    dim_in_decoder = ch_out * dim_out_h * dim_out_w
            if ks_h > 1 or ks_w > 1:
                pad_h = (ks_h * (dim_out_h - dim_out // ks_h)) // 2
                pad_w = (ks_w * (dim_out_w - dim_out // ks_w)) // 2
                print(ks_h, ks_w)
                self.avg_pool = nn.AvgPool2d((ks_h, ks_w), padding=(pad_h, pad_w))
            else:
                self.avg_pool = None
        if not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim'):
            if args.bio:
                self.decoder_y = LinearFA(dim_in_decoder, args.target_proj_size)
            else:
                self.decoder_y = nn.Linear(dim_in_decoder, num_classes)
            self.decoder_y.weight.data.zero_()
        if not args.backprop and args.bio:
            self.proj_y = nn.Linear(num_classes, args.target_proj_size, bias=False)
        if not args.backprop and (args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim'):
            self.conv_loss = nn.Conv2d(ch_out, ch_out, 3, stride=1, padding=1, bias=False)
        if not args.no_batch_norm:
            if pre_act:
                self.bn_pre = torch.nn.BatchNorm2d(ch_in)
            if not (pre_act and args.backprop):
                self.bn = torch.nn.BatchNorm2d(ch_out)
                nn.init.constant_(self.bn.weight, 1)
                nn.init.constant_(self.bn.bias, 0)
        if args.nonlin == 'relu':
            self.nonlin = nn.ReLU(inplace=True)
        elif args.nonlin == 'leakyrelu':
            self.nonlin = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        if self.dropout_p > 0:
            self.dropout = torch.nn.Dropout2d(p=self.dropout_p, inplace=False)
        if args.optim == 'sgd':
            # print('sgd')
            self.optimizer = optim.SGD(self.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
        elif args.optim == 'adam' or args.optim == 'amsgrad':
            # print('adam')
            self.optimizer = optim.Adam(self.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
        self.clear_stats()

    def clear_stats(self):
        if not args.no_print_stats:
            self.loss_sim = 0.0
            self.loss_pred = 0.0
            self.correct = 0
            self.examples = 0

    def print_stats(self):
        if not args.backprop:
            # stats = '{}, loss_sim={:.4f}, loss_pred={:.4f}, error={:.3f}%, num_examples={}\n'.format(
            #     self.encoder, self.loss_sim / self.examples, self.loss_pred / self.examples,
            #     100.0 * float(self.examples - self.correct) / self.examples, self.examples)
            # return stats
            return ''
        return ''

    def set_learning_rate(self, lr):
        self.lr = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def optim_zero_grad(self):
        self.optimizer.zero_grad()

    def optim_step(self):
        self.optimizer.step()

    def forward(self, x, y, y_onehot, x_shortcut=None):
        # if self.training and not args.no_detach:
        #     x.detach_()

        if self.pre_act:
            if not args.no_batch_norm:
                x = self.bn_pre(x)
            x = self.nonlin(x)
            if self.dropout_p > 0:
                x = self.dropout(x)
        h = self.encoder(x)
        if self.post_act and not args.no_batch_norm:
            h = self.bn(h)
        if x_shortcut is not None:
            h = h + x_shortcut
        if self.post_act:
            h = self.nonlin(h)
        h_return = h
        if self.post_act and self.dropout_p > 0:
            h_return = self.dropout(h_return)
        if 0:
            # if not self.post_act:
            #     if not args.no_batch_norm:
            #         h = self.bn(h)
            #     h = self.nonlin(h)
            if args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim':
                if args.bio:
                    h_loss = h
                    if self.avg_pool is not None:
                        h_loss = self.avg_pool(h_loss)
                else:
                    h_loss = self.conv_loss(h)
                Rh = similarity_matrix(h_loss)
            if args.loss_unsup == 'sim':
                Rx = similarity_matrix(x).detach()
                loss_unsup = F.mse_loss(Rh, Rx)
            elif args.loss_unsup == 'recon' and not self.first_layer:
                x_hat = self.nonlin(self.decoder_x(h))
                loss_unsup = F.mse_loss(x_hat, x.detach())
            else:
                loss_unsup = torch.cuda.FloatTensor([0]) if args.cuda else torch.FloatTensor([0])
            if args.loss_sup == 'sim':
                Ry = similarity_matrix(self.proj_y(y_onehot) if args.bio else y_onehot).detach()
                loss_sup = F.mse_loss(Rh, Ry)
                if not args.no_print_stats:
                    self.loss_sim += loss_sup.item() * h.size(0)
                    self.examples += h.size(0)
            elif args.loss_sup == 'pred':
                if self.avg_pool is not None:
                    h = self.avg_pool(h)
                y_hat_local = self.decoder_y(h.view(h.size(0), -1))
                if args.bio:
                    float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
                    y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
                    loss_sup = F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
                else:
                    loss_sup = F.cross_entropy(y_hat_local, y.detach())
                if not args.no_print_stats:
                    self.loss_pred += loss_sup.item() * h.size(0)
                    self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
                    self.examples += h.size(0)
            elif args.loss_sup == 'predsim':
                if self.avg_pool is not None:
                    h = self.avg_pool(h)
                y_hat_local = self.decoder_y(h.view(h.size(0), -1))
                if args.bio:
                    Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
                    float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
                    y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
                    loss_pred = (1 - args.beta) * F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
                else:
                    Ry = similarity_matrix(y_onehot).detach()
                    loss_pred = (1 - args.beta) * F.cross_entropy(y_hat_local, y.detach())
                loss_sim = args.beta * F.mse_loss(Rh, Ry)
                loss_sup = loss_pred + loss_sim
                if not args.no_print_stats:
                    self.loss_pred += loss_pred.item() * h.size(0)
                    self.loss_sim += loss_sim.item() * h.size(0)
                    self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
                    self.examples += h.size(0)
            loss = args.alpha * loss_unsup + (1 - args.alpha) * loss_sup
            if self.training:
                loss.backward(retain_graph=args.no_detach)
            if self.training and not args.no_detach:
                self.optimizer.step()
                self.optimizer.zero_grad()
            loss = loss.item()
        else:
            h_return.detach_()
            # print('h', h.shape)
            if self.avg_pool is not None:
                # print('pool')
                h = self.avg_pool(h)
            # self.optimizer.step()
            # self.optimizer.zero_grad()
            # print('h_return', h.shape, h_return.shape)

            
            y_hat_local = h.view(h.size(0), -1)
            loss = 0.0

            # print(h_return.shape, y_hat_local.shape)
        return h_return, y_hat_local

# Define VGGn and other model classes (abbreviated; include full definitions)
cfg = {
    'vgg6':  [128, 'M', 256, 'M', 512, 'M', 512, 'M'],
    'vgg8':  [128, 128, 'M', 256, 256, 'M', 512, 'M', 512, 'M'],
    'vgg11a': [128, 128, 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 'M'],
    'vgg11a2x': [256, 256, 256, 256, 'M', 512, 512, 'M', 1024, 1024, 'M', 1024, 'M'],
    'vgg11a4x': [512, 512, 512, 512, 'M', 1024, 1024, 'M', 1024, 1024, 'M', 1024, 'M'],
    'vgg16a': [128, 128, 128, 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 'M'],
    'vgg16a2x': [256, 256, 256, 256, 256, 256, 'M', 512, 512, 512, 'M', 1024, 1024, 1024, 'M', 1024, 1024, 'M'],
    'vgg16a4x': [512, 512, 512, 512, 512, 512, 'M', 1024, 1024, 1024, 'M', 1024, 1024, 1024, 'M', 1024, 1024, 'M'],
    'vgg19a': [128, 128, 128, 128, 128, 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
}


class VGGn(nn.Module):
    def __init__(self, vgg_name, input_dim, input_ch, num_classes, feat_mult=1):
        super(VGGn, self).__init__()
        self.cfg = cfg[vgg_name]
        self.input_dim = input_dim
        self.input_ch = input_ch
        self.num_classes = num_classes
        self.total_feature = 0
        self.features, output_dim = self._make_layers(self.cfg, input_ch, input_dim, feat_mult)
        # print('self.total_feature', self.total_feature)
        # print('args.dim_in_decoder', args.dim_in_decoder)
        
        lay_num = sum(1 for item in cfg[args.model] if isinstance(item, (int, float)))
        lay_size = [item for item in cfg[args.model] if isinstance(item, (int, float))]
        pooling_num = sum(1 for item in cfg[args.model] if isinstance(item, str))
        # print('lay_num', lay_num)
        # print(args.num_hidden, args.num_layers, args.dim_in_decoder, lay_num, pooling_num)
        
        self.decoder_y = nn.Linear(args.dim_in_decoder * lay_num+args.num_hidden*args.num_layers, num_classes)
        self.decoder_y.weight.data.zero_()
        self.criterion = OrthogonalLoss()
        for layer in self.cfg:
            if isinstance(layer, int):
                output_ch = layer
        # if args.num_layers > 0:
        #     self.classifier = Net(args.num_layers, args.num_hidden, output_dim, int(output_ch * feat_mult), num_classes)
        # else:

        # Iterative downsampling with floor division
        downsampled_dim = input_dim
        for _ in range(pooling_num):
            downsampled_dim = downsampled_dim // 2
        
        # Check for valid dimension
        if downsampled_dim < 1:
            raise ValueError(f"Input dimension {input_dim} reduced to {downsampled_dim} after {pooling_num} pooling operations. Increase input_dim or reduce pooling_num.")
        self.layers = nn.ModuleList([nn.Linear(int(downsampled_dim ** 2 * lay_size[-1]), args.num_hidden)])
        self.layers.extend([nn.Linear(args.num_hidden, args.num_hidden) for i in range(1, args.num_layers)])
        # self.layer_out = nn.Linear(args.num_hidden, num_classes)
        

        self.bns = nn.ModuleList([torch.nn.BatchNorm1d(args.num_hidden) for i in range(0, args.num_layers)])
        for i in range(args.num_layers):
            nn.init.constant_(self.bns[i].weight, 1)
            nn.init.constant_(self.bns[i].bias, 0)
        self.nonlins = nn.ModuleList([nn.LeakyReLU(negative_slope=0.01, inplace=True) for i in range(0, args.num_layers)])

    def _make_layers(self, cfg, input_ch, input_dim, feat_mult):
        layers = []
        first_layer = True
        scale_cum = 1
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                scale_cum *= 2
            else:
                x = int(x * feat_mult)
                self.total_feature += x
                layers += [LocalLossBlockConv(input_ch, x, kernel_size=3, stride=1, padding=1,
                                              num_classes=num_classes, dim_out=input_dim//scale_cum,
                                              first_layer=first_layer)]
                input_ch = x
                first_layer = False
        return nn.Sequential(*layers), input_dim//scale_cum

    def set_learning_rate(self, lr):
        for i, layer in enumerate(self.cfg):
            if isinstance(layer, int):
                self.features[i].set_learning_rate(lr)
        # if args.num_layers > 0:
        #     self.classifier.set_learning_rate(lr)

    def optim_zero_grad(self):
        for i, layer in enumerate(self.cfg):
            if isinstance(layer, int):
                self.features[i].optim_zero_grad()
        # if args.num_layers > 0:
        #     self.classifier.optim_zero_grad()

    def optim_step(self):
        for i, layer in enumerate(self.cfg):
            if isinstance(layer, int):
                self.features[i].optim_step()
        # if args.num_layers > 0:
        #     self.classifier.optim_step()

    def forward(self, x, y, y_onehot, epoch=0):
        loss_total = 0
        x_save = x.view(x.size(0), -1)
        x_cut = x_save.size(1)
        for i, layer in enumerate(self.cfg):
            if isinstance(layer, int):
                x, x_for_loss = self.features[i](x, y, y_onehot)
                # print('x_for_loss', x_for_loss.shape)
                # print('x', x_for_loss.shape)
                x_save = torch.cat([x_save, x_for_loss], dim=1)
            else:
                x = self.features[i](x)
                # print('x_else', x.shape)
        x_save = x_save[:,x_cut:]
        h = x.view(x.size(0), -1)

        
        for i in range(args.num_layers):
            h = self.layers[i](h.detach())
            h = self.bns[i](h)
            h = self.nonlins[i](h)
            x_save = torch.cat([x_save, h], dim=1)

        # 反向传播
        # y_hat_local = self.layer_out(h)
        # loss_sup = F.cross_entropy(y_hat_local, y.detach())
        # 交叉熵约束
        y_hat_local = self.decoder_y(x_save)
        loss_sup = F.cross_entropy(y_hat_local, y.detach())


        # 两种loss混合
        # loss_total += (1 - epoch/400)*loss_sup + epoch/400*loss_sup_2
        # 只有交叉熵
        loss_total += loss_sup
        
        return y_hat_local, loss_total

class Net(nn.Module):
    def __init__(self, num_layers, num_hidden, input_dim, input_ch, num_classes):
        super(Net, self).__init__()
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        reduce_factor = 1
        self.layers = nn.ModuleList([LocalLossBlockLinear(input_dim*input_dim*input_ch, num_hidden, num_classes, first_layer=True)])
        self.layers.extend([LocalLossBlockLinear(int(num_hidden // (reduce_factor**(i-1))), int(num_hidden // (reduce_factor**i)), num_classes) for i in range(1, num_layers)])
        self.layer_out = nn.Linear(int(num_hidden // (reduce_factor**(num_layers-1))), num_classes)
        if not args.backprop:
            self.layer_out.weight.data.zero_()

    def parameters(self):
        if not args.backprop:
            return self.layer_out.parameters()
        else:
            return super(Net, self).parameters()

    def set_learning_rate(self, lr):
        for i, layer in enumerate(self.layers):
            layer.set_learning_rate(lr)

    def optim_zero_grad(self):
        for i, layer in enumerate(self.layers):
            layer.optim_zero_grad()

    def optim_step(self):
        for i, layer in enumerate(self.layers):
            layer.optim_step()

    def forward(self, x, y, y_onehot):
        x = x.view(x.size(0), -1)
        total_loss = 0.0
        for i, layer in enumerate(self.layers):
            x, loss = layer(x, y, y_onehot)
            total_loss += loss
        x = self.layer_out(x)
        return x, total_loss

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def to_one_hot(y, n_dims=None):
    y_tensor = y.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    y_one_hot = y_one_hot.view(*y.shape, -1)
    return y_one_hot

def similarity_matrix(x):
    if x.dim() == 4:
        if not args.no_similarity_std and x.size(1) > 3 and x.size(2) > 1:
            z = x.view(x.size(0), x.size(1), -1)
            x = z.std(dim=2)
        else:
            x = x.view(x.size(0), -1)
    xc = x - x.mean(dim=1).unsqueeze(1)
    xn = xc / (1e-8 + torch.sqrt(torch.sum(xc**2, dim=1))).unsqueeze(1)
    R = xn.matmul(xn.transpose(1,0)).clamp(-1,1)
    return R

# Model instantiation
checkpoint = None
if not args.resume == '':
    if os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume)
        args.model = checkpoint['args'].model
        args_backup = args
        args = checkpoint['args']
        args.optim = args_backup.optim
        args.momentum = args_backup.momentum
        args.weight_decay = args_backup.weight_decay
        args.dropout = args_backup.dropout
        args.no_batch_norm = args_backup.no_batch_norm
        args.cutout = args_backup.cutout
        args.length = args_backup.length
        print('=> loaded checkpoint "{}" (epoch {})'.format(args.resume, checkpoint['epoch']))
    else:
        print('Checkpoint not found: {}'.format(args.resume))

if args.model.startswith('vgg'):
    model = VGGn(args.model, input_dim, input_ch, num_classes, args.feat_mult)
else:
    raise ValueError('Model {} not implemented in this example'.format(args.model))

if checkpoint is not None:
    model.load_state_dict(checkpoint['state_dict'])
    args = args_backup

if args.cuda:
    model.cuda()

if args.progress_bar:
    from tqdm import tqdm

if args.optim == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
elif args.optim == 'adam' or args.optim == 'amsgrad':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')

model.set_learning_rate(args.lr)
print(model)
print('Model {} has {} parameters influenced by global loss'.format(args.model, count_parameters(model)))

# Training and testing functions

def train(epoch, lr):
    model.train()
    correct = 0
    correct_2 = 0
    loss_total_local = 0
    loss_total_global = 0

    total_forward_time = 0.0
    total_backward_time = 0.0

    args.progress_bar = 1

    if args.progress_bar:
        pbar = tqdm(total=len(train_loader))
    if not args.no_print_stats:
        for m in model.modules():
            if isinstance(m, LocalLossBlockLinear) or isinstance(m, LocalLossBlockConv):
                m.clear_stats()

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        target_ = target
        target_onehot = to_one_hot(target, num_classes)
        if args.cuda:
            target_onehot = target_onehot.cuda()

        optimizer.zero_grad()

        # ---- 前向传播计时 ----
        torch.cuda.synchronize()
        t_fwd_start = time.time()

        output, loss = model(data, target, target_onehot, epoch)

        torch.cuda.synchronize()
        t_fwd_end = time.time()
        forward_time = t_fwd_end - t_fwd_start
        total_forward_time += forward_time
        # ---------------------

        loss_total_local += loss * data.size(0)

        # ---- 反向传播计时 ----
        torch.cuda.synchronize()
        t_bwd_start = time.time()

        loss.backward()

        torch.cuda.synchronize()
        t_bwd_end = time.time()
        backward_time = t_bwd_end - t_bwd_start
        total_backward_time += backward_time
        # ---------------------

        optimizer.step()

        pred = output.max(1)[1]
        correct += pred.eq(target_).cpu().sum()

        if args.progress_bar:
            pbar.set_postfix(loss=loss.item(),
                             fwd_time=f"{forward_time*1000:.2f}ms",
                             bwd_time=f"{backward_time*1000:.2f}ms",
                             refresh=False)
            pbar.update()

    if args.progress_bar:
        pbar.close()

    # ===== 平均 & 总时间 =====
    num_batches = len(train_loader)
    avg_forward_time = total_forward_time / num_batches
    avg_backward_time = total_backward_time / num_batches

    print(f"[Time Stats] Forward: total={total_forward_time:.3f}s, avg={avg_forward_time*1000:.2f}ms/batch")
    print(f"[Time Stats] Backward: total={total_backward_time:.3f}s, avg={avg_backward_time*1000:.2f}ms/batch")
    # ========================

    loss_average_local = loss_total_local / len(train_loader.dataset)
    loss_average_global = loss_total_global / len(train_loader.dataset)
    error_percent = 100 - 100.0 * float(correct) / len(train_loader.dataset)


    string_print = (
        'Train epoch={}, lr={:.2e}, loss_local={:.4f}, loss_global={:.4f}, '
        'error={:.3f}%, mem={:.0f}MiB, max_mem={:.0f}MiB\n'
        .format(epoch, lr, loss_average_local, loss_average_global,
                error_percent,
                torch.cuda.memory_allocated()/1e6,
                torch.cuda.max_memory_allocated()/1e6)
    )

    if not args.no_print_stats:
        for m in model.modules():
            if isinstance(m, LocalLossBlockLinear) or isinstance(m, LocalLossBlockConv):
                string_print += m.print_stats()
    print(string_print)
    return loss_average_local + loss_average_global, error_percent, string_print

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    correct_2 = 0
    if not args.no_print_stats:
        for m in model.modules():
            if isinstance(m, LocalLossBlockLinear) or isinstance(m, LocalLossBlockConv):
                m.clear_stats()
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        target_ = target
        target_onehot = to_one_hot(target, num_classes)
        if args.cuda:
            target_onehot = target_onehot.cuda()
        with torch.no_grad():
            output, _ = model(data, target, target_onehot)
            test_loss += F.cross_entropy(output, target).item() * data.size(0)
        pred = output.max(1)[1]
        correct += pred.eq(target_).cpu().sum()

    loss_average = test_loss / len(test_loader.dataset)
    if args.loss_sup == 'predsim' and not args.backprop:
        loss_average *= (1 - args.beta)
    error_percent = 100 - 100.0 * float(correct) / len(test_loader.dataset)
    string_print = 'Test loss_global={:.4f}, error={:.3f}%\n'.format(loss_average, error_percent)
    if not args.no_print_stats:
        for m in model.modules():
            if isinstance(m, LocalLossBlockLinear) or isinstance(m, LocalLossBlockConv):
                string_print += m.print_stats()
    print(string_print)
    return loss_average, error_percent, string_print

# Main training loop
start_epoch = 1 if checkpoint is None else 1 + checkpoint['epoch']
for epoch in range(start_epoch, args.epochs + 1):
    lr = args.lr * args.lr_decay_fact ** bisect_right(args.lr_decay_milestones, (epoch-1))
    save_state_dict = False
    for ms in args.lr_decay_milestones:
        if (epoch-1) == ms:
            print('Decaying learning rate to {}'.format(lr))
            decay = True
        elif epoch == ms:
            save_state_dict = True
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    model.set_learning_rate(lr)
    if args.classes_per_batch_until_epoch > 0 and epoch > args.classes_per_batch_until_epoch and isinstance(train_loader.sampler, NClassRandomSampler):
        print('Remove NClassRandomSampler from train_loader')
        train_loader = torch.utils.data.DataLoader(dataset_train, sampler=None, batch_size=args.batch_size, shuffle=True, **kwargs)
    train_loss, train_error, train_print = train(epoch, lr)
    test_loss, test_error, test_print = test(epoch)
    if args.save_dir != '':
        filename = 'chkp_ep{}_lr{:.2e}_trainloss{:.2f}_testloss{:.2f}_trainerr{:.2f}_testerr{:.2f}.tar'.format(
            epoch, lr, train_loss, test_loss, train_error, test_error)
        dirname = os.path.join(args.save_dir, args.dataset)
        dirname = os.path.join(dirname, '{}_mult{:.1f}'.format(args.model, args.feat_mult))
        dirname = os.path.join(dirname, '{}_{}x{}_{}_{}_dimdec{}_alpha{}_beta{}_bs{}_cpb{}_drop{}{}_bn{}_{}_wd{}_bp{}_detach{}_lr{:.2e}'.format(
            args.nonlin, args.num_layers, args.num_hidden, args.loss_sup + '-bio' if args.bio else args.loss_sup, args.loss_unsup, args.dim_in_decoder, args.alpha,
            args.beta, args.batch_size, args.classes_per_batch, args.dropout, '_cutout{}x{}'.format(args.n_holes, args.length) if args.cutout else '',
            int(not args.no_batch_norm), args.optim, args.weight_decay, int(args.backprop), int(not args.no_detach), args.lr))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        elif epoch == 1 and os.path.exists(dirname):
            for f in os.listdir(dirname):
                os.remove(os.path.join(dirname, f))
        with open(os.path.join(dirname, 'log.txt'), 'a') as f:
            if epoch == 1:
                f.write('{}\n\n'.format(args))
                f.write('{}\n\n'.format(model))
                f.write('{}\n\n'.format(optimizer))
                f.write('Model {} has {} parameters influenced by global loss\n\n'.format(args.model, count_parameters(model)))
            f.write(train_print)
            f.write(test_print)
            f.write('\n')
        torch.save({
            'epoch': epoch,
            'args': args,
            'state_dict': model.state_dict(),
            'train_loss': train_error,
            'train_error': train_error,
            'test_loss': test_loss,
            'test_error': test_error,
        }, os.path.join(dirname, filename))
        torch.save({
            'epoch': epoch,
            'args': args,
            'state_dict': model.state_dict(),
            'train_loss': train_error,
            'train_error': train_error,
            'test_loss': test_loss,
            'test_error': test_error,
        }, os.path.join(dirname, 'chkp_last_epoch.tar'))