import os
import argparse
import numpy as np
import torch
from tensorboardX import SummaryWriter
import random
from cifar100 import CIFAR100
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--mean', type=float, default=1.0)
# parser.add_argument('--std', type=float, default=0.0)
parser.add_argument('--method', type=str, default='ggdCov', choices=['ggdCov', 'sgd','sgdCov'])
parser.add_argument('--iters', type=int, default=int(7.5e4+1))
parser.add_argument('--schedule', type=int, nargs='+', default=[int(4e4), int(6e4)])
parser.add_argument('--batchsize', type=int, default=100)
parser.add_argument('--ghostsize', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--wd', type=float, default=5e-4)
parser.add_argument('--aug', action='store_true', default=True)
parser.add_argument('--randinit', action='store_true', default=False)
parser.add_argument('--model', type=str, default='resnet', choices=['vgg','resnet'])
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--datadir', type=str, default='/home/geothe9/datasets/CIFAR100/numpy')
parser.add_argument('--logdir', type=str, default='logs/CIFAR100')

args = parser.parse_args()

# output file
dir = '{}_{}_{}_{}_lr_{}_wd_{}_seed_{}'.format(
    args.model,
    args.method,
    'rand' if args.randinit else 'fix',
    'aug' if args.aug else 'no-aug',
    args.lr, args.wd, args.seed)

log_dir = os.path.join(args.logdir, dir)
try:
   if not os.path.isdir(log_dir):
       os.mkdir(log_dir)
except OSError as err:
   print(err)
# log_dir = os.path.join(args.logdir, '/'+str(args.model)+str(args.method)+'aug'+str(args.aug)+'seed'+str(args.seed))

logger = LogSaver(log_dir)
# logger = LogSaver(args.logdir)
logger.save(str(args), 'args')


torch.manual_seed(args.seed)  # cpu
torch.cuda.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

# data
dataset = CIFAR100(args.datadir)
logger.save(str(dataset), 'dataset')
test_list = dataset.getTestList(500, True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model
start_iter = 0
lr = args.lr
if args.model == 'resnet':
    from resnet import ResNet18
    model = ResNet18()
    if not args.randinit:
        logger.save('=> load fixed initialization')
        model.load_state_dict(torch.load('./CIFAR100/init/resnet.pth'))
elif args.model == 'vgg':
    from vgg import vgg11
    model = vgg11()
    if not args.randinit:
        logger.save('=> load fixed initialization')
        model.load_state_dict(torch.load('./CIFAR100/init/vgg.pth'))
else:
    raise NotImplementedError()

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(reduction='none')
# if args.method == 'sgd':
#     # criterion = torch.nn.CrossEntropyLoss().to(device)
#     criterion = torch.nn.CrossEntropyLoss(reduction='none')
# else:
#     criterion = CEwithMask

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.wd)

if args.resume:
    checkpoint = torch.load(args.resume)
    start_iter = checkpoint['iter'] + 1
    lr = checkpoint['lr']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.save("=> loaded checkpoint '{}'".format(args.resume))
logger.save(str(model), 'classifier')
logger.save(str(optimizer), 'optimizer')

# writer
writer = SummaryWriter(log_dir)

# optimization
torch.backends.cudnn.benchmark = True
for i in range(start_iter, args.iters):
    # decay lr
    if i in args.schedule:
        lr *= 0.1
        logger.save('update lr: %f'%(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # train
    model.train()
    optimizer.zero_grad()
    # mask = None
    bs = 0
    if args.method == 'ggdCov':
        bs = int(args.batchsize * 5)
        train_list = dataset.getTrainList(bs, args.aug, True)
        w_noise = torch.randn(dataset.n_train).to(device)
        noise = w_noise - w_noise.mean()
        # mask = noise * args.std + args.mean
        mask = noise * np.sqrt(len(train_list)) + args.mean
    else:
        train_list = dataset.getTrainGhostBatch(args.batchsize, args.ghostsize, args.aug, True)

    train_loss, train_acc = 0, 0

    for j in range(len(train_list)):
        x, y = train_list[j]
        out = model(x)

        if args.method == 'sgdCov':
            w_noise = torch.randn(y.shape).to(device)
            mask = w_noise - w_noise.mean()
            # mask = mask * args.std + args.mean
            mask = mask * np.sqrt(len(train_list)) + args.mean
            # loss = criterion(out, y, mask) * mask
            loss = criterion(out, y) * mask
            # loss

        elif args.method == 'ggdCov':
            # loss *= mask[j * args.batchsize:(j + 1) * args.batchsize]
            loss = criterion(out, y) * mask[j * bs:(j + 1) * bs]

        else:
            loss = criterion(out, y)

        loss = loss.mean()
        loss.backward()
        train_acc += accuracy(out, y).item()
        train_loss += loss.item()
    for param in model.parameters():
        param.grad.data /= len(train_list)
    optimizer.step()
    train_acc /= len(train_list)
    train_loss /= len(train_list)

    # evaluate
    if i % 500 == 0 or i <= 100:
        model.eval()
        writer.add_scalar('lr', lr, i)
        writer.add_scalar('acc/train', train_acc, i)
        writer.add_scalar('loss/train', train_loss, i)

        test_loss, test_acc = 0, 0
        for x,y in test_list:
            out = model(x)
            test_loss += criterion(out, y).mean().item()
            test_acc += accuracy(out, y).item()
        test_loss /= len(test_list)
        test_acc /= len(test_list)
        writer.add_scalar('loss/test', test_loss, i)
        writer.add_scalar('acc/test', test_acc, i)
        writer.add_scalar('acc/diff', train_acc - test_acc, i)

        logger.save('Iter:%d, Test [acc: %.2f, loss: %.4f], Train [acc: %.2f, loss: %.4f]' \
                % (i, test_acc, test_loss, train_acc, train_loss))

    if i % 5000 == 0:
        state = {'iter':i, 'lr':lr, 'ts_acc': test_acc, 'ts_loss': test_loss, 'tr_acc': train_acc, 'tr_loss': train_loss,
                 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
        torch.save(state, log_dir+'/iter-'+str(i)+'.pth.tar')

writer.close()
