import torch
import numpy as np
from src import projection
from src import Net,Net2

def eval_accuracy(model, x_test, y_test, device):
    lenth = np.float32(len(x_test))
    batch_size = 1000
    nb_batches = int(len(x_test) / batch_size)
    if len(x_test) % batch_size != 0:
        nb_batches += 1

    acc = 0
    for batch in range(nb_batches):
        ind_batch = range(batch_size * batch, min(batch_size * (1 + batch), len(x_test)))
        y = torch.tensor(y_test[ind_batch], dtype=torch.long).to(device)
        x = x_test[ind_batch]
        x = torch.tensor(x, dtype=torch.float32).to(device)
        cls_test_noisy = model(x)
        preds = torch.argmax(cls_test_noisy, 1)
        acc += torch.sum(torch.eq(preds, y), dtype=torch.float32).detach().cpu()
    acc /= lenth
    return acc

def pretrain(model, train_x, train_y, test_x, test_y,
            criterion, beta, G_0, tau, eta_y, rw,  device,full_epochs,
            sigma=0, target=8):


    idx_8 = np.where(test_y==target)[0]
    X_8 = test_x[idx_8]
    Y_8 = test_y[idx_8]

    beta = 0.1

    batch_size = 200
    v = 0
    for epoch in range(full_epochs):

        nb_batches = int(len(train_x) / batch_size)
        ind_shuf = np.arange(len(train_x))
        np.random.shuffle(ind_shuf)
        if len(train_x) % batch_size != 0:
            nb_batches += 1
        for batch in range(nb_batches):
            ind_batch = range(batch_size * batch, min(batch_size * (1 + batch), len(train_x)))
            ind_tr = ind_shuf[ind_batch]

            x_tr = torch.tensor(train_x[ind_tr], dtype=torch.float32).to(device)
            y_tr = torch.tensor(train_y[ind_tr], dtype=torch.long).to(device)

            train_outputs = model(x_tr)
            ce_loss_train = criterion(train_outputs, y_tr)
            g = (ce_loss_train) / batch_size
            Gy_gradient = torch.autograd.grad(g, model.parameters())
            Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
            params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
            with torch.no_grad():
                v = (1 - beta) * v + beta * Gy_gradients
                Z = params - tau * v / (torch.sqrt(torch.norm(v)) + G_0)
                Pz = projection.projection_l2_ball(Z, rw)
                load_model_parameters(model, Pz)
            # with torch.no_grad():
            #     v = (1 - beta) * v + beta * Gy_gradients
            #     Z = params - tau * v / (torch.sqrt(torch.norm(v)) + G_0)
            #     Pz = projection.projection_l2_ball(Z, rw)
            #     param_tmp = (1 - eta_y) * params + eta_y * Pz
            #     load_model_parameters(model, param_tmp)

        if epoch % 10 == 0:
            test_acc = eval_accuracy(model, test_x, test_y, device)
            acc8 = eval_accuracy(model, X_8, Y_8, device)
            print("Accuracy gaussaian aug", epoch, target)
            print("Test 8:", acc8)
            print("Test:", test_acc, "\n")
    return test_acc, acc8


def eval_on_poison_data(train_x, train_y, test_x, test_y, X_poisoned, Y_poisoned,
                        criterion, beta, G_0, tau, eta_y, rw,  device,full_epochs,
                        name, sigma=0, target=8):
    if name == 'mnist':
        model = Net().to(device)
    else:
        model = Net2().to(device)
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    Pz = projection.projection_l2_ball(params, rw)
    load_model_parameters(model, Pz)

    idx_8_train = np.where(train_y==target)[0]
    train_x = np.delete(train_x, idx_8_train, 0)
    train_y = np.delete(train_y, idx_8_train, 0)
    train_x = np.concatenate([train_x, X_poisoned])
    train_y = np.concatenate([train_y, Y_poisoned])

    idx_8 = np.where(test_y==target)[0]
    X_8 = test_x[idx_8]
    Y_8 = test_y[idx_8]

    batch_size = 200

    v = 0
    beta=0.1
    for epoch in range(full_epochs):

        nb_batches = int(len(train_x) / batch_size)
        ind_shuf = np.arange(len(train_x))
        np.random.shuffle(ind_shuf)
        if len(train_x) % batch_size != 0:
            nb_batches += 1
        for batch in range(nb_batches):
            ind_batch = range(batch_size * batch, min(batch_size * (1 + batch), len(train_x)))
            ind_tr = ind_shuf[ind_batch]
            x_tr = torch.tensor(train_x[ind_tr], dtype=torch.float32).to(device)
            y_tr = torch.tensor(train_y[ind_tr], dtype=torch.long).to(device)
            train_outputs = model(x_tr)
            ce_loss_train = criterion(train_outputs, y_tr)
            Gy_gradient = torch.autograd.grad(ce_loss_train, model.parameters())
            Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
            params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
            # with torch.no_grad():
            #     v = (1 - beta) * v + beta * Gy_gradients
            #     Z = params - tau * v / (torch.sqrt(torch.norm(v)) + G_0)
            #     Pz = projection.projection_l2_ball(Z, rw)
            with torch.no_grad():
                v = (1 - beta) * v + beta * Gy_gradients
                Z = params - tau * v / (torch.sqrt(torch.norm(v)) + G_0)
                Pz = projection.projection_l2_ball(Z, rw)
                param_tmp = (1 - eta_y) * params + eta_y * Pz
                load_model_parameters(model, param_tmp)

        # if epoch % 100 == 0:
    test_acc = eval_accuracy(model, test_x, test_y, device)
    acc8 = eval_accuracy(model, X_8, Y_8, device)
    print("Accuracy gaussaian aug", epoch, target)
    print("Test 8:", acc8)
    print("Test:", test_acc, "\n")
    return  test_acc, acc8


