import math

import numpy as np
import torch

sft = np.array([[-1, -1, 0], [-1, 0, -1]])
sft_nodown = np.array([[-1, 0], [-1, -1]])  # only slash and from left, only appear in anchorParse
sft_noup = sft  # this only appears in videoParse


def DTW(dist, get_acc=False):
    # a non-differentiable DTW algorithm in numpy
    # a torch tensor of [L1, L2]
    # returns a path matrix of the same shape
    # distance can be get simple by integrate input along output
    m, n = dist.shape
    accumulated = np.ones((m + 1, n + 1)) * np.inf
    accumulated[0, 0] = 0

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = dist[i - 1, j - 1]
            indices = np.array([[i], [j]]) + sft  # [2, 3]
            costs = accumulated[tuple(indices)]
            accumulated[i, j] = cost + np.min(costs)

    path = np.zeros((m, n))
    start = np.array([[m], [n]])
    while start.sum() != 0:
        path[tuple(start - 1)] = 1
        indices = start + sft
        costs = accumulated[tuple(indices)]
        move = sft[:, np.argmin(costs), np.newaxis]
        start += move
    if get_acc:
        return path, accumulated
    else:
        return path


def videoParse(T, R, stop, debug=False):
    # T, numpy array of shape [M, f]
    # R, numpy array of shape [V, N, f]
    # stop, an int
    # yields tuple (S, v), where S of shape [N, f], v < V
    M = T.shape[0]
    N = R.shape[1]
    while M > stop:
        d_best = np.inf
        D_best = None
        v_best = -1
        cut = 0.1
        # S_container = []
        sims = 1 - np.transpose(R @ T.transpose(), (0, 2, 1))  # [V, M, N], well, it is actually distance
        if debug:
            d_container = []
        for v, sim in enumerate(sims):  # do grid building V times
            d = np.concatenate([sim, np.zeros((M, 1))], axis=1)  # [M, N+1]
            if debug:
                d_container.append(d)
            D = np.ones((M + 1, N + 2)) * np.inf
            D[0, 0] = 0
            for i in range(1, M + 1):
                for j in range(1, N + 2):
                    cost = d[i - 1, j - 1]
                    indices = np.array([[i], [j]]) + (sft_noup if j < N + 1 else sft)  # [2, 2/3]
                    costs = D[tuple(indices)]
                    D[i, j] = cost + np.min(costs)
            # print(f"I am {v}-th reference, I have accumulated distance of {D[-1, -1]}")
            if D[-1, -1] < d_best:
                d_best = D[-1, -1]
                D_best = D
                v_best = v

        start = np.array([[M], [N + 1]])
        while start.sum() != 0:
            if start[1, 0] < N + 1:
                indices = start + sft_noup  # [2, 2/3]
                # S_container.append(T[start[0, 0] - 1])
                costs = D_best[tuple(indices)]
                move = sft_noup[:, np.argmin(costs), np.newaxis]
                start += move
            else:
                indices = start + sft  # [2, 2/3]
                costs = D_best[tuple(indices)]
                move = sft[:, np.argmin(costs), np.newaxis]
                start += move
                if start[1, 0] == N:
                    cut = start[0, 0]
        T = T[cut:]
        M = T.shape[0]
        if debug:
            yield v_best, M, d_container
        else:
            yield v_best, M


def serialized_videoParse(T, R, stop, sticky_threshold=2, debug=False):
    # post processing on videoParse results: merge consecutive same word if 1-length there
    # TODO: consider some CTC-style readout method
    v_container = []
    L_container = []
    M0 = T.shape[0]

    prev_v = -1
    prev_M = M0
    sticky = False
    for (v, M_) in videoParse(T, R, stop, debug=False):
        delta = prev_M - M_
        prev_M = M_
        if v == prev_v and (delta < sticky_threshold or sticky):  # to merge 1-length in behind
            L_container[-1] += delta
        elif v == prev_v and sticky:  # to merge after 1-length
            L_container[-1] += delta
        else:
            prev_v = v
            v_container.append(v)
            L_container.append(delta)
        sticky = delta < sticky_threshold

    cheese = M0
    for v, L in zip(v_container, L_container):
        cheese -= L
        yield v, cheese


