import torch
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset
import argparse
from models.resnet import *
import torchvision.transforms.functional as TF
from functools import partial


parser = argparse.ArgumentParser(description='PAG Research argparser')
parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float)
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed (default: 1)')
parser.add_argument('--pag_coeff', type=float, default=0, help='pag loss term coeff')
parser.add_argument('--grad_source', type=int, default=5, help='1: One Image'
                                                               '2: Class Mean'
                                                               '3: Nearest Neighbor'
                                                               '4: SBG ')
parser.add_argument('--augmentations', type=int, default=1, help='0: no 1: yes')

args, _ = parser.parse_known_args()
print(args)

stats = ((0., 0., 0.), (1., 1., 1.))

torch.manual_seed(args.seed)

# translation to rectify improved-diffusion classes
trans_dict = {
    0: 1,
    1: 5,
    2: 6,
    3: 7,
    4: 8,
    5: 9,
    6: 0,
    7: 4,
    8: 2,
    9: 3
}

# get data - train
if args.grad_source == 4:
    data = torch.load('data/stl_data_tensor_gt_SM_550.pt')
    labels = torch.load('data/stl_label_tensor_gt_SM_550.pt')
    for l_i, l in enumerate(labels):
        labels[l_i] = trans_dict[l.item()]
    data_t = data.clone()
    for i in range(1, data_t.shape[1]):
        data_t[:, trans_dict[i-1] + 1, :, :, :] = data[:, i, :, :, :]
    data = data_t
    stats = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
elif args.grad_source == 3:
    data = torch.load('data/data_tensor_nn_stl.pt')
    labels = torch.load('data/label_tensor_nn_stl.pt')
elif args.grad_source == 2:
    data = torch.load('data/data_tensor_CM_stl.pt')
    labels = torch.load('data/label_tensor_CM_stl.pt')
elif args.grad_source == 1:
    data = torch.load('data/data_tensor_OI_stl.pt')
    labels = torch.load('data/label_tensor_OI_stl.pt')

dataset = TensorDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True,
                                         drop_last=True, num_workers=2)

transform_crop_flip = transforms.Compose([transforms.RandomCrop(96, padding=4), transforms.RandomHorizontalFlip()])


def augment(x, flag):
    if flag:
        return transform_crop_flip(x)
    else:
        return x


maybe_augment = partial(augment, flag=args.augmentations)

# get data - test
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])

testset = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)


# get model
model = torch.nn.DataParallel(ResNet18()).cuda()

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


# train function
def train_func(model, train_loader, optimizer):
    model.train()
    pag_coeff = args.pag_coeff
    for batch_idx, (data, target) in enumerate(train_loader):
        images, target = data.cuda(), target.cuda()
        if args.pag_coeff:
            images, pag = images[:, 0].requires_grad_(True), images[:, 1:]
        else:
            images, pag = images[:, 0], images[:, 1:]
        #
        if images.shape[-1] != 96:
            images = TF.resize(images, [96, 96])
        optimizer.zero_grad()
        pred = model(maybe_augment(images))
        # CE loss
        loss_ce = torch.nn.CrossEntropyLoss()(pred, target)
        # PAG loss
        pag_loss = 0
        if pag_coeff > 0:
            for cls in range(10):
                dummy_loss = torch.nn.CrossEntropyLoss()(model(maybe_augment(images)),
                                                         torch.ones_like(target, dtype=torch.long).cuda() * cls)
                grad, = torch.autograd.grad(-1 * dummy_loss, [images], create_graph=True)
                t_pag = pag[:, cls]
                if pag.shape[-1] != 96:
                    t_pag = TF.resize(pag[:, cls], [96,96])
                pag_loss += (1. - torch.nn.CosineSimilarity(dim=1)(grad.view(target.shape[0], -1),
                                                                   t_pag.view(target.shape[0], -1)).mean())
        loss = loss_ce + pag_coeff * pag_loss
        if batch_idx % 100 == 0:
            print(f'Train loss in batch {batch_idx}: {loss} | CE loss: {loss_ce} | PAG loss (before | after coeff): '
                  f'{pag_loss} | {pag_coeff * pag_loss}')
        loss.backward()
        optimizer.step()


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 75:
        lr = args.lr * 0.1
    if epoch >= 90:
        lr = args.lr * 0.01
    if epoch >= 100:
        lr = args.lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


pag_coeff = args.pag_coeff
best_acc = -1
for epoch in range(1, args.epochs + 1):
    # adjust learning rate for SGD
    adjust_learning_rate(optimizer, epoch)
    # training
    print(f"Training epoch {epoch}")
    train_func(model, dataloader, optimizer)
# save model
torch.save(model.state_dict(), f'checkpoints/stl_grad_source-{args.grad_source}-pag_coeff-{args.pag_coeff}-'
                               f'lr-{args.lr}-aug-{args.augmentations}.pt')

