


from utils import *
from models import *

from torch.optim.lr_scheduler import LambdaLR
import numpy.linalg as la

import argparse




def printf(*a):
    with open(f'python-prints.log', 'a') as f:
        print(*a, file=f)



print('''\n\n\n\n                  
99988          \n\n\n\n''')


parser = argparse.ArgumentParser()

parser.add_argument('--train_or_eva', type=str, default='train')
parser.add_argument("--net_desc", type=str, default="r50")
parser.add_argument("--cwd", type=str, default="no pretrained cwd")
parser.add_argument("--dataset", type=str, default="Cifar10", choices=['Cifar10','Cifar100'])
parser.add_argument("--optType", type=str, default="SGD+cos", help="['SGD+cos','SGD','Adam+cos','Adam']")
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--assignGPU", type=int, default=-1)
parser.add_argument("--n_epochs", type=int, default=200)



pargs = parser.parse_args()








if torch.cuda.is_available():
    if pargs.assignGPU==-1:
        whichGPU = bestGPU()
    else:
        whichGPU = pargs.assignGPU

    # os.environ["CUDA_VISIBLE_DEVICES"] = str(whichGPU)
    torch.cuda.set_device(whichGPU)


CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if CUDA else 'cpu')

os.makedirs('wz_saved_models',exist_ok=True)
os.makedirs('wIns',exist_ok=True)



# DEVICE = torch.device('cpu')

# print(DEVICE,whichGPU)

# raise


# if torch.device('cuda' if torch.cuda.is_available():
#     net = torch.nn.DataParallel(net)
#     cudnn.benchmark = True



def use_model(net_desc='r18',cwd='No pretrained cwd',dataset=None, **w):
    print('==> Building model..')

    # net = VGG('VGG19')
    if dataset=='Cifar10':
        n_class = 10
    if dataset=='Cifar100':
        n_class = 100


    if net_desc=='r18':
        net = ResNet18(n_class)
        net.name = 'ResNet18'
    elif net_desc=='r50':
        net = ResNet50(n_class)
        net.name = 'ResNet50'
    elif net_desc=='r152':
        net = ResNet152(n_class)
        net.name = 'ResNet152'
    elif net_desc=='m2':
        net = MobileNetV2(n_class)
        net.name = 'MobileNetV2'



    elif net_desc=='dla':
        net = DLA(n_class)
        net.name = 'DLA'





    from w import viz
    viz(net)
    raise






    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    # net = SimpleDLA()

    # ====== 以下l2o的模型 ======
    # net = SMALL_CNN(**args)
    # net = resnet18(**args)
    # net = resnet18_conv1_3x3(**args)
    # net = resnet34_conv1_3x3(**args)
    # ====== 以上l2o的模型 ======

    net = net.to(DEVICE)


    # if cwd=='':
    #     print('\n\n Did not use previous model!!\n\n')
    # else:

    load_model(net,cwd)

    return net



def getArgs(pargs):

    args = {
            'num_workers':      2,
            'resume':           0,
            'pin_memory':       True,
            'criterion':        nn.CrossEntropyLoss(),


            }

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])



    if pargs.dataset=='Cifar10':
        trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=False, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root='./datasets', train=False, download=False, transform=transform_test)
    elif pargs.dataset=='Cifar100':
        trainset = torchvision.datasets.CIFAR100(root='./datasets', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root='./datasets', train=False, download=False, transform=transform_test)


    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=args['num_workers'], pin_memory=args['pin_memory'])
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=args['num_workers'], pin_memory=args['pin_memory'])

    args['trainloader'] = trainloader
    args['testloader'] = testloader
    args['net'] = use_model(**vars(pargs))

    return args




