import argparse
import torch

from model.wide_res_net import WideResNet
from model.resnet import resnet, linear
from model.PyramidNet import PyramidNet as PYRM
from model.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar, Cifar100
from utility.initialize import initialize
from utility.bypass_bn import enable_running_stats, disable_running_stats

import sys; sys.path.append("..")
from sam import SAM

import pdb
import time
import os


def check_dir(path):
    '''
    Create directory if it does not exist.
        path:           Path of directory.
    '''
    if not os.path.exists(path):
        os.makedirs(path)

def save(model, epoch, path):
    print('Saving..')
    state = {
    'model': model.state_dict(),
    'epoch': epoch,
    }
    torch.save(state, path +'/'+'epoch_'+str(epoch)+'.pth')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--arch', default='wideresnet', type=str, help="data augmentations using during training")
    parser.add_argument('--dataset', default='cifar10', type=str, help="data augmentations using during training")
    parser.add_argument('--data-dir', default='cifar10', type=str, help="data augmentations using during training")
    parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
    parser.add_argument("--epochs", default=200, type=int, help="Total number of epochs.")
    parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
    parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")

    # optimization parameters
    parser.add_argument("--batch_size", default=128, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
    parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
    parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
    parser.add_argument("--cos", action='store_true', help="True if you want to use the Adaptive SAM.")

    # sam parameters
    parser.add_argument("--adaptive", action='store_true', help="True if you want to use the Adaptive SAM.")
    parser.add_argument("--rho", default=2.0, type=float, help="Rho parameter for SAM.")
    parser.add_argument('--optimizer', default='SAM', type=str,choices = ['SAM', 'SGD'])
    parser.add_argument("--analysis", action='store_true', help="analyzing whether the last gradient is helpful in predicting the current gradient")

    # efficient parameters
    parser.add_argument("--loss_thred", default=2., type=float, help="Use 0.0 for no label smoothing.")
    parser.add_argument("--top_k", default=128, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--top_sgd", default=128, type=int, help="Batch size used in the training and validation loop.")

    # wb names
    parser.add_argument("--exp_name", default="baseline", type=str, help="exp name shown in the w&b")

    # training setting
    parser.add_argument("--gpu", default="0", type=str, help="which gpu are you using")
    parser.add_argument("--seed", default=42, type=int, help="which random seed are you using")
    args = parser.parse_args()

    initialize(args, seed=args.seed)
    gpu_id = "cuda:" + args.gpu
    device = torch.device(gpu_id if torch.cuda.is_available() else "cpu")

    if args.dataset == "cifar10":
        dataset = Cifar(args.batch_size, args.threads)
        n_class = 10
    elif args.dataset == "cifar100":
        dataset = Cifar100(args.batch_size, args.threads)
        n_class = 100

    if args.arch == "resnet":
        model = resnet(args.depth, n_class).to(device)
    elif args.arch == "wideresnet":
        model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=n_class).to(device)
    elif args.arch == "pynet":
        model = PYRM(args.dataset, 110, 270, n_class, False).cuda()
    elif args.arch == "linear":
        model = linear(n_class).to(device) 

    epochs = int(args.epochs)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        base_optimizer = torch.optim.SGD
        if args.optimizer == "SAM":
            optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, analysis=args.analysis,
                            lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay,)
    if args.cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    test_acc_best = 0
    train_acc = 0

    path = './checkpoints/'+ args.dataset + '/'+ args.arch + '/' + args.exp_name
    check_dir(path)

    optim_def = args.optimizer

    for epoch in range(epochs):
        model.train()

        print('\nEpoch: %d, LR: %.5f' % (epoch, scheduler.optimizer.param_groups[0]['lr']))
        total = 0
        total_correct = 0
        train_loss = 0
        epoch_time = 0

        top_sgd = args.top_sgd
        top_k = args.top_k

        for idx_b, batch in enumerate(dataset.train):
            start = time.time()
            inputs, targets = (b.cuda() for b in batch)

            # top k for subset

            top_k = min(top_k, len(inputs))
            sgd_k = min(top_sgd, len(inputs))

            # first forward-backward step
            enable_running_stats(model)
            if args.optimizer == "SGD" and (sgd_k == len(inputs)):
                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                   
                if args.rand_k:
                    loss = loss[:args.top_sgd]

                optimizer.zero_grad()
                loss.mean().backward()
                optimizer.step()
                top_k_ = 0
            elif args.optimizer == "SGD":
                with torch.no_grad():
                    predictions = model(inputs)
                    loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)

                    _, idx_sgd = torch.topk(loss, sgd_k)

                idx_sgd = idx_sgd[torch.randperm(len(idx_sgd))]
                predictions_adv = model(inputs[idx_sgd])
                loss_adv = smooth_crossentropy(predictions_adv, targets[idx_sgd])
                optimizer.zero_grad()
                loss_adv.mean().backward()
                optimizer.step()

            elif (args.top_k == args.batch_size) and (args.top_sgd == args.batch_size) and args.loss_thred >= 1:
                # first forward-backward step
                enable_running_stats(model)
                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
                loss.mean().backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward step
                disable_running_stats(model)
                smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
                optimizer.second_step(zero_grad=True)
            else:
                with torch.no_grad():
                    predictions = model(inputs)
                    loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)

                    _, idx_k = torch.topk(loss, top_k)
                    _, idx_sgd = torch.topk(loss, sgd_k)

                idx_k = idx_k[torch.randperm(len(idx_k))]

                predictions_k = model(inputs[idx_k])
                loss_k = smooth_crossentropy(predictions_k, targets[idx_k], smoothing=args.label_smoothing)
                loss_k.mean().backward()

                optimizer.first_step(zero_grad=True)

                # second forward-backward step
                disable_running_stats(model)

                idx_sgd = idx_sgd[torch.randperm(len(idx_sgd))]
                predictions_adv = model(inputs[idx_sgd])
                loss_adv = smooth_crossentropy(predictions_adv, targets[idx_sgd])
                loss_adv.mean().backward()
                if predictions is None:
                    predictions = predictions_adv
                    loss = loss_adv

                optimizer.second_step(zero_grad=True)

            end = time.time()
            epoch_time += (end - start)

            iter_time = epoch_time/(idx_b + 1)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets

                total += len(inputs)
                total_correct +=  correct.sum()
                train_loss += loss.sum()
                train_acc = total_correct/total
                if (idx_b + 1) % 50 == 0 or idx_b + 1 == len(dataset.train):
                    print(idx_b + 1, len(dataset.train), 'Loss: %.3f | Acc: %.3f%% (%.3f) | Iter Time: %.3f (%.3f)'
                                 % (train_loss/total, 100.*total_correct/total, 100.*(correct.sum()/len(inputs)), iter_time, epoch_time))

        scheduler.step()

        model.eval()

        total_eval = 0
        total_correct_eval = 0
        test_loss = 0
        print("testing")
        with torch.no_grad():
            for idx_b, batch in enumerate(dataset.test):
                inputs, targets = (b.cuda() for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                total_eval += len(inputs)
                total_correct_eval +=  correct.sum()
                test_loss += loss.sum()
                test_acc = 100.*total_correct_eval/total_eval
                if (idx_b + 1) % 10 == 0 or idx_b + 1 == len(dataset.test):
                    print(idx_b + 1, len(dataset.test), 'Loss: %.3f | Acc: %.3f%% (%.3f)'
                             % (test_loss/total_eval, test_acc, 100.*(correct.sum()/len(inputs))))
            test_acc_best = max(test_acc, test_acc_best)

        if epoch % int(epochs//4) == 0 or epoch == (args.epochs - 1):
            save(model, epoch, path)


