import math

import numpy as np
from torch.nn import functional as F


def sweep_path(N):

    def zigzag_path_lr(N, start_row=0, start_col=0, dir_row=1, dir_col=1):
        path = []
        for i in range(N):
            for j in range(N):
                col = j
                path.append((start_row + dir_row * i) * N + start_col + dir_col * col)
        return path

    def zigzag_path_tb(N, start_row=0, start_col=0, dir_row=1, dir_col=1):
        path = []
        for j in range(N):
            for i in range(N):
                row = i
                path.append((start_row + dir_row * row) * N + start_col + dir_col * j)
        return path

    paths = []
    for start_row, start_col, dir_row, dir_col in [
        (0, 0, 1, 1),
        (0, N - 1, 1, -1),
        (N - 1, 0, -1, 1),
        (N - 1, N - 1, -1, -1),
    ]:
        paths.append(zigzag_path_lr(N, start_row, start_col, dir_row, dir_col))
        paths.append(zigzag_path_tb(N, start_row, start_col, dir_row, dir_col))

    for _index, _p in enumerate(paths):
        paths[_index] = np.array(_p)
    return paths


def zigma_path(N):

    def zigzag_path_lr(N, start_row=0, start_col=0, dir_row=1, dir_col=1):
        path = []
        for i in range(N):
            for j in range(N):
                # If the row number is even, move right; otherwise, move left
                col = j if i % 2 == 0 else N - 1 - j
                path.append((start_row + dir_row * i) * N + start_col + dir_col * col)
        return path

    def zigzag_path_tb(N, start_row=0, start_col=0, dir_row=1, dir_col=1):
        path = []
        for j in range(N):
            for i in range(N):
                # If the column number is even, move down; otherwise, move up
                row = i if j % 2 == 0 else N - 1 - i
                path.append((start_row + dir_row * row) * N + start_col + dir_col * j)
        return path

    paths = []
    for start_row, start_col, dir_row, dir_col in [
        (0, 0, 1, 1),
        (0, N - 1, 1, -1),
        (N - 1, 0, -1, 1),
        (N - 1, N - 1, -1, -1),
    ]:
        paths.append(zigzag_path_lr(N, start_row, start_col, dir_row, dir_col))
        paths.append(zigzag_path_tb(N, start_row, start_col, dir_row, dir_col))

    for _index, _p in enumerate(paths):
        paths[_index] = np.array(_p)
    return paths


def jpeg_zigzag(N):
    def zigzag_path_lr(N, start_row, start_col, dir_row, dir_col):
        # initializing the variables
        # ----------------------------------
        h = 0
        v = 0

        vmin = 0
        hmin = 0

        vmax = N  # input.shape[0]
        hmax = N  # input.shape[1]
        # print(vmax, hmax)

        i = 0

        # output = np.zeros(( vmax * hmax))
        indices = []
        # ----------------------------------

        while (v < vmax) and (h < hmax):
            # indices.append(v*vmax + h)
            indices.append((start_row + dir_row * v) * vmax + start_col + dir_col * h)
            # output[i] = input[v, h]        # if we got to the first line

            if ((h + v) % 2) == 0:  # going up

                if v == vmin:
                    # print(1)

                    if h == hmax:
                        v = v + 1
                    else:
                        h = h + 1

                elif (h == hmax - 1) and (v < vmax):  # if we got to the last column
                    # print(2)

                    v = v + 1

                elif (v > vmin) and (h < hmax - 1):  # all other cases
                    # print(3)
                    v = v - 1
                    h = h + 1

            else:  # going down

                if (v == vmax - 1) and (h <= hmax - 1):  # if we got to the last line
                    # print(4)
                    h = h + 1

                elif h == hmin:  # if we got to the first column
                    # print(5)

                    if v == vmax - 1:
                        h = h + 1
                    else:
                        v = v + 1

                elif (v < vmax - 1) and (h > hmin):  # all other cases
                    # print(6)
                    v = v + 1
                    h = h - 1

            i = i + 1
            if (v == vmax - 1) and (h == hmax - 1):  # bottom right element
                # print(7)
                # output[i] = input[v, h]
                # indices.append(v*vmax + h)
                indices.append((start_row + dir_row * v) * vmax + start_col + dir_col * h)
                break

        return np.array(indices)

    def zigzag_path_tb(N, start_row, start_col, dir_row, dir_col):
        # initializing the variables
        # ----------------------------------
        h = 0
        v = 0

        vmin = 0
        hmin = 0

        vmax = N  # input.shape[0]
        hmax = N  # input.shape[1]
        # print(vmax, hmax)

        i = 0

        # output = np.zeros(( vmax * hmax))
        indices = []
        # ----------------------------------

        while (v < vmax) and (h < hmax):
            indices.append((start_row + dir_row * v) * vmax + start_col + dir_col * h)
            # indices.append(v*vmax + h)
            # output[i] = input[v, h]        # if we got to the first line

            if ((h + v) % 2) == 0:  # going up

                if h == hmin:  # if we got to the first column
                    # print(5)

                    if v == vmax - 1:
                        h = h + 1
                    else:
                        v = v + 1

                elif (v == vmax - 1) and (h <= hmax - 1):  # if we got to the last line
                    # print(4)
                    h = h + 1

                elif (v < vmax - 1) and (h > hmin):  # all other cases
                    # print(6)
                    v = v + 1
                    h = h - 1

            else:  # going down

                if (h == hmax - 1) and (v < vmax):  # if we got to the last column
                    # print(2)

                    v = v + 1

                elif v == vmin:
                    # print(1)

                    if h == hmax:
                        v = v + 1
                    else:
                        h = h + 1

                elif (v > vmin) and (h < hmax - 1):  # all other cases
                    # print(3)
                    v = v - 1
                    h = h + 1

            i = i + 1
            if (v == vmax - 1) and (h == hmax - 1):  # bottom right element
                # print(7)
                # output[i] = input[v, h]
                indices.append((start_row + dir_row * v) * vmax + start_col + dir_col * h)
                # indices.append(v*vmax + h)
                break

        return np.array(indices)

    paths = []
    for start_row, start_col, dir_row, dir_col in [
        (0, 0, 1, 1),
        (0, N - 1, 1, -1),
        (N - 1, 0, -1, 1),
        (N - 1, N - 1, -1, -1),
    ]:
        paths.append(zigzag_path_lr(N, start_row, start_col, dir_row, dir_col))
        paths.append(zigzag_path_tb(N, start_row, start_col, dir_row, dir_col))

    for _index, _p in enumerate(paths):
        paths[_index] = np.array(_p)
    return paths