def trainNN_zoo(num_workers=None,resume=None,lr=None,cwd='',n_epochs=None,trainloader=None,testloader=None,net=None,criterion=None,optType=None,momentum=None,**w):

    best_acc = 0

    def getOptimizer(epoch=0, optimizer=None, scheduler=None):
        def getNewOpt(epoch, optimizer, scheduler):
            if 'SGD' in optType:
                optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4, nesterov=('nesterov' in optType))
            elif 'Adam' in optType:
                optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
            elif 'symbolicL2O' in optType:
                optimizer = getSymbolicL2O(epoch, optimizer, scheduler)
            else:
                raise NotImplementedError


            if 'cos' in optType:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
            else:
                scheduler = None
            return optimizer, scheduler
        if epoch==0:
            optimizer, scheduler = getNewOpt(epoch, optimizer, scheduler)
            if 'cos' in optType: scheduler.step()
            return optimizer, scheduler
        else:
            if 'switch3' not in optType:
                return optimizer, scheduler
            else:
                assert int(n_epochs/3) == n_epochs/3
                changePoint = [int(n_epochs/3), int(n_epochs/3)*2]
                if epoch in changePoint:
                    optimizer, scheduler = getNewOpt(epoch, optimizer, scheduler)
                    if scheduler is not None:
                        for i in range(epoch):
                            scheduler.step()
            return optimizer, scheduler

    def train_1epoch(epoch, optimizer):
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        optimizer = getOptimizer(epoch, optimizer)


        iterator = enumerate(trainloader) if CUDA else enumerate(tqdm(trainloader))

        for batch_idx, (inputs, targets) in iterator:

            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            acc = 100.*correct/total
            avg_loss = train_loss/(batch_idx+1)

            # if CUDA: progress_bar(batch_idx, len(trainloader), f' ~ {net.name} ~ {optType} ~ Loss: {avg_loss:.3f} | Train Acc: {acc:>5.2f} | lr: {optimizer.param_groups[0]["lr"]:.3f}')


        return acc, avg_loss

    rec = []
    for epoch in range(n_epochs):
        printf(f'\n ~ {net.name} ~ {optType} \n\t Epoch: < {epoch}/{n_epochs} >\n\n')
        print(f'\n\t Epoch: < {epoch}/{n_epochs} >')
        train_acc, train_loss = train_1epoch(epoch, optimizer)
        # test()
        test_acc, best_acc = eva_net(net, testloader, best_acc, criterion)
        rec.append([train_acc, train_loss, test_acc, best_acc])
        # wzRec(np.array(rec)[:,0], ttl=f'train_acc << {net.name}', want_save=True)
        # wzRec(np.array(rec)[:,1], ttl=f'train_loss << {net.name}', want_save=True)
        wzRec(np.array(rec)[:,2], ttl=f'test_acc << {net.name}', task_desc=f'{optType} $ {lr} ^ {momentum}', want_save=True)
        # wzRec(np.array(rec)[:,3], ttl=f'best_acc << {net.name}', want_save=True)


    return






@torch.no_grad()
def eva_net(net,testloader,best_acc=101.,criterion=None,trainloader=None,**w):

    # testloader = trainloader

    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        # print(inputs.shape, outputs.shape, targets.shape)
        # raise

        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100.*correct/total
        avg_loss = test_loss/(batch_idx+1)
        
        # progress_bar(batch_idx, len(testloader), f'Loss: {avg_loss:.3f} | Test Acc: {acc:>5.2f}')
    desc = f'\n\n  Final test acc is:  {acc:.2f}\n  avg_loss is: {avg_loss} \n'

    print(desc)
    net.test_acc = acc
    if acc > best_acc:
        save_model(net, f'./wz_saved_models/{net._get_name()}.pth')
        best_acc = acc
        net.best_acc = best_acc

    return acc, best_acc












if __name__ == '__main__':

    if pargs.train_or_eva == 'train':
        # ============= train a resnet ==============
        args = getArgs(pargs)
        trainNN_zoo(**args,**vars(pargs))

    elif pargs.train_or_eva == 'eva':
        # ============= Evaluate resnet/etc ==============
        args = getArgs(pargs)
        test_acc, _ = eva_net(**args)













































































    pass
