import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from ml_common import get_device, test
from ml_datasets import nclasses_dict
from torch.utils.data import DataLoader, TensorDataset
from utils import D3Model


def get_labels(X_sub: torch.Tensor, model: torch.nn.Module, pred_type="soft"):
    device = get_device()
    ds = TensorDataset(X_sub)
    dl = DataLoader(ds, batch_size=128)
    ys = torch.tensor((), device=device)
    model = model.to(device)
    model.eval()
    coherence = torch.tensor([], device=device)

    with torch.no_grad():

        for (x,) in dl:
            x = x.to(device)
            if pred_type == "soft":
                y = model(x)
                y = F.softmax(y, dim=-1)
            else:
                y = model(x, label=True)
            try:
                c = model.coherence(x)
                coherence = torch.cat((coherence, c))
            except:
                pass
            ys = torch.cat((ys, y))
    if pred_type == "hard":
        ys = ys.long()
    if coherence.shape[0] > 0:
        print("coherence:{:.3f}".format(coherence.mean().item()))
    return ys


def batch_indices(batch_nb, data_length, batch_size):
    # Batch start and end index
    start = int(batch_nb * batch_size)
    end = int((batch_nb + 1) * batch_size)

    # When there are not enough inputs left, we reuse some to complete the
    # batch
    if end > data_length:
        shift = end - data_length
        start -= shift
        end -= shift

    return start, end


def to_var(x, requires_grad=False, volatile=False):
    """
    Varialbe type that automatically choose cpu or cuda
    """
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=requires_grad, volatile=volatile)


def jacobian(model, x, nb_classes=10):
    """
    This function will return a list of PyTorch gradients
    """
    list_derivatives = []
    x_var = Variable(x, requires_grad=True)

    # derivatives for each class
    for class_ind in range(nb_classes):
        x_var_exp = x_var.unsqueeze(0)
        score = model(x_var_exp)[:, class_ind]
        score.backward()
        list_derivatives.append(x_var.grad.data.cpu().numpy())
        x_var.grad.data.zero_()

    return list_derivatives


def jacobian_augmentation(
    model, X_sub_prev, Y_sub, lmbda=0.1, nb_classes=10, bounds=[-1, 1]
):
    """
    Create new numpy array for adversary training data
    with twice as many components on the first dimension.
    """
    model.eval()
    device = get_device()
    X_sub = torch.cat((X_sub_prev, X_sub_prev), dim=0)
    if Y_sub.ndim == 2:
        # Labels could be a posterior probability distribution. Use argmax as a proxy.
        Y_sub = torch.argmax(Y_sub, axis=1)

    # For each input in the previous' substitute training iteration
    offset = X_sub_prev.shape[0]
    for ind, x in enumerate(X_sub_prev):
        grads = jacobian(model, x, nb_classes)
        # Select gradient corresponding to the label predicted by the oracle
        grad = grads[Y_sub[ind]]

        # Compute sign matrix
        grad_val = torch.sign(torch.tensor(grad, device=device))

        # Create new synthetic point in adversary substitute training set
        X_sub[offset + ind] = x - lmbda * grad_val

    X_sub = torch.clamp(X_sub, bounds[0], bounds[1])

    # Return augmented training data (needs to be labeled afterwards)
    return X_sub


def jacobian_tr_augmentation(
    model, X_sub_prev, Y_sub, lmbda=0.1, nb_classes=10, bounds=[-1, 1], fgsm_iter=5
):
    """
    Create new numpy array for adversary training data
    with twice as many components on the first dimension.
    """
    model.eval()
    device = get_device()
    X_sub = torch.cat((X_sub_prev, X_sub_prev), dim=0)
    if Y_sub.ndim == 2:
        # Labels could be a posterior probability distribution. Use argmax as a proxy.
        Y_sub = torch.argmax(Y_sub, axis=1)

    # For each input in the previous' substitute training iteration
    offset = 0
    for _ in range(1):
        offset += len(X_sub_prev)
        for ind, x in enumerate(X_sub_prev):
            ind_tar = (
                ind + np.random.randint(nb_classes)
            ) % nb_classes  # pick a random target class
            for _ in range(fgsm_iter):
                grads = jacobian(model, x, nb_classes)
                # Select gradient corresponding to the label picked as the target
                grad = grads[ind_tar]

                # Compute sign matrix
                grad_val = torch.sign(torch.tensor(grad, device=device))

                # Create new synthetic point in adversary substitute training set
                x += lmbda * grad_val / fgsm_iter
                x = torch.clamp(x, bounds[0], bounds[1])
            X_sub[offset + ind] = x

    X_sub = torch.clamp(X_sub, bounds[0], bounds[1])

    # Return augmented training data (needs to be labeled afterwards)
    return X_sub


def jbda(
    T: D3Model,
    S,
    dataloader_train,
    dataloader_test,
    opt,
    acc_tar,
    num_seed=100,
    aug_rounds=6,
    epochs=10,
    batch_size=128,
    dataset="mnist",
    bounds=[-1, 1],
    mode="jbda",
    lmbda=0.1,
    pred_type="soft",
):
    # Label seed data
    device = get_device()
    S = S.to(device)
    T = T.to(device)

    num_classes = nclasses_dict[dataset]
    data_iter = iter(
        DataLoader(dataloader_train.dataset, batch_size=num_seed, shuffle=False)
    )
    X_sub, _ = data_iter.next()
    X_sub = X_sub.to(device)

    Y_sub = get_labels(X_sub, T, pred_type=pred_type)
    if pred_type == "soft":
        criterion = torch.nn.KLDivLoss(reduction="batchmean")
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Train the substitute and augment dataset alternatively
    T.eval()
    for aug_round in range(1, aug_rounds + 1):
        # model training
        # Indices to shuffle training set
        ds = TensorDataset(X_sub, Y_sub)
        dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
        S.train()

        for _ in range(epochs):
            for x, y in dataloader:
                opt.zero_grad()
                x, y = x.to(device), y.to(device)
                Sout = S(x)
                Sout = F.log_softmax(Sout, dim=-1)

                lossS = criterion(Sout, y)
                lossS.backward()
                opt.step()

        test_acc = test(S, dataloader_test)

        # If we are not in the last substitute training iteration, augment dataset
        if aug_round < aug_rounds:
            print("[{}] Augmenting substitute training data.".format(aug_round))
            # Perform the Jacobian augmentation
            if mode == "jbda":
                X_sub = jacobian_augmentation(
                    S, X_sub, Y_sub, nb_classes=num_classes, bounds=bounds, lmbda=lmbda
                )
            elif mode == "jbda-tr":
                X_sub = jacobian_tr_augmentation(
                    S, X_sub, Y_sub, nb_classes=num_classes, bounds=bounds, lmbda=lmbda
                )

            print("Labeling substitute training data.")
            Y_sub = get_labels(X_sub, T, pred_type=pred_type)
        print(
            "Aug Round {} Clone Accuracy: {:.2f}({:.2f}x)".format(
                aug_round, test_acc * 100, test_acc / acc_tar
            )
        )