def reverse_permut_np(permutation):
    n = len(permutation)
    reverse = np.array([0] * n)
    for i in range(n):
        reverse[permutation[i]] = i
    return reverse

def inverse_jpeg_zigzag(input, vmax, hmax):

    # print input.shape

    # initializing the variables
    # ----------------------------------
    h = 0
    v = 0

    vmin = 0
    hmin = 0

    output = np.zeros((vmax, hmax))

    i = 0
    # ----------------------------------

    while (v < vmax) and (h < hmax):
        # print ('v:',v,', h:',h,', i:',i)
        if ((h + v) % 2) == 0:  # going up

            if v == vmin:
                # print(1)
                output[v, h] = input[i]  # if we got to the first line

                if h == hmax:
                    v = v + 1
                else:
                    h = h + 1

                i = i + 1

            elif (h == hmax - 1) and (v < vmax):  # if we got to the last column
                # print(2)
                output[v, h] = input[i]
                v = v + 1
                i = i + 1

            elif (v > vmin) and (h < hmax - 1):  # all other cases
                # print(3)
                output[v, h] = input[i]
                v = v - 1
                h = h + 1
                i = i + 1

        else:  # going down

            if (v == vmax - 1) and (h <= hmax - 1):  # if we got to the last line
                # print(4)
                output[v, h] = input[i]
                h = h + 1
                i = i + 1

            elif h == hmin:  # if we got to the first column
                # print(5)
                output[v, h] = input[i]
                if v == vmax - 1:
                    h = h + 1
                else:
                    v = v + 1
                i = i + 1

            elif (v < vmax - 1) and (h > hmin):  # all other cases
                output[v, h] = input[i]
                v = v + 1
                h = h - 1
                i = i + 1

        if (v == vmax - 1) and (h == hmax - 1):  # bottom right element
            # print(7)
            output[v, h] = input[i]
            break

    return output


"""PyTorch code for local scan and local reverse"""


def local_scan(x, w=7, H=14, W=14, flip=False, column_first=False):
    B, L, C = x.shape
    x = x.view(B, H, W, C)
    Hg, Wg = math.ceil(H / w), math.ceil(W / w)
    if H % w != 0 or W % w != 0:
        newH, newW = Hg * w, Wg * w
        x = F.pad(x, (0, 0, 0, newW - W, 0, newH - H))
    if column_first:
        x = x.view(B, Hg, w, Wg, w, C).permute(0, 3, 1, 4, 2, 5).reshape(B, -1, C)
    else:
        x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, -1, C)
    if flip:
        x = x.flip([1])
    return x


def local_scan_bchw(x, w=7, H=14, W=14, flip=False, column_first=False):
    """Local windowed scan in LocalMamba
    Input:
        x: [B, C, H, W]
        H, W: original width and height before padding
        column_first: column-wise scan first (the additional direction in VMamba)
    Return: [B, C, L]
    """
    B, C, _, _ = x.shape
    x = x.view(B, C, H, W)
    Hg, Wg = math.ceil(H / w), math.ceil(W / w)
    if H % w != 0 or W % w != 0:
        newH, newW = Hg * w, Wg * w
        x = F.pad(x, (0, newW - W, 0, newH - H))
    if column_first:
        x = x.view(B, C, Hg, w, Wg, w).permute(0, 1, 4, 2, 5, 3).reshape(B, C, -1)
    else:
        x = x.view(B, C, Hg, w, Wg, w).permute(0, 1, 2, 4, 3, 5).reshape(B, C, -1)
    if flip:
        x = x.flip([-1])
    return x


def local_reverse(x, w=7, H=14, W=14, flip=False, column_first=False):
    """Local windowed scan in LocalMamba
    Input:
        x: [B, L, C]
        H, W: original width and height before padding
        column_first: column-wise scan first (the additional direction in VMamba)
    Return: [B, L, C]
    """
    B, L, C = x.shape
    Hg, Wg = math.ceil(H / w), math.ceil(W / w)
    if flip:
        x = x.flip([1])
    if H % w != 0 or W % w != 0:
        if column_first:
            x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, C, Hg * w, Wg * w)
        else:
            x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, C, Hg * w, Wg * w)
        x = x[:, :H, :W, :].reshape(B, -1, C)
    else:
        if column_first:
            x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C)
        else:
            x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C)
    return x


SCAN_ZOO = {
    "sweep": sweep_path,
    "zigma": zigma_path,
    "jpeg": jpeg_zigzag,
}
