import numpy as np
import time
import torch
#import scipy.io

#import numpy.linalg as nl

#
import os
import sys

import torch.nn as nn
import torch.nn.functional as F

def dense_to_onehot(y_test, n_cls):
    y_test_onehot = np.zeros([len(y_test), n_cls], dtype=bool)
    y_test_onehot[np.arange(len(y_test)), y_test] = True
    return y_test_onehot

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)

def get_loss(y, logits, targeted=False, loss_type='margin_loss'):
    """ Implements the margin loss (difference between the correct and 2nd best class). """
    if loss_type == 'margin_loss':
        preds_correct_class = (logits * y).sum(1, keepdims=True)
        diff = preds_correct_class - logits  # difference between the correct class and all other classes
        diff[y] = np.inf  # to exclude zeros coming from f_correct - f_correct
        margin = diff.min(1, keepdims=True)
        loss = margin * -1 if targeted else margin
    elif loss_type == 'cross_entropy':
        probs = softmax(logits)
        loss = -np.log(probs[y])
        loss = loss * -1 if not targeted else loss
    else:
        raise ValueError('Wrong loss.')
    return loss.flatten()


def p_selection(p_init, it, n_iters):
    """ Piece-wise constant schedule for p (the fraction of pixels changed on every iteration). """
    it = int(it / n_iters * 10000)

    if 10 < it <= 50:
        p = p_init / 2
    elif 50 < it <= 200:
        p = p_init / 4
    elif 200 < it <= 500:
        p = p_init / 8
    elif 500 < it <= 1000:
        p = p_init / 16
    elif 1000 < it <= 2000:
        p = p_init / 32
    elif 2000 < it <= 4000:
        p = p_init / 64
    elif 4000 < it <= 6000:
        p = p_init / 128
    elif 6000 < it <= 8000:
        p = p_init / 256
    elif 8000 < it <= 10000:
        p = p_init / 512
    else:
        p = p_init

    return p

