from __future__ import print_function

import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import random

import argparse
import csv
import os
import sys

# local imports
import models as local_models
from utils import get_dataset, select_optimal_device, progress_bar, SUPPORTED_DATASETS, LOG_DIR, CHECKPOINTS_DIR, INSTAHIDE_MODELS

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def label_to_onehot(target, num_classes):
    '''Returns one-hot embeddings of scaler labels'''
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(
        0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target


def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))


def mixup_criterion(pred, ys, lam_batch, num_classes):
    '''Returns mixup loss'''
    if args.pair:
        inside_cnt = 1
    else:
        inside_cnt = (args.klam+1)//2

    ys_onehot = [label_to_onehot(y, num_classes) for y in ys]
    mixy = vec_mul_ten(lam_batch[:, 0], ys_onehot[0])
    # for i in range(1, args.klam):
    for i in range(1, inside_cnt):
        mixy += vec_mul_ten(lam_batch[:, i], ys_onehot[i])
    l = cross_entropy_for_onehot(pred, mixy)
    return l


def vec_mul_ten(vec, tensor):
    size = list(tensor.size())
    size[0] = -1
    size_rs = [1 for i in range(len(size))]
    size_rs[0] = -1
    vec = vec.reshape(size_rs).expand(size)
    res = vec * tensor
    return res


def mixup_data(x, y, x_help, device):
    '''Returns mixed inputs, lists of targets, and lambdas'''
    lams = np.random.normal(0, 1, size=(x.size()[0], args.klam))
    for i in range(x.size()[0]):
        lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i]))
        if args.klam > 1:
            # upper bounds a single lambda + lower bounds the sum of lambdas for private samples
            while lams[i].max() > args.upper or (lams[i][0] + lams[i][1]) < args.dom:
                lams[i] = np.random.normal(0, 1, size=(1, args.klam))
                lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i]))

    lams = torch.from_numpy(lams).float().to(device)

    mixed_x = vec_mul_ten(lams[:, 0], x)
    ys = [y]

    if args.pair:
        inside_cnt = 1
    else:
        inside_cnt = (args.klam + 1)//2

    for i in range(1, args.klam):
        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(device)
        if i < inside_cnt:
            # mix private samples
            mixed_x += vec_mul_ten(lams[:, i], x[index, :])
        else:
            # mix public samples
            mixed_x += vec_mul_ten(lams[:, i], x_help[index, :])
        ys.append(y[index])         # Only keep the labels for private samples

    if args.mode == 'instahide':
        sign = torch.randint(2, size=list(x.shape), device=device) * 2.0 - 1
        mixed_x *= sign.float().to(device)
    return mixed_x, ys, lams


def generate_sample(trainloader, inputs_help, device):
    assert len(trainloader) == 1

    inputs_help = inputs_help[torch.randperm(inputs_help.size()[0])]
    for _, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        mix_inputs, mix_targets, lams = mixup_data(
            inputs, targets.float(), inputs_help, device)
    return (mix_inputs, mix_targets, lams)


def train(device, net, optimizer, inputs_all, mix_targets_all, lams, epoch, num_classes, new_lr):
    print('\nEpoch: %d' % epoch)
    net.to(device)
    net.train()

    train_loss, correct, total = 0, 0, 0

    seq = random.sample(range(len(inputs_all)), len(inputs_all))
    bl = list(chunks(seq, args.batch_size))

    for batch_idx in range(len(bl)):
        b = bl[batch_idx]
        inputs = torch.stack([inputs_all[i] for i in b])
        if args.mode == 'instahide' or args.mode == 'mixup':
            lam_batch = torch.stack([lams[i] for i in b])

        mix_targets = []
        for ik in range(args.klam):
            mix_targets.append(
                torch.stack(
                    [mix_targets_all[ik][ib].long().to(device) for ib in b]))
        targets_var = [Variable(mix_targets[ik]) for ik in range(args.klam)]

        inputs = Variable(inputs)
        outputs = net(inputs)
        loss = mixup_criterion(outputs, targets_var, lam_batch, num_classes)
        train_loss += loss.data.item()
        total += args.batch_size
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        msg = f"Loss: {(train_loss / (batch_idx + 1)):.3f} | LR: {new_lr}"
        progress_bar(batch_idx, len(inputs_all)/args.batch_size+1, msg)
    return (train_loss / batch_idx, 100. * correct / total)