def nms_score(len, score):
    # length is normed by N+1, thus min/(N+1) ~ N/(N+1)
    # score is centroid length N - distance accumulated along a path also with length N, 0~1
    # we want score to be dominant, while insensitive to score
    # we want length to be secondary, while sensitive to length
    # return 100 * math.sqrt(score) - math.log(1 - len)
    return score
# def get_overlap(a, b):
#     # both are tuple of (start, end)
#     return max(0, min(a[1], b[1]) - max(a[0], b[0])) / (max(a[1], b[1]) - min(a[0], b[0]))


def anchored_videoParse(T, R, overlap_threshold=0.1):
    # T, numpy array of shape [M, f]
    # R, numpy array of shape [V, N, f]
    M = T.shape[0]
    N = R.shape[1]
    sims = 1 - np.transpose(R @ T.transpose(), (0, 2, 1))  # [V, M, N], well, it is actually distance
    res_container = []
    for s in range(M - N):
        e = s + N
        cropped = sims[:, s:e]  # [V, N, N]
        for v, sim in enumerate(cropped):  # do grid building V times
            d = np.concatenate([sim, np.zeros((N, 1))], axis=1)  # [N, N+1]
            D = np.ones((N + 1, N + 2)) * np.inf
            D[0, 0] = 0
            for i in range(1, N + 1):
                for j in range(1, N + 2):
                    cost = d[i - 1, j - 1]
                    indices = np.array([[i], [j]]) + (sft_nodown if j < N + 1 else sft)  # [2, 2/3]
                    costs = D[tuple(indices)]
                    D[i, j] = cost + np.min(costs)

            start = np.array([[N], [N + 1]])
            while start.sum() != 0:
                if start[1, 0] < N + 1:
                    indices = start + sft_nodown  # [2, 2/3]
                    costs = D[tuple(indices)]
                    move = sft_nodown[:, np.argmin(costs), np.newaxis]
                    start += move
                else:
                    indices = start + sft  # [2, 2/3]
                    costs = D[tuple(indices)]
                    move = sft[:, np.argmin(costs), np.newaxis]
                    start += move
                    if start[1, 0] == N:
                        cut = start[0, 0]  # length of the match
            if cut > 5:  # we are never interested in snippet this short
                res_container.append(
                    [s, s + cut - 1, nms_score(cut / (N+1), (N - D[-1, -1]) / N), v])  # start and end both inclusive
    res_container = np.array(res_container)
    # lengths = res_container[:, 1] - res_container[:, 0] + 1

    # return res_container[:][:, [0, 1, 2, 3]].tolist()

    # NMS, two stage NMS should be equivalent to one-stage
    idxs = np.argsort(res_container[:, 2])
    pick = []
    while len(idxs) > 0:
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)

        xx1 = np.maximum(res_container[i][0], res_container[idxs[:last]][:, 0])
        xx2 = np.minimum(res_container[i][1], res_container[idxs[:last]][:, 1])
        intersection = np.maximum(0, xx2 - xx1 + 1)
        yy1 = np.minimum(res_container[i][0], res_container[idxs[:last]][:, 0])
        yy2 = np.maximum(res_container[i][1], res_container[idxs[:last]][:, 1])
        union = yy2 - yy1 + 1
        overlap = intersection / union
        # overlap = intersection / lengths[idxs[:last]]
        idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > overlap_threshold)[0])))

    return res_container[pick][:, [0, 1, 2, 3]].tolist()


class SnippetDistance():
    def __init__(self, TYPE, model):
        self.type = TYPE
        self.model = model
        raise NotImplementedError

    def __call__(self, a, b):
        T1 = a.size(0)
        T2 = b.size(0)
        res1 = self.model(a.unsqueeze(0), torch.tensor([T1]))
        res2 = self.model(b.unsqueeze(0), torch.tensor([T2]))
        forward_feat = res1[:, 0]  # [T1, f]
        forward_feat /= torch.linalg.norm(forward_feat, dim=-1, keepdim=True, ord=2)
        reverse_feat = res2[:, 0]
        reverse_feat /= torch.linalg.norm(reverse_feat, dim=-1, keepdim=True, ord=2)
        attmap = forward_feat @ reverse_feat.t()  # [T1, T2]
        # attmap = attmap.detach().numpy()
        return attmap
