import tensorflow as tf
import numpy as np
import torch


class ModelAdapter():
    def __init__(self, logits, x, y, sess, num_classes=10):
        self.logits = logits
        self.sess = sess
        self.x_input = x
        self.y_input = y
        self.num_classes = num_classes

        # gradients of logits
        if num_classes <= 10:
            self.grads = [None] * num_classes
            for cl in range(num_classes):
                self.grads[cl] = tf.gradients(self.logits[:, cl], self.x_input)[0]

        # cross-entropy loss
        self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=self.logits, labels=self.y_input)
        self.grad_xent = tf.gradients(self.xent, self.x_input)[0]

        # dlr loss
        self.dlr = dlr_loss(self.logits, self.y_input, num_classes=self.num_classes)
        self.grad_dlr = tf.gradients(self.dlr, self.x_input)[0]

        # targeted dlr loss
        self.y_target = tf.placeholder(tf.int64, shape=[None])
        self.dlr_target = dlr_loss_targeted(self.logits, self.y_input, self.y_target, num_classes=self.num_classes)
        self.grad_target = tf.gradients(self.dlr_target, self.x_input)[0]

        self.la = tf.placeholder(tf.int64, shape=[None])
        self.la_target = tf.placeholder(tf.int64, shape=[None])
        la_mask = tf.one_hot(self.la, self.num_classes)
        la_target_mask = tf.one_hot(self.la_target, self.num_classes)
        la_logit = tf.reduce_sum(la_mask * self.logits, axis=1)
        la_target_logit = tf.reduce_sum(la_target_mask * self.logits, axis=1)
        self.diff_logits = la_target_logit - la_logit
        self.grad_diff_logits = tf.gradients(self.diff_logits, self.x_input)[0]

    def predict(self, x):
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        y = self.sess.run(self.logits, {self.x_input: x2})

        return torch.from_numpy(y).cuda()

    def grad_logits(self, x):
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        logits, g2 = self.sess.run([self.logits, self.grads], {self.x_input: x2})
        g2 = np.moveaxis(np.array(g2), 0, 1)
        g2 = np.transpose(g2, (0, 1, 4, 2, 3))

        return torch.from_numpy(logits).cuda(), torch.from_numpy(g2).cuda()

    def get_grad_diff_logits_target(self, x, y=None, y_target=None):
        la = y.cpu().numpy()
        la_target = y_target.cpu().numpy()
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        dl, g2 = self.sess.run([self.diff_logits, self.grad_diff_logits],
                               {self.x_input: x2, self.la: la, self.la_target: la_target})
        g2 = np.transpose(np.array(g2), (0, 3, 1, 2))

        return torch.from_numpy(dl).cuda(), torch.from_numpy(g2).cuda()

    def get_logits_loss_grad_xent(self, x, y):
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        y2 = y.clone().cpu().numpy()
        logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.xent, self.grad_xent],
                                                             {self.x_input: x2, self.y_input: y2})
        grad_val = np.moveaxis(grad_val, 3, 1)

        return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(
            grad_val).cuda()

    def get_logits_loss_grad_dlr(self, x, y):
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        y2 = y.clone().cpu().numpy()
        logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.dlr, self.grad_dlr],
                                                             {self.x_input: x2, self.y_input: y2})
        grad_val = np.moveaxis(grad_val, 3, 1)

        return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(
            grad_val).cuda()

    def get_logits_loss_grad_target(self, x, y, y_target):
        x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
        y2 = y.clone().cpu().numpy()
        y_targ = y_target.clone().cpu().numpy()
        logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.dlr_target, self.grad_target],
                                                             {self.x_input: x2, self.y_input: y2,
                                                              self.y_target: y_targ})
        grad_val = np.moveaxis(grad_val, 3, 1)

        return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(
            grad_val).cuda()


def dlr_loss(x, y, num_classes=10):
    x_sort = tf.contrib.framework.sort(x, axis=1)
    y_onehot = tf.one_hot(y, num_classes)
    ### TODO: adapt to the case when the point is already misclassified
    loss = -(x_sort[:, -1] - x_sort[:, -2]) / (x_sort[:, -1] - x_sort[:, -3] + 1e-12)

    return loss


def dlr_loss_targeted(x, y, y_target, num_classes=10):
    x_sort = tf.contrib.framework.sort(x, axis=1)
    y_onehot = tf.one_hot(y, num_classes)
    y_target_onehot = tf.one_hot(y_target, num_classes)
    loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (
                x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12)

    return loss