from __future__ import print_function
import argparse
import torch
import torch.optim as optim
import numpy as np
import os
from torchvision import datasets, transforms
from auto_augment import AutoAugment, Cutout
from archive import autoaug_paper_cifar10
from FastAutoAugment.data import Augmentation
from models import ButterfLeNet
from train_utils import train, train_alt, test

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='ButterfLeNet')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size', type=int, default=1000, 
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd', type=float, default=0.0001, metavar='LR',
                        help='weight decay (default: 0.0001)')
    parser.add_argument('--lam', type=float, default=0.0, metavar='LAM',
                        help='regularization parameter (default: 0.0)')
    parser.add_argument('--arch-lr-factor', type=float, default=1.0, 
                        metavar='ARCH_LR_FACTOR',
                        help='architecture learning rate scale (default: 1.0)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging \
                            training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--conv', action='store_true', default=False,
                        help='Use convolutional LeNet-esque architecture')
    parser.add_argument('--fc', action='store_true', default=False,
                        help='Use fully connected architecture')
    parser.add_argument('--fixed', action='store_true', default=False,
                        help='Use fixed DFT matrices')
    parser.add_argument('--rand', action='store_true', default=False,
                        help='Use fixed random K-ops')
    parser.add_argument('--warm-start', action='store_true', default=False,
                        help='Warm start with DFT matrices')
    parser.add_argument('--kmatrix', action='store_true', default=False,
                        help='Use K-matrices')
    parser.add_argument('--save-perf', action='store_true', default=False,
                        help='For Saving the test accuracies')
    parser.add_argument('--bilevel', action='store_true', default=False,
                        help='Bilevel optimization')
    parser.add_argument('--unilevel', action='store_true', default=False,
                        help='Unilevel optimization')
    parser.add_argument('--fdarts', action='store_true', default=False,
                        help='Use first order DARTS')
    parser.add_argument('--sdarts', action='store_true', default=False,
                        help='Use second order DARTS')
    parser.add_argument('--mnist', action='store_true', default=False,
                        help='train on MNIST')
    parser.add_argument('--fmnist', action='store_true', default=False,
                        help='train on FashionMNIST')
    parser.add_argument('--cifar100', action='store_true', default=False,
                        help='train on cifar100')
    parser.add_argument('--depth', type=int, default=9,
                        help='Depth of the from-scratch K-matrices')
    parser.add_argument('--tied', action='store_true', default=False,
                        help='Use tied KOPS')
    parser.add_argument('--cutout', action='store_true', default=False,
                        help='Use cutout')
    parser.add_argument('--aa', action='store_true', default=False,
                        help='Use autoaugment')
    parser.add_argument('--faa', action='store_true', default=False,
                        help='Use fast autoaugment')
    parser.add_argument('--bn', action='store_true', default=False,
                        help='Use batch norm')
    parser.add_argument('--archada', action='store_true', default=False,
                        help='Use adaptive optimizer for architecture')
    parser.add_argument('--permute', action='store_true', default=False,
                        help='Permute pixels')
    parser.add_argument('--sgdr', action='store_true', default=False,
                        help='Retrain loaded architecture parameters')
    parser.add_argument('--offlinesgd', action='store_true', default=False,
                        help='Use SGD when retraining')
    parser.add_argument('--searchdir', type=str, default='',
                        help='K-operation weights to load')
    parser.add_argument('--loadepoch', type=int, default=100, metavar='N',
                        help='search epoch to load from')
    parser.add_argument('--wide', action='store_true', default=False,
                        help='Use Wide LeNet architecture')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    if args.wide:
        from models_wide import ButterfLeNet
    else:
        from models import ButterfLeNet

    if args.fdarts or args.sdarts:
        args.archada = True
        args.bilevel = True

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    global_augs = []

    if args.permute:
        # Try permuting rows and columns with same perm
        #shuffle_idx = torch.randperm(3 * 32 * 32)
        shuffle_idx = torch.randperm(32)
        global_augs.append(
            transforms.Lambda(
                lambda x: x[:, shuffle_idx][:, :, shuffle_idx])
        )


    if args.aa or args.cutout or args.faa:
        augs = []

        if args.cutout:
            augs.append(Cutout())
        if args.aa:
            augs.append(AutoAugment())
        if args.faa:
            augs.append(Augmentation(autoaug_paper_cifar10()))

        transform = transforms.Compose(augs + [
            transforms.ToTensor(),
            transforms.Normalize(
                (125.307 / 255.0, 122.95 / 255.0, 113.865 / 255.0),
                (62.9932 / 255.0, 62.0887 / 255.0, 66.7048 / 255.0))
            ]+ global_augs)
    else:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(
                (125.307 / 255.0, 122.95 / 255.0, 113.865 / 255.0),
                (62.9932 / 255.0, 62.0887 / 255.0, 66.7048 / 255.0))
            ] + global_augs)
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            (125.307 / 255.0, 122.95 / 255.0, 113.865 / 255.0),
            (62.9932 / 255.0, 62.0887 / 255.0, 66.7048 / 255.0))
        ] + global_augs)

    if args.mnist:
        dataset1 = datasets.MNIST('../data/mnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
        dataset12 = datasets.MNIST('../data/mnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
        dataset2 = datasets.MNIST('../data/mnist',
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.Resize((32, 32)),
                            transforms.ToTensor()]))
    elif args.fmnist:
        dataset1 = datasets.FashionMNIST('../data/fashionmnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
        dataset12 = datasets.FashionMNIST('../data/fashionmnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
        dataset2 = datasets.FashionMNIST('../data/fashionmnist',
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.Resize((32, 32)),
                            transforms.ToTensor()]))
    elif args.cifar100:
        dataset1 = datasets.CIFAR100(root='../data/cifar100', train=True,
                                                download=True, 
                                                transform=transform)
        dataset12 = datasets.CIFAR100(root='../data/cifar100', train=True,
                                        download=True, 
                                        transform=test_transform)
        dataset2 = datasets.CIFAR100(root='../data/cifar100', train=False,
                                        download=True, 
                                        transform=test_transform)
    else:
        dataset1 = datasets.CIFAR10(root='../data/cifar10', train=True,
                                                download=True, 
                                                transform=transform)
        dataset12 = datasets.CIFAR10(root='../data/cifar10', train=True,
                                        download=True, 
                                        transform=test_transform)
        dataset2 = datasets.CIFAR10(root='../data/cifar10', train=False,
                                        download=True, 
                                        transform=test_transform)


    swset, archset = torch.utils.data.random_split(
        dataset1, [len(dataset1) // 2, len(dataset1) // 2])

    train_loader_sw = torch.utils.data.DataLoader(swset, **kwargs)
    train_loader_arch = torch.utils.data.DataLoader(archset, **kwargs)
    
    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)

    traineval_loader = torch.utils.data.DataLoader(dataset12, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = ButterfLeNet(args).to(device)
    print(model)

    # Don't train the loaded architecture weights
    if args.searchdir != "" and not args.sgdr:
        args.fixed = True

    # Filter architectural parameters
    arch_params = []
    model_params = []
    for name, p in model.named_parameters():
        if ('twiddle' in name) or ('permutation' in name):
            if args.fixed or args.rand:
                p.requires_grad = False
            arch_params.append(p)
        else:
            model_params.append(p)


    tot_model_params = sum(p.numel() for p in model_params if p.requires_grad)
    print("Total model parameters:", tot_model_params)
    tot_arch_params = sum(p.numel() for p in arch_params if p.requires_grad)
    print("Total arch parameters:", tot_arch_params)

    # Only include model parameters
    optimizer = optim.SGD(model_params, 
        lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.wd)
    if args.searchdir != "" and not args.offlinesgd:
        optimizer = torch.optim.Adam(model_params, weight_decay=args.wd)

    def sched(epoch):
    #    if epoch < 100:
        if epoch < int(args.epochs * 0.5):
            return 1.0
    #    if epoch < 150:
        if epoch < int(args.epochs * 0.75):
            return 0.5
        return 0.1

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda e: sched(e))

    # TODO sort out arch scheduler
    # Separate optimizer for architectural parameters
    if (tot_arch_params > 0):
        #arch_optimizer = optim.Adadelta(arch_params)
        if args.archada:
            arch_optimizer = optim.Adam(arch_params)
        else:
            arch_optimizer = optim.SGD(arch_params, 
                lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0
                )
        arch_scheduler = torch.optim.lr_scheduler.LambdaLR(
            arch_optimizer, lr_lambda=lambda e: sched(e) * args.arch_lr_factor)
        #arch_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        #    arch_optimizer, T_0=1)
    else:
        arch_optimizer = None
        arch_scheduler = None

    train_accs = []
    test_accs = []
    train_losses = []
    test_losses = []
    lrs = []
    fnorm = []

    epoch_fnorms = []
    epoch_gradnorms = []

    norm, op1, op2 = model.get_fnorms()
    fnorm.append(norm)

    for epoch in range(1, args.epochs + 1):

        if args.bilevel or args.unilevel:
            train_alt(args, model, device, train_loader_sw, train_loader_arch, 
                optimizer, arch_optimizer, epoch, epoch_fnorms, epoch_gradnorms)
        else:
            train(args, model, device, train_loader, 
            optimizer, arch_optimizer, epoch, epoch_fnorms, epoch_gradnorms)

        train_acc, train_loss = test(model, device, traineval_loader)
        test_acc, test_loss = test(model, device, test_loader)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        scheduler.step()
        lrs.append(scheduler.get_last_lr())
        norm, op1, op2 = model.get_fnorms()
        fnorm.append(norm)

        if arch_scheduler:
            arch_scheduler.step()

        # TODO improve directory naming
        if args.save_perf:
            if args.searchdir != "":
                if args.mnist:
                    prefix = args.searchdir + "_MNIST"
                elif args.fmnist:
                    prefix = args.searchdir + "_FMNIST"
                elif args.cifar100:
                    prefix = args.searchdir + "_CIFAR100"
                elif args.sgdr:
                    prefix = args.searchdir + "_SGDR"
                else:
                    prefix = args.searchdir + "_OFFLINE"
            else:
                if args.fc:
                    prefix = f'./results/butterflenet_padded_fc_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                elif args.conv:
                    prefix = f'./results/butterflenet_padded_conv_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                elif args.kmatrix:
                    prefix = f'./results/butterflenet_padded_kmatrix_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                elif args.fixed and args.searchdir == "":
                    prefix = f'./results/butterflenet_padded_fixed_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                elif args.warm_start:
                    prefix = f'./results/butterflenet_padded_warmstart_archlrf{args.arch_lr_factor}_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                else:
                    prefix = f'./results/butterflenet_padded_kop{args.depth}_archlrf{args.arch_lr_factor}_unilevel{args.unilevel}_bilevel{args.bilevel}_{args.seed}'

                if args.permute:
                    prefix += '_PERMUTED'

                if args.rand:
                    prefix += 'rand'

                if args.tied:
                    prefix += 'tied'

                if args.bn:
                    prefix += 'bn'

                if args.lam > 0:
                    prefix += f'lambda2{args.lam}'

                if args.cutout or args.aa or args.faa:
                    s = ''
                    if args.cutout:
                        s += 'c'
                    if args.aa:
                        s += 'aa'
                    if args.faa:
                        s += 'faa'

                    prefix += s

                prefix += f'_wd{args.wd}'

                if args.wide:
                    prefix += '_WIDE'

                if args.archada:
                    prefix += '_archada'

                if args.fdarts:
                    prefix += '_fdarts'

                if args.sdarts:
                    prefix += '_sdarts'

                if args.mnist:
                    prefix += "_MNIST"
                elif args.fmnist:
                    prefix += "_FMNIST"
                elif args.cifar100:
                    prefix += "_CIFAR100"
            
            prefix += '/'

            if not os.path.exists(prefix):
                os.makedirs(prefix)

            if not os.path.exists(prefix + 'models/'):
                os.makedirs(prefix + 'models/')

            np.save(prefix + 'train_acc.npy', np.array(train_accs))
            np.save(prefix + 'test_acc.npy', np.array(test_accs))
            np.save(prefix + 'train_loss.npy', np.array(train_losses))
            np.save(prefix + 'test_loss.npy', np.array(test_losses))
            np.save(prefix + 'lrs.npy', np.array(lrs))
            np.save(prefix + 'fnorm.npy', np.array(fnorm))
            np.save(prefix + 'epoch_fnorms.npy', np.array(epoch_fnorms))
            np.save(prefix + 'epoch_gradnorms.npy', np.array(epoch_gradnorms))
            np.save(prefix + 'op1.npy', np.array(op1))
            np.save(prefix + 'op2.npy', np.array(op2))
            torch.save(model.state_dict(), 
                prefix + 'models/' + f'model_{epoch}.pt')

if __name__ == '__main__':
    main()
