import torch
import torch.nn as nn
from tqdm import trange
from src.attacks import pgd_rand
from src.context import ctx_noparamgrad_and_eval

def test_clean(loader, model, device):
    total_loss, total_correct = 0., 0.
    for X,y in loader:
        model.eval()
        X,y = X.to(device), y.to(device)
        with torch.no_grad():
            yp = model(X)
            loss = nn.CrossEntropyLoss()(yp,y)
        
        total_correct += (yp.argmax(dim = 1) == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        
    test_acc = total_correct / len(loader.dataset) * 100
    test_loss = total_loss / len(loader.dataset)
    return test_acc, test_loss

def test_adv(loader, model, attack, param, device):
    total_loss, total_correct = 0.,0.
    for X,y in loader:
        model.eval()
        X,y = X.to(device), y.to(device)
        with ctx_noparamgrad_and_eval(model):
            delta = attack(**param).generate(model,X,y)
        with torch.no_grad():
            yp = model(X+delta)
            loss = nn.CrossEntropyLoss()(yp,y)
        
        total_correct += (yp.argmax(dim = 1) == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        
    test_acc = total_correct / len(loader.dataset) * 100
    test_loss = total_loss / len(loader.dataset)
    return test_acc, test_loss

def test_transfer_adv(loader, transferred_model, attacked_model, attack, param, device):
    total_loss, total_correct = 0.,0.
    for X,y in loader:
        transferred_model.eval()
        attacked_model.eval()
        X,y = X.to(device), y.to(device)
        with ctx_noparamgrad_and_eval(transferred_model):
            delta = attack(**param).generate(transferred_model,X,y)
        with torch.no_grad():
            yp = attacked_model(X+delta)
            loss = nn.CrossEntropyLoss()(yp,y)
        
        total_correct += (yp.argmax(dim = 1) == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        
    test_acc = total_correct / len(loader.dataset) * 100
    test_loss = total_loss / len(loader.dataset)
    return test_acc, test_loss


def get_probs(model, x, y):
    output = model(x)
    probs = torch.nn.Softmax()(output)[:, y]
    return torch.diag(probs.data)

# 20-line implementation of (untargeted) SimBA for single image input
def simba_single(model, x, y, device, num_iters=3072, epsilon=8/255):
    n_dims = x.view(1, -1).size(1)
    perm = torch.randperm(n_dims)
    last_prob = get_probs(model, x, y)
    for i in range(num_iters):
        diff = torch.zeros(n_dims, device = device)
        diff[perm[i]] = epsilon
        left_prob = get_probs(model, (x - diff.view(x.size())).clamp(0, 1), y)
#         ipdb.set_trace()
        if left_prob < last_prob:
            x = (x - diff.view(x.size())).clamp(0, 1)
            last_prob = left_prob
        else:
            right_prob = get_probs(model, (x + diff.view(x.size())).clamp(0, 1), y)
            if right_prob < last_prob:
                x = (x + diff.view(x.size())).clamp(0, 1)
                last_prob = right_prob
    return x

def test_simba(loader, model, device):
    total_loss, total_correct = 0.,0.
    total_tested = 0
    with trange(1000) as t:
        for X,y in loader:
            model.eval()
            X,y = X.to(device), y.to(device)
            
            X_delta = simba_single(model, X, y, device, num_iters=3072, epsilon=8/255)

            with torch.no_grad():
                yp = model(X_delta)
            correct = (yp.argmax(dim = 1) == y).sum().item()

                    
            total_correct += correct
            total_tested +=1

            t.set_postfix(acc = '{0:.2f}%'.format(total_correct/total_tested*100))
            t.update()
        
            if total_tested == 1000:
                break
            

    test_acc = total_correct / total_tested
#     test_loss = total_loss / len(loader.dataset)
    return test_acc, 0