import torch
import torch.nn as nn
import torch.nn.functional as F

upper_limit, lower_limit = 1, -1

device = "cuda" if torch.cuda.is_available() else "cpu"


def normalize(X, mu, std):
    return (X - mu) / std


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts,
               norm, early_stop=False, mixup=False):
    # batch normailzation, calculate mean and std for X: (B, L)

    max_loss = torch.zeros(y.shape[0]).to(device)  # (B)
    max_delta = torch.zeros_like(X).to(device)  # (B, L)

    for _ in range(restarts):
        delta = torch.zeros_like(X).to(device)  # (B, L)
        if norm == "l_inf":
            delta.uniform_(-epsilon, epsilon)
        elif norm == "l_2":
            delta.normal_()
            d_flat = delta.view(delta.size(0), -1)
            n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r / n * epsilon
        else:
            raise ValueError
        delta = clamp(delta, lower_limit - X, upper_limit - X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            # if mixup:
            #     X, labels_a, labels_b, lam = mixup_data(X + delta, y, args.mixup_alpha)
            output = model(X + delta)
            # if early_stop:
            #     index = torch.where(output.max(1)[1] == y)[0]
            # else:
            index = slice(None, None, None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            # if mixup:
            #     criterion = nn.CrossEntropyLoss()
            #     loss = mixup_criterion(criterion, model(X + delta), y_a, y_b, lam)
            # else:
            loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :]
            g = grad[index, :]
            x = X[index, :]
            if norm == "l_inf":
                d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1)
                scaled_g = g / (g_norm + 1e-10)
                d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :] = d
            delta.grad.zero_()
            # if mixup:
            #     criterion = nn.CrossEntropyLoss(reduction='none')
            #     all_loss = mixup_criterion(criterion, model(X + delta), y_a, y_b, lam)
            # else:
            all_loss = F.cross_entropy(model(X + delta), y, reduction='none')
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
    # dict
    return max_delta


if __name__ == "__main__":
    # test attack_pgd
    model = nn.Linear(10, 2)
    # x values range from (-1, 1) with shape (4, 10)
    X = torch.rand(4, 10) * 2 - 1
    y = torch.randint(0, 2, (4,))
    epsilon = 0.5
    alpha = 0.01
    attack_iters = 2
    restarts = 1
    norm = "l_inf"
    early_stop = False
    mixup = False
    y_a = torch.randint(0, 2, (4,))
    y_b = torch.randint(0, 2, (4,))
    lam = 0.5

    # calculate mean and std for X: (B, L)
    mu = X.mean(dim=1, keepdim=True)
    std = X.std(dim=1, keepdim=True)
    delta = attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, norm, early_stop, mixup)
    print(delta)