def square_attack_linf_adapt(GradOracle, x, y, eps, n_iters, p_init, device, loss_type='cross_entropy', update_times=5):
    """ The Linf square attack """
    np.random.seed(0)  # important to leave it here as well
    y_torch = y
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    y = dense_to_onehot(y, 10)
    min_val, max_val = 0, 1 if x.max() <= 1 else 255
    c, h, w = x.shape[1:]
    n_features = c*h*w
    n_ex_total = x.shape[0]

    # [c, 1, w], i.e. vertical stripes work best for untargeted attacks
    init_delta = np.random.choice([-eps, eps], size=[x.shape[0], c, 1, w])
    x_best = np.clip(x + init_delta, min_val, max_val)
    model = GradOracle.get_model(torch.as_tensor(x_best, device=device).to(torch.float32), y_torch, update=True)
    logits = model(torch.as_tensor(x_best, device=device).to(torch.float32)).detach().cpu().numpy()
    loss_min = get_loss(y, logits, loss_type=loss_type)
    margin_min = get_loss(y, logits, loss_type='margin_loss')
    n_queries = np.ones(x.shape[0])  # ones because we have already used 1 query
    # time_start = time.time()

    for i_iter in range(n_iters - 1):
        deltas = x_best - x

        p = p_selection(p_init, i_iter, n_iters)
        for i_img in range(x_best.shape[0]):
            s = int(round(np.sqrt(p * n_features / c)))
            s = min(max(s, 1), h-1)  # at least c x 1 x 1 window is taken and at most c x h-1 x h-1
            center_h = np.random.randint(0, h - s)
            center_w = np.random.randint(0, w - s)

            x_window = x[i_img, :, center_h:center_h+s, center_w:center_w+s]
            x_best_window = x_best[i_img, :, center_h:center_h+s, center_w:center_w+s]
            # prevent trying out a delta if it doesn't change x (e.g. an overlapping patch)
            while np.sum(np.abs(np.clip(x_window + deltas[i_img, :, center_h:center_h+s, center_w:center_w+s], min_val, max_val) - x_best_window) < 10**-7) == c*s*s:
                deltas[i_img, :, center_h:center_h+s, center_w:center_w+s] = np.random.choice([-eps, eps], size=[c, 1, 1])

        x_new = np.clip(x + deltas, min_val, max_val)

        if (i_iter+1) % (n_iters//update_times) == 0:
            model = GradOracle.get_model(torch.as_tensor(x_best, device=device).to(torch.float32), y_torch, update=True)
            tmp_logits = model(torch.as_tensor(x_best, device=device).to(torch.float32)).cpu().numpy()
            loss_min = get_loss(y, tmp_logits, loss_type=loss_type)

        logits = model(torch.as_tensor(x_new, device=device).to(torch.float32)).cpu().numpy()
        loss = get_loss(y, logits, loss_type=loss_type)
        margin = get_loss(y, logits, loss_type='margin_loss')

        idx_improved = loss < loss_min
        loss_min = idx_improved * loss + ~idx_improved * loss_min
        margin_min = idx_improved * margin + ~idx_improved * margin_min
        idx_improved = np.reshape(idx_improved, [-1, *[1]*len(x.shape[:-1])])
        x_best = idx_improved * x_new + ~idx_improved * x_best
        n_queries += 1

        acc = (margin_min > 0.0).sum() / n_ex_total

    return acc, torch.as_tensor(x_best, device=device).to(torch.float32)


def square_attack_linf(model, x, y, eps, n_iters, p_init, device, loss_type='cross_entropy'):
    """ The Linf square attack """
    np.random.seed(0)  # important to leave it here as well
    y_torch = y
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    y = dense_to_onehot(y, 10)
    min_val, max_val = 0, 1 if x.max() <= 1 else 255
    c, h, w = x.shape[1:]
    n_features = c*h*w
    n_ex_total = x.shape[0]

    # [c, 1, w], i.e. vertical stripes work best for untargeted attacks
    init_delta = np.random.choice([-eps, eps], size=[x.shape[0], c, 1, w])
    x_best = np.clip(x + init_delta, min_val, max_val)
    logits = model(torch.as_tensor(x_best, device=device).to(torch.float32)).detach().cpu().numpy()
    loss_min = get_loss(y, logits, loss_type=loss_type)
    margin_min = get_loss(y, logits, loss_type='margin_loss')
    n_queries = np.ones(x.shape[0])  # ones because we have already used 1 query
    # time_start = time.time()

    for i_iter in range(n_iters - 1):
        deltas = x_best - x

        p = p_selection(p_init, i_iter, n_iters)
        for i_img in range(x_best.shape[0]):
            s = int(round(np.sqrt(p * n_features / c)))
            s = min(max(s, 1), h-1)  # at least c x 1 x 1 window is taken and at most c x h-1 x h-1
            center_h = np.random.randint(0, h - s)
            center_w = np.random.randint(0, w - s)

            x_window = x[i_img, :, center_h:center_h+s, center_w:center_w+s]
            x_best_window = x_best[i_img, :, center_h:center_h+s, center_w:center_w+s]
            # prevent trying out a delta if it doesn't change x (e.g. an overlapping patch)
            while np.sum(np.abs(np.clip(x_window + deltas[i_img, :, center_h:center_h+s, center_w:center_w+s], min_val, max_val) - x_best_window) < 10**-7) == c*s*s:
                deltas[i_img, :, center_h:center_h+s, center_w:center_w+s] = np.random.choice([-eps, eps], size=[c, 1, 1])

        x_new = np.clip(x + deltas, min_val, max_val)

        logits = model(torch.as_tensor(x_new, device=device).to(torch.float32)).cpu().numpy()
        loss = get_loss(y, logits, loss_type=loss_type)
        margin = get_loss(y, logits, loss_type='margin_loss')

        idx_improved = loss < loss_min
        loss_min = idx_improved * loss + ~idx_improved * loss_min
        margin_min = idx_improved * margin + ~idx_improved * margin_min
        idx_improved = np.reshape(idx_improved, [-1, *[1]*len(x.shape[:-1])])
        x_best = idx_improved * x_new + ~idx_improved * x_best
        n_queries += 1

        acc = (margin_min > 0.0).sum() / n_ex_total

    return acc, torch.as_tensor(x_best, device=device).to(torch.float32)