import torch
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset
import argparse
from models.resnet import *
from models.ViT import ViT
from functools import partial


parser = argparse.ArgumentParser(description='PAG Research argparser')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float)
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed (default: 1)')
parser.add_argument('--arch', type=str, default='rn18', metavar='S', help='rn18, vit')
parser.add_argument('--pag_coeff', type=float, default=2, help='pag loss term coeff')
parser.add_argument('--augmentations', type=int, default=1, help='0: no 1: yes')
parser.add_argument('--grad_source', type=int, default=4, help='1: One Image'
                                                               '2: Class Mean'
                                                               '3: Nearest Neighbor'
                                                               '4: SBG')
args, _ = parser.parse_known_args()
print(args)
torch.manual_seed(args.seed)

# translation to rectify improved-diffusion classes
trans_dict = {
    0: 2,
    1: 1,
    2: 3,
    3: 4,
    4: 5,
    5: 6,
    6: 7,
    7: 0,
    8: 8,
    9: 9
}

# get data - train
if args.grad_source == 4:
    data = torch.load('data/data_tensor_gt_SM_600.pt')
    labels = torch.load('data/label_tensor_gt_SM_600.pt')
    for l_i, l in enumerate(labels):
        labels[l_i] = trans_dict[l.item()]
    data_t = data.clone()
    for i in range(1, data_t.shape[1]):
        data_t[:, trans_dict[i-1] + 1, :, :, :] = data[:, i, :, :, :]
    data = data_t
elif args.grad_source == 3:
    data = torch.load('data/data_tensor_emp_nn.pt')
    labels = torch.load('data/label_tensor_emp_nn.pt')
elif args.grad_source == 2:
    data = torch.load('data/data_tensor_emp_mean.pt')
    labels = torch.load('data/label_tensor_emp_mean.pt')
elif args.grad_source == 1:
    data = torch.load('data/data_tensor_emp_typical.pt')
    labels = torch.load('data/label_tensor_emp_typical.pt')

dataset = TensorDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True,
                                         drop_last=True, num_workers=2)
transform_crop_flip = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()])

# get data - test
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)

# get model
if args.arch == 'vit':
    model = torch.nn.DataParallel(ViT()).cuda()
else:
    model = torch.nn.DataParallel(ResNet18()).cuda()
#
print(f'Using {args.arch} with {sum(p.numel() for p in model.parameters() if p.requires_grad)} learnable params')

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

def augment(x, flag):
    if flag:
        return transform_crop_flip(x)
    else:
        return x

maybe_augment = partial(augment, flag=args.augmentations)

stats = ([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
if args.grad_source == 5:   # improved-diffusion normalization values
    stats = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])


# train function
def train_func(model, train_loader, optimizer):
    model.train()
    pag_coeff = args.pag_coeff
    #
    for batch_idx, (data, target) in enumerate(train_loader):
        images, target = data.cuda(), target.cuda()
        if args.pag_coeff:
            images, pag = images[:, 0].requires_grad_(True), images[:, 1:]
        else:
            images, pag = images[:, 0], images[:, 1:]
        optimizer.zero_grad()
        pred = model(maybe_augment(images))
        # CE loss
        loss_ce = torch.nn.CrossEntropyLoss()(pred, target)
        # PAG loss
        pag_loss = 0
        if pag_coeff > 0:
            for cls in range(10):
                dummy_loss = torch.nn.CrossEntropyLoss()(model(maybe_augment(images)),
                                                         torch.ones_like(target, dtype=torch.long).cuda() * cls)
                grad, = torch.autograd.grad(-1 * dummy_loss, [images], create_graph=True)
                pag_loss += (1. - torch.nn.CosineSimilarity(dim=1)(grad.view(target.shape[0], -1),
                                                                   pag[:, cls].view(target.shape[0], -1)).mean())
        loss = loss_ce + pag_coeff * pag_loss
        if batch_idx % 100 == 0:
            print(f'Train loss in batch {batch_idx}: {loss} | CE loss: {loss_ce} | PAG loss (before | after coeff): '
                  f'{pag_loss} | {pag_coeff * pag_loss}')
        loss.backward()
        optimizer.step()


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    epochs = args.epochs
    if epoch >= 0.5 * epochs:
        lr = args.lr * 0.1
    if epoch >= 0.75 * epochs:
        lr = args.lr * 0.01
    if epoch >= epochs:
        lr = args.lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


pag_coeff = args.pag_coeff
for epoch in range(1, args.epochs + 1):
    # adjust learning rate for SGD
    adjust_learning_rate(optimizer, epoch)
    # training
    print(f"Training epoch {epoch}")
    train_func(model, dataloader, optimizer)
# save model
torch.save(model.state_dict(), f'checkpoints/{args.arch}-cifar_grad_source-{args.grad_source}-'
                               f'pag_coeff-{args.pag_coeff}-lr-{args.lr}-aug-{args.augmentations}.pt')