import torch
import numpy as np
from torch.autograd import Function, gradcheck
from scipy.special import softmax

RIGHT_C = 'right_complete'
LEFT_C = 'left_complete'
RIGHT_I = 'right_incomplete'
LEFT_I = 'left_incomplete'


def eisner_surrogate(arc_scores, hard=False):
    return EisnerSurrogate.apply(arc_scores, hard)


class EisnerSurrogate(Function):

    @staticmethod
    def _inside(n, weights, cw_arr, bp_arr, w_arr, item_map, hard):

        for l in range(1, n):
            for i in range(n - l):
                j = i + l

                # left to right incomplete
                w_right_i_i_j = \
                    cw_arr[item_map[RIGHT_C], i, i: j] \
                    + cw_arr[item_map[LEFT_C], i + 1: j + 1, j]

                if hard:
                    bp_right_i_i_j = np.zeros_like(w_right_i_i_j)
                    bp_right_i_i_j[np.argmax(w_right_i_i_j)] = 1.
                else:
                    bp_right_i_i_j = softmax(w_right_i_i_j)

                right_i_term = np.dot(bp_right_i_i_j, w_right_i_i_j)
                w_arr[item_map[RIGHT_I], i, j, i: j] = w_right_i_i_j
                bp_arr[item_map[RIGHT_I], i, j, i: j] = bp_right_i_i_j
                cw_arr[item_map[RIGHT_I], i, j] = weights[i, j] + right_i_term

                # right to left incomplete
                w_arr[item_map[LEFT_I], i, j, i: j] = w_right_i_i_j
                bp_arr[item_map[LEFT_I], i, j, i: j] = bp_right_i_i_j
                cw_arr[item_map[LEFT_I], i, j] = weights[j, i] + right_i_term

                # left to right complete
                w_right_c_i_j = \
                    cw_arr[item_map[RIGHT_I], i, i + 1: j + 1] \
                    + cw_arr[item_map[RIGHT_C], i + 1: j + 1, j]

                if hard:
                    bp_right_c_i_j = np.zeros_like(w_right_c_i_j)
                    bp_right_c_i_j[np.argmax(w_right_c_i_j)] = 1.
                else:
                    bp_right_c_i_j = softmax(w_right_c_i_j)

                w_arr[item_map[RIGHT_C], i, j, i + 1: j + 1] = w_right_c_i_j
                bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] = bp_right_c_i_j
                cw_arr[item_map[RIGHT_C], i, j] = np.dot(bp_right_c_i_j, w_right_c_i_j)

                # right to left complete
                w_left_c_i_j = \
                    cw_arr[item_map[LEFT_C], i, i: j] \
                    + cw_arr[item_map[LEFT_I], i: j, j]

                if hard:
                    bp_left_c_i_j = np.zeros_like(w_left_c_i_j)
                    bp_left_c_i_j[np.argmax(w_left_c_i_j)] = 1.
                else:
                    bp_left_c_i_j = softmax(w_left_c_i_j)

                w_arr[item_map[LEFT_C], i, j, i: j] = w_left_c_i_j
                bp_arr[item_map[LEFT_C], i, j, i:j] = bp_left_c_i_j
                cw_arr[item_map[LEFT_C], i, j] = np.dot(bp_left_c_i_j, w_left_c_i_j)

    @staticmethod
    def _backptr(n, soft_cw_arr, bp_arr, item_map):

        soft_cw_arr[item_map[RIGHT_C], 0, n - 1] = 1

        for l in range(n - 1, 0, -1):
            for i in range(0, n - l):
                j = i + l

                right_c_term = soft_cw_arr[item_map[RIGHT_C], i, j] * bp_arr[item_map[RIGHT_C], i, j,
                                                                        i + 1: j + 1]
                soft_cw_arr[item_map[RIGHT_I], i, i + 1: j + 1] += right_c_term
                soft_cw_arr[item_map[RIGHT_C], i + 1: j + 1, j] += right_c_term

                left_c_term = soft_cw_arr[item_map[LEFT_C], i, j] * bp_arr[item_map[LEFT_C], i, j, i: j]
                soft_cw_arr[item_map[LEFT_C], i, i: j] += left_c_term
                soft_cw_arr[item_map[LEFT_I], i: j, j] += left_c_term

                update_term = (soft_cw_arr[item_map[LEFT_I], i, j] + soft_cw_arr[item_map[RIGHT_I], i, j]) \
                              * bp_arr[item_map[RIGHT_I], i, j, i: j]

                soft_cw_arr[item_map[RIGHT_C], i, i: j] += update_term
                soft_cw_arr[item_map[LEFT_C], i + 1: j + 1, j] += update_term

    @staticmethod
    def forward(ctx, input, hard=False):

        item_map = {RIGHT_C: 0,
                    LEFT_C: 1,
                    RIGHT_I: 2,
                    LEFT_I: 3}

        n_item_types = len(item_map)

        w = input.cpu().detach().numpy()
        n = w.shape[-1]

        w_bp_arr_dim = (n_item_types,) + (n,) * 3
        cw_arr_dim = (n_item_types,) + (n,) * 2

        w_arr = np.zeros(shape=w_bp_arr_dim, dtype=np.float64)
        bp_arr = np.zeros_like(w_arr)
        cw_arr = np.zeros(shape=cw_arr_dim, dtype=np.float64)
        soft_cw_arr = np.zeros_like(cw_arr)

        EisnerSurrogate._inside(n, w, cw_arr, bp_arr, w_arr, item_map, hard)
        EisnerSurrogate._backptr(n, soft_cw_arr, bp_arr, item_map)

        d_tree = np.zeros_like(w)
        for i in range(n):
            for j in range(1, n):
                if i < j:
                    d_tree[i, j] = soft_cw_arr[item_map[RIGHT_I], i, j]
                elif j < i:
                    d_tree[i, j] = soft_cw_arr[item_map[LEFT_I], j, i]

        ctx.soft_cw = soft_cw_arr
        ctx.bp = bp_arr
        ctx.w = w_arr

        return input.new(d_tree)

    @staticmethod
    def backward(ctx, grad_output):

        soft_cw_arr, bp_arr, w_arr = ctx.soft_cw, ctx.bp, ctx.w

        item_map = {RIGHT_C: 0,
                    LEFT_C: 1,
                    RIGHT_I: 2,
                    LEFT_I: 3}

        n_item_types = len(item_map)
        np_grad_output = grad_output.cpu().detach().numpy()
        n = np_grad_output.shape[-1]

        w_bp_arr_dim = (n_item_types,) + (n,) * 3
        cw_arr_dim = (n_item_types,) + (n,) * 2

        g_w_arr = np.zeros(shape=w_bp_arr_dim, dtype=np.float64)
        g_bp_arr = np.zeros_like(g_w_arr)
        g_cw_arr = np.zeros(shape=cw_arr_dim, dtype=np.float64)
        g_soft_cw_arr = np.zeros_like(g_cw_arr)

        for i in range(n):
            for j in range(1, n):
                if i < j:
                    g_soft_cw_arr[item_map[RIGHT_I], i, j] = np_grad_output[i, j]
                elif j < i:
                    g_soft_cw_arr[item_map[LEFT_I], j, i] = np_grad_output[i, j]

        EisnerSurrogate._backward_backptr(n, g_soft_cw_arr, g_bp_arr, soft_cw_arr, bp_arr, item_map)
        EisnerSurrogate._backward_inside(n, g_cw_arr, g_bp_arr, g_w_arr, bp_arr, w_arr, item_map)
        np_grad_input = np.zeros_like(np_grad_output)

        for i in range(n):
            for j in range(1, n):
                if i < j:
                    np_grad_input[i, j] = g_cw_arr[item_map[RIGHT_I], i, j]
                elif j < i:
                    np_grad_input[i, j] = g_cw_arr[item_map[LEFT_I], j, i]

        return grad_output.new(np_grad_input), None

    @staticmethod
    def _backward_backptr(n, g_soft_cw_arr, g_bp_arr, soft_cw_arr, bp_arr, item_map):

        for l in range(1, n):
            for i in range(n - l):
                j = i + l

                # right to left incomplete
                update_term = np.dot(g_soft_cw_arr[item_map[RIGHT_C], i, i:j]
                                     + g_soft_cw_arr[item_map[LEFT_C], i + 1: j + 1, j],
                                     bp_arr[item_map[LEFT_I], i, j, i:j])
                g_soft_cw_arr[item_map[LEFT_I], i, j] += update_term

                g_bp_arr[item_map[LEFT_I], i, j, i:j] = \
                    (g_soft_cw_arr[item_map[RIGHT_C], i, i:j]
                     + g_soft_cw_arr[item_map[LEFT_C], i + 1: j + 1, j]) \
                    * soft_cw_arr[item_map[LEFT_I], i, j]

                # left to right incomplete
                g_soft_cw_arr[item_map[RIGHT_I], i, j] += update_term

                g_bp_arr[item_map[RIGHT_I], i, j, i:j] = \
                    (g_soft_cw_arr[item_map[RIGHT_C], i, i:j]
                     + g_soft_cw_arr[item_map[LEFT_C], i + 1: j + 1, j]) \
                    * soft_cw_arr[item_map[RIGHT_I], i, j]

                # right to left complete
                g_soft_cw_arr[item_map[LEFT_C], i, j] += \
                    np.dot(g_soft_cw_arr[item_map[LEFT_C], i, i:j]
                           + g_soft_cw_arr[item_map[LEFT_I], i:j, j],
                           bp_arr[item_map[LEFT_C], i, j, i:j])

                g_bp_arr[item_map[LEFT_C], i, j, i:j] = \
                    (g_soft_cw_arr[item_map[LEFT_C], i, i:j]
                     + g_soft_cw_arr[item_map[LEFT_I], i:j, j]) \
                    * soft_cw_arr[item_map[LEFT_C], i, j]

                # left to right complete
                g_soft_cw_arr[item_map[RIGHT_C], i, j] += \
                    np.dot(g_soft_cw_arr[item_map[RIGHT_I], i, i + 1:j + 1]
                           + g_soft_cw_arr[item_map[RIGHT_C], i + 1: j + 1, j],
                           bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1])

                g_bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] = \
                    (g_soft_cw_arr[item_map[RIGHT_I], i, i + 1:j + 1]
                     + g_soft_cw_arr[item_map[RIGHT_C], i + 1: j + 1, j]) \
                    * soft_cw_arr[item_map[RIGHT_C], i, j]

    @staticmethod
    def _backward_inside(n, g_cw_arr, g_bp_arr, g_w_arr, bp_arr, w_arr, item_map):

        for l in range(n - 1, 0, -1):
            for i in range(0, n - l):
                j = i + l

                # right to left complete
                g_bp_arr[item_map[LEFT_C], i, j, i:j] += \
                    g_cw_arr[item_map[LEFT_C], i, j] \
                    * w_arr[item_map[LEFT_C], i, j, i:j]

                g_w_arr[item_map[LEFT_C], i, j, i:j] += \
                    g_cw_arr[item_map[LEFT_C], i, j] \
                    * bp_arr[item_map[LEFT_C], i, j, i:j]

                s = np.dot(g_bp_arr[item_map[LEFT_C], i, j, i:j],
                           bp_arr[item_map[LEFT_C], i, j, i:j])

                g_w_arr[item_map[LEFT_C], i, j, i:j] += \
                    bp_arr[item_map[LEFT_C], i, j, i:j] \
                    * (g_bp_arr[item_map[LEFT_C], i, j, i:j] - s)

                g_cw_arr[item_map[LEFT_C], i, i:j] += g_w_arr[item_map[LEFT_C], i, j, i:j]
                g_cw_arr[item_map[LEFT_I], i:j, j] += g_w_arr[item_map[LEFT_C], i, j, i:j]

                # left to right complete
                g_bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] += \
                    g_cw_arr[item_map[RIGHT_C], i, j] \
                    * w_arr[item_map[RIGHT_C], i, j, i + 1:j + 1]

                g_w_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] += \
                    g_cw_arr[item_map[RIGHT_C], i, j] \
                    * bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1]

                s = np.dot(g_bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1],
                           bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1])

                g_w_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] += \
                    bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] \
                    * (g_bp_arr[item_map[RIGHT_C], i, j, i + 1:j + 1] - s)

                g_cw_arr[item_map[RIGHT_I], i, i + 1:j + 1] += g_w_arr[item_map[RIGHT_C], i, j, i + 1:j + 1]
                g_cw_arr[item_map[RIGHT_C], i + 1:j + 1, j] += g_w_arr[item_map[RIGHT_C], i, j, i + 1:j + 1]

                # right to left incomplete
                g_bp_arr[item_map[LEFT_I], i, j, i:j] += \
                    g_cw_arr[item_map[LEFT_I], i, j] \
                    * w_arr[item_map[LEFT_I], i, j, i:j]

                g_w_arr[item_map[LEFT_I], i, j, i:j] += \
                    g_cw_arr[item_map[LEFT_I], i, j] \
                    * bp_arr[item_map[LEFT_I], i, j, i:j]

                s = np.dot(g_bp_arr[item_map[LEFT_I], i, j, i:j],
                           bp_arr[item_map[LEFT_I], i, j, i:j])

                g_w_arr[item_map[LEFT_I], i, j, i:j] += \
                    bp_arr[item_map[LEFT_I], i, j, i:j] \
                    * (g_bp_arr[item_map[LEFT_I], i, j, i:j] - s)

                g_cw_arr[item_map[RIGHT_C], i, i:j] += g_w_arr[item_map[LEFT_I], i, j, i:j]
                g_cw_arr[item_map[LEFT_C], i + 1:j + 1, j] += g_w_arr[item_map[LEFT_I], i, j, i:j]

                # left to right incomplete
                g_bp_arr[item_map[RIGHT_I], i, j, i:j] += \
                    g_cw_arr[item_map[RIGHT_I], i, j] \
                    * w_arr[item_map[RIGHT_I], i, j, i:j]

                g_w_arr[item_map[RIGHT_I], i, j, i:j] += \
                    g_cw_arr[item_map[RIGHT_I], i, j] \
                    * bp_arr[item_map[RIGHT_I], i, j, i:j]

                s = np.dot(g_bp_arr[item_map[RIGHT_I], i, j, i:j],
                           bp_arr[item_map[RIGHT_I], i, j, i:j])

                g_w_arr[item_map[RIGHT_I], i, j, i:j] += \
                    bp_arr[item_map[RIGHT_I], i, j, i:j] \
                    * (g_bp_arr[item_map[RIGHT_I], i, j, i:j] - s)

                g_cw_arr[item_map[RIGHT_C], i, i:j] += g_w_arr[item_map[RIGHT_I], i, j, i:j]
                g_cw_arr[item_map[LEFT_C], i + 1:j + 1, j] += g_w_arr[item_map[RIGHT_I], i, j, i:j]


if __name__ == '__main__':
    torch.set_printoptions(sci_mode=False)
    torch.manual_seed(0)

    eisner_alg = EisnerSurrogate.apply
    tensor = torch.randn((20, 20), requires_grad=True, dtype=torch.float64)
    test = gradcheck(eisner_alg, (tensor, False), eps=1e-6, atol=1e-4)
    print(test)