import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import grad
from tensorboardX import SummaryWriter

from collections import OrderedDict
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from collections import OrderedDict


from vgg import vgg11
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--total_reps', type=int, default=int(5))
parser.add_argument('--max_epoch', type=int, default=int(300))
parser.add_argument('--lr_update_epoch', type=int, default=int(30))
parser.add_argument('--lr_decay_rate', type=float, default=2.0)
parser.add_argument('--init_lr', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--phase_end_epoch', type=int, default=int(200))
parser.add_argument('--big_batch_size', type=int, default=1024)
parser.add_argument('--logdir', type=str, default='logs/data_augmentation_reps_our_cifar10_vgg11')
parser.add_argument('--seed', type=int, default=int(1))
parser.add_argument('--model_weight_dir', type=str, default='logs/VGG11_initialization')
#parser.add_argument('--iter_save_model', type=int, default=int(1000)) # always save model at the end of each epoch
parser.add_argument('--iter_eval_train', type=int, default=int(100))
parser.add_argument('--epoch_save_model', type=int, default=int(10))
parser.add_argument('--heavy_tail_noise_alpha', type=float, default=1.4)
parser.add_argument('--heavy_tail_noise_magnitude', type=float, default=0.5)
parser.add_argument('--gradient_clip_b', type=float, default=1)
parser.add_argument('--weight_decay', type=float, default=5e-4)


args = parser.parse_args()
logger = LogSaver(args.logdir)
logger.save(str(args), 'args')

device = torch.device("cuda:0")
torch.cuda.set_device(device)

gradient_clip_b = args.gradient_clip_b
heavy_tail_noise_alpha = args.heavy_tail_noise_alpha
heavy_tail_noise_magnitude = args.heavy_tail_noise_magnitude
#gradient_clip = args.gradient_clip
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=0, pin_memory=True)

train_loader_lb = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.big_batch_size, shuffle=True,
        num_workers=0, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True)


# writer
writer = SummaryWriter(args.logdir)

for idx_rep in range(args.total_reps):
    logger.save('       ')
    logger.save('              REP '+str(idx_rep))
    logger.save('       ')
    # model
    start_iter = 0
    model = vgg11()
    # load model weight
    checkpoint = torch.load(args.model_weight_dir+'/initialized_weight_'+str(idx_rep)+'.pth.tar')
    model.load_state_dict(checkpoint)
    del checkpoint
    # move to gpu
    model.to(device)
    logger.save(str(model), 'classifier')

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.init_lr, weight_decay = args.weight_decay)
    logger.save(str(optimizer), 'optimizer')
    current_lr = args.init_lr

    # optimization
    torch.backends.cudnn.benchmark = True
    train_history = []
    test_history = []
    iter_count = 0
    for idx_epoch in range(args.max_epoch):
        model.train()
        for x,y in train_loader:
            # small batch info
            x,y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()

            # first phase only: create heavy-tailed white noise
            if idx_epoch < args.phase_end_epoch:
                sb_direction_dict = get_grads_dict(model)
                for xx_,yy_ in train_loader_lb:
                    xx,yy = xx_.to(device), yy_.to(device)
                    break
                #xx,yy = xx.to(device), yy.to(device)
                optimizer.zero_grad()
                out_lb = model(xx)
                loss_lb = criterion(out_lb, yy)
                loss_lb.backward()
                lb_direction_dict = get_grads_dict(model)

                # noise direction and multiplier
                noise_direction = get_dict_differnce(sb_direction_dict, lb_direction_dict)
                noise_size = np.random.pareto(heavy_tail_noise_alpha) * heavy_tail_noise_magnitude
                #============================
                # produce gradient for update
                #============================
                optimizer.zero_grad()
                # apply parteo white noise onto the gd(LB) direction
                modify_model_noise(model, lb_direction_dict, noise_direction, 1 + noise_size)


            # apply gradient clipping
            if gradient_clip_b is not None and gradient_clip_b > 0.1:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_b/current_lr )

            optimizer.step()

            # update train history
            acc_train = accuracy(out, y).item()
            loss_train = loss.item()
            train_history.append([ idx_epoch + 1, iter_count, acc_train, loss_train ])

            # print training performance
            if iter_count % args.iter_eval_train == 0:
                logger.save('EPOCH '+str(idx_epoch + 1)+' ITER '+str(iter_count)+' TRAIN BATCH ACC = '+str(acc_train)+', LOSS = '+str(loss_train))

            # update counter and pointer
            iter_count += 1
            #pointer += args.batch_size

        # check model performance
        model.eval()
        acc, loss = 0.0, 0.0
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            acc += accuracy(out, y).item()
            loss += criterion(out, y).item()
        acc /= len(val_loader)
        loss /= len(val_loader)

        test_history.append([idx_epoch + 1, iter_count, acc, loss])

        logger.save('='*100)
        logger.save('EPOCH '+str(idx_epoch + 1)+' ITER '+str(iter_count)+' TEST ACC = '+str(acc)+', LOSS = '+str(loss)+', LR = '+str(current_lr))
        logger.save('='*100)

        writer.add_scalar('acc/train', acc_train, iter_count)
        writer.add_scalar('loss/train', loss_train, iter_count)

        writer.add_scalar('acc/test', acc, iter_count)
        writer.add_scalar('loss/test', loss, iter_count)
        writer.add_scalar('acc/diff', acc_train-acc, iter_count)

        # save model weights and numpy training history
        if (idx_epoch + 1) % args.epoch_save_model == 0:
            state = {'epoch':idx_epoch + 1,'iter':iter_count, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
            torch.save(state, args.logdir+'/REP_'+str(idx_rep)+'epoch-'+str(idx_epoch + 1)+'.pth.tar')
        np.save(open(args.logdir+'/REP_'+str(idx_rep)+'train_history.npy','wb'),np.array(train_history))
        np.save(open(args.logdir+'/REP_'+str(idx_rep)+'test_history.npy','wb'),np.array(test_history))


        # update learning rate in the optimizer
        if idx_epoch > args.phase_end_epoch and (idx_epoch + 1)%args.lr_update_epoch == 0:
            current_lr /= 2
            adjust_learning_rate(optimizer, current_lr)


writer.close()