def test(device, net, testloader, epoch, start_epoch):
    global best_acc
    net.to(device)
    net.eval()

    test_loss, correct_1, correct_5, total = 0, 0, 0, 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.data.item()
            _, pred = outputs.topk(5, 1, largest=True, sorted=True)
            total += targets.size(0)
            correct = pred.eq(targets.view(targets.size(0), -
                                           1).expand_as(pred)).float().cpu()
            correct_1 += correct[:, :1].sum()
            correct_5 += correct[:, :5].sum()

            progress_bar(
                batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (test_loss /
                    (batch_idx + 1), 100. * correct_1 / total, correct_1, total))

    acc = 100. * correct_1 / total
    if epoch == start_epoch + args.epochs - 1 or acc > best_acc:
        save_checkpoint(net, acc, epoch)
    if acc > best_acc:
        best_acc = acc
    return (test_loss / batch_idx, 100. * correct_1 / total)


def save_checkpoint(net, acc, epoch):
    """ Save checkpoints. """
    print('Saving model checkpoint...')
    state = {
        'net': net,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    ckptname = os.path.join(
        f'{CHECKPOINTS_DIR}/', f'{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}.t7')
    torch.save(state, ckptname)


def adjust_learning_rate(optimizer, epoch) -> float:
    """
    Decrease learning rate at certain epochs based on dataset.
    """
    lr = args.lr

    if args.data == 'mnist':
        # MNIST: decay by factor of 0.1 at 15th abd 45th epochs
        if epoch >= 15:
            lr /= 10
        if epoch >= 45:
            lr /= 10
    elif args.data == 'fashion':
        # FashionMNIST: same as MNIST - decay by factor of 0.1 at 15th epoch
        if epoch >= 15:
            lr /= 10
    elif args.data == 'cifar10':
        # CIFAR-10: decay by factor of 0.1 at 100th and 150th epochs
        if epoch >= 100:
            lr /= 10
        if epoch >= 150:
            lr /= 10
    elif args.data == 'cifar100':
        # CIFAR-100: decay by factor of 0.2 at 60th, 120th, and 160th epochs
        if epoch >= 60:
            lr *= 0.2
        if epoch >= 120:
            lr *= 0.2
        if epoch >= 160:
            lr *= 0.2
    elif args.data == 'tiny':
        # TinyImageNet: decay by factor of 0.2 at 75th, 150th, 200th, and 250th epochs
        if epoch >= 75:
            lr *= 0.2
        if epoch >= 150:
            lr *= 0.2
        if epoch >= 200:
            lr *= 0.2
        if epoch >= 250:
            lr *= 0.2

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def main(device: torch.device, use_cuda: bool):
    global best_acc
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args.seed != 0:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    print('==> Number of lambdas: %g' % args.klam)

    ## --------------- Prepare data --------------- ##
    print('==> Preparing data..')
    if args.data in ['mnist', 'fashion']:
        print(f'Making {str(args.data).upper()} 3 channels...')
        data_workers = 0

        # the public dataset
        public_dataset, _, _ = get_dataset(
            dataset_type="cifar10", match_for_mnist=True)

        # the private datasets
        trainset, testset, num_classes = get_dataset(
            args.data, make_mnist_3_channels=True)
    else:
        data_workers = 4

        # the public dataset
        public_dataset, _, _ = get_dataset(dataset_type="tiny")

        # the private datasets
        trainset, testset, num_classes = get_dataset(args.data)

    # `num_workers`: https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=len(trainset),
        shuffle=True,
        num_workers=data_workers)
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=data_workers)
    trainloader_public = torch.utils.data.DataLoader(
        public_dataset,
        batch_size=len(public_dataset),
        shuffle=True,
        num_workers=data_workers)

    # load the public dataset onto device
    for _, (inputs_help, _) in enumerate(trainloader_public):
        inputs_help = inputs_help.to(device)

    ## --------------- Create the model --------------- ##
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            CHECKPOINTS_DIR), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(f'{CHECKPOINTS_DIR}/' + args.data + '_' +
                                args.name + 'ckpt.t7')
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        rng_state = checkpoint['rng_state']
        torch.set_rng_state(rng_state)
    else:
        print(f'==> Building model: {args.model}')
        net = local_models.__dict__[
            args.model](num_classes=num_classes)

    os.makedirs(LOG_DIR, exist_ok=True)
    logname = f'{LOG_DIR}/log_{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}seed_{args.epochs}epochs.csv'
    print(f'Saving everything to: {logname}')

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
        print('==> Using CUDA..')
    else:
        net.to(device)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.decay)

    ## --------------- Train and Eval --------------- ##
    print(f'==> Training on {os.uname().nodename} for #{args.epochs} epochs')
    if not os.path.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter='\t')
            logwriter.writerow([
                'Epoch', 'Train loss', 'Test loss',
                'Test acc', 'LR'
            ])

    mix_inputs_all, mix_targets_all, lams = generate_sample(
        trainloader, inputs_help, device)

    new_lr = args.lr
    for epoch in range(start_epoch, args.epochs):
        train_loss, _ = train(device, net, optimizer, mix_inputs_all,
                              mix_targets_all, lams, epoch, num_classes, new_lr)
        test_loss, test_acc1, = test(
            device, net, testloader, epoch, start_epoch)
        new_lr = adjust_learning_rate(optimizer, epoch)
        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter='\t')
            logwriter.writerow(
                [epoch, train_loss, test_loss, test_acc1, new_lr])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch InstaHide Training')

    # device configuration
    parser.add_argument('--gpu',
                        default=None,
                        type=str,
                        help='Select the GPU mode. (E.g., CUDA for NVIDIA GPUs or MPS for Apple M Series.)')

    # Training configurations
    parser.add_argument('--model',
                        default="resnet18",
                        type=str,
                        help='Model architecture (default: resnet18)')
    parser.add_argument('--data', default='cifar10', type=str,
                        help='dataset')

    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--batch-size', default=128,
                        type=int, help='batch size')
    parser.add_argument('--epochs',
                        default=200,
                        type=int,
                        help='total epochs to run')
    parser.add_argument('--no-augment',
                        dest='augment',
                        action='store_false',
                        help='use standard augmentation (default: True)')
    parser.add_argument('--decay', default=1e-4,
                        type=float, help='weight decay')

    # Saving configurations
    parser.add_argument('--name', default='cross',
                        type=str, help='name of run')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')

    # InstaHide configurations
    parser.add_argument('--klam', default=4, type=int,
                        help='number of lambdas')
    parser.add_argument('--mode', default='instahide',
                        type=str, help='InstaHide or Mixup')
    parser.add_argument('--pair', action='store_true')
    parser.add_argument('--upper', default=0.65, type=float,
                        help='the upper bound of any coefficient')
    parser.add_argument('--dom', default=0.3, type=float,
                        help='the lower bound of the sum of coefficients of two private images')

    args = parser.parse_args()
    # model to lowercase
    args.model = str(args.model).lower()

    if args.data not in SUPPORTED_DATASETS:
        raise ValueError(
            f"Dataset type not supported. Currently supported datasets: {SUPPORTED_DATASETS}")
    if args.model not in INSTAHIDE_MODELS:
        raise ValueError(
            f"Model architecture not supported. Currently supported models: {INSTAHIDE_MODELS}")

    # Device configuration
    use_cuda = torch.cuda.is_available()
    device = select_optimal_device()
    print(f'`use_cuda`: {use_cuda}')
    print(f'Using device: {device}')

    criterion = nn.CrossEntropyLoss()
    best_acc = 0  # best test accuracy

    main(device, use_cuda)
