import argparse
import time

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet

import sys
import os
# 添加上级目录到 sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizers.lamb import create_lamb_optimizer
from optimizers.ALTO import create_ALTO_optimizer
from adabelief_pytorch import AdaBelief
model_names = sorted(name for name in resnet.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

parser = argparse.ArgumentParser(description='Proper ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                    ' (default: resnet32)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=16384, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--beta', '--bt', default=0.5, type=float,
                    metavar='bt', help='beta of optX')
parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam','adamW' , 'lamb', 'ALTO', 'adaBelief'])
args = parser.parse_args()

def main():
    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='../datasets', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='../datasets', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    criterion = nn.CrossEntropyLoss().cuda()
    
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), 
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), 
                                     lr=args.lr, 
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(), 
                                     lr=args.lr, 
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'adaBelief':
        optimizer = AdaBelief(model.parameters(), 
                              lr=args.lr, 
                              betas=(0.9, 0.999))
    elif args.optimizer == 'ALTO':
        optimizer = create_ALTO_optimizer(model, 
                                          lr=args.lr, 
                                          betas=(0.99, 0.9, 0.99), 
                                          weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = create_lamb_optimizer(model, 
                                          lr=args.lr, 
                                          weight_decay=args.weight_decay)
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))
    
    # 定义学习率调度器
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[args.epochs/4, args.epochs/2, 3*args.epochs/4], gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_prec1 = train(train_loader, model, criterion, optimizer)
        val_loss, val_prec1 = validate(val_loader, model, criterion)
        lr_scheduler.step()
        print('Epoch: [{0}]\t'
              'Train Loss {1:.4f}\t'
              'Train Prec@1 {2:.3f}\t'
              'Val Loss {3:.4f}\t'
              'Val Prec@1 {4:.3f}'.format(
                  epoch, train_loss, train_prec1,
                  val_loss, val_prec1))

def train(train_loader, model, criterion, optimizer):
    model.train()
    total_loss = 0.0
    total_prec1 = 0.0
    total_samples = 0

    for i, (input, target) in enumerate(train_loader):
        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        output = model(input_var)
        loss = criterion(output, target_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1 = accuracy(output.data, target)  # 直接接收返回的浮点数
        total_loss += loss.item() * input.size(0)
        total_prec1 += prec1 * input.size(0)  # 不再使用下标
        total_samples += input.size(0)

    avg_loss = total_loss / total_samples
    avg_prec1 = total_prec1 / total_samples
    return avg_loss, avg_prec1
def validate(val_loader, model, criterion):
    model.eval()
    total_loss = 0.0
    total_prec1 = 0.0
    total_samples = 0

    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()
            output = model(input_var)
            loss = criterion(output, target_var)

            prec1 = accuracy(output.data, target)  # 直接接收返回的浮点数
            total_loss += loss.item() * input.size(0)
            total_prec1 += prec1 * input.size(0)  # 不再使用下标
            total_samples += input.size(0)

    avg_loss = total_loss / total_samples
    avg_prec1 = total_prec1 / total_samples
    return avg_loss, avg_prec1

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    correct_k = correct[:maxk].view(-1).float().sum(0, keepdim=True)
    return correct_k.mul_(100.0 / batch_size).item()  # 返回一个浮点数

if __name__ == '__main__':
    main()