def out_f(model, val_x, val_y, beta_macer,
          nclass, k_macer, sigma_macer):
    device = val_x.device

    noise = torch.randn_like(val_x) * sigma_macer
    noise_inputs = val_x + noise
    outputs = model(noise_inputs)
    outputs = torch.reshape(outputs, [-1, k_macer, nclass])
    beta_outputs = outputs * beta_macer

    beta_outputs_softmax = torch.mean(torch.nn.functional.softmax(beta_outputs, dim=2), dim=1)
    top2_score, top2_idx = torch.topk(beta_outputs_softmax, 2)
    dist = torch.distributions.normal.Normal(loc=0, scale=1)
    a = top2_idx[:, 0]
    b = torch.eq(a, val_y)
    correct = torch.where(torch.eq(top2_idx[:, 0], val_y), torch.arange(100).to(device), -10 * torch.ones(100).to(device))
    mask_1 = torch.tensor(correct >= 0, dtype=torch.bool).to(device)
    alpha = 0.001

    robustness_loss_1 = torch.masked_select(dist.icdf((1 - 2 * alpha) * top2_score[:, 0] + alpha)
                                          - dist.icdf((1 - 2 * alpha) * top2_score[:, 1] + alpha), mask_1)
    robustness_loss = torch.sum(robustness_loss_1 * sigma_macer * 0.5) / 100
    f = robustness_loss
    return f

def inner_g(model,criterion, u, Y_poisoned, train_x, train_y, sigma_macer):
    batch_size_poisoned = len(Y_poisoned)
    batch_size_clean = len(train_y)
    noise_gp = torch.randn_like(u) * sigma_macer
    noise_poisons = u + noise_gp
    poison_outputs = model(noise_poisons)
    ce_loss_poison = criterion(poison_outputs, Y_poisoned)

    noise_gt = torch.randn_like(train_x) * sigma_macer
    noisy_train = train_x + noise_gt
    train_outputs = model(noisy_train)
    ce_loss_train = criterion(train_outputs, train_y)
    g = (ce_loss_train + ce_loss_poison) / (batch_size_poisoned + batch_size_clean)
    return g

# def eval_accuracy(model, x_test, y_test, add_noise, sigma_macer, device):
#     lenth = np.float32(len(x_test))
#     batch_size = 1000
#     nb_batches = int(len(x_test) / batch_size)
#     if len(x_test) % batch_size != 0:
#         nb_batches += 1
#
#     acc = 0
#     for batch in range(nb_batches):
#         ind_batch = range(batch_size * batch, min(batch_size * (1 + batch), len(x_test)))
#         y = torch.tensor(y_test[ind_batch], dtype=torch.long).to(device)
#         if add_noise:
#             noise = np.random.normal(0, 1, x_test[ind_batch].shape) * sigma_macer
#             x = x_test[ind_batch] + noise
#             x = torch.tensor(x, dtype=torch.float32).to(device)
#             cls_test_noisy = model(x)
#         else:
#             x = x_test[ind_batch]
#             x = torch.tensor(x, dtype=torch.float32).to(device)
#             cls_test_noisy = model(x)
#         preds = torch.argmax(cls_test_noisy, 1)
#         acc += torch.sum(torch.eq(preds, y), dtype=torch.float32).detach().cpu()
#     acc /= lenth
#     return acc


def load_model_parameters(model, parameters_old):
    start=0
    # offset=0
    for param in model.parameters():
        offset = len(torch.reshape(param, [-1]))
        # tmp = torch.reshape(parameters_old[start:start + offset], param.shape)
        param.data = torch.reshape(parameters_old[start:start+offset], param.shape)
        start = start+offset
    return model