# -*- coding: utf-8 -*-


import numpy as np

def robust_ar_statistics(Ez, tau, x, y, J, h):
    T, K = Ez.shape
    D = tau.shape[2]
    N = x.shape[1]

    for k in range(K):
        for t in range(T):
            for d in range(D):
                for m in range(N):
                    for n in range(N):
                        J[k, d, m, n] += Ez[t, k] * tau[t, k, d] * x[t, m] * x[t, n]

                    h[k, d, m] += Ez[t, k] * tau[t, k, d] * x[t, m] * y[t, d]

def _blocks_to_bands_lower(Ad, Aod):
    T, D = Ad.shape[0], Ad.shape[1]
    L = np.zeros((2 * D, T * D))

    for t in range(T):
        for u in range(2 * D):
            for d in range(D):
                j = t * D + d
                i = u + j

                trow = i // D
                drow = i % D

                if trow >= T:
                    continue

                if t == trow:
                    L[u, j] = Ad[t, drow, d]
                elif t == trow - 1:
                    L[u, j] = Aod[t, drow, d]

    return np.asarray(L)

def _blocks_to_bands_upper(Ad, Aod):
    T, D = Ad.shape[0], Ad.shape[1]
    U = np.zeros((2 * D, T * D))

    for t in range(T):
        for u in range(2 * D):
            for d in range(D):
                j = t * D + d
                i = u + j - (2 * D - 1)

                if i < 0:
                    continue

                trow = i // D
                drow = i % D

                if trow >= T:
                    continue

                if t == trow:
                    U[u, j] = Ad[t, drow, d]
                elif t == trow + 1:
                    U[u, j] = Aod[t-1, drow, d]

    return np.asarray(U)

def _bands_to_blocks_lower(A_banded):
    D = A_banded.shape[0] // 2
    T = A_banded.shape[1] // D
    Ad = np.zeros((T, D, D))
    Aod = np.zeros((T-1, D, D))

    for t in range(T):
        for u in range(2 * D):
            for d in range(D):
                j = t * D + d
                i = u + j

                trow = i // D
                drow = i % D

                if trow >= T:
                    continue

                if t == trow:
                    Ad[t, drow, d] = A_banded[u, j]
                elif t == trow - 1:
                    Aod[t, drow, d] = A_banded[u, j]

    return np.asarray(Ad), np.asarray(Aod)

def _bands_to_blocks_upper(A_banded):
    D = A_banded.shape[0] // 2
    T = A_banded.shape[1] // D
    Ad = np.zeros((T, D, D))
    Aod = np.zeros((T-1, D, D))

    for t in range(T):
        for u in range(2 * D):
            for d in range(D):
                j = t * D + d
                i = u + j - (2 * D - 1)

                if i < 0:
                    continue

                trow = i // D
                drow = i % D

                if trow >= T:
                    continue

                if t == trow:
                    Ad[t, drow, d] = A_banded[u, j]
                elif t == trow + 1:
                    Aod[t-1, drow, d] = A_banded[u, j]

    return np.asarray(Ad), np.asarray(Aod)

def _transpose_banded(l, u, A_banded):
    D, N = A_banded.shape[0], A_banded.shape[1]
    A_banded_T = np.zeros_like(A_banded)

    for d in range(D):
        for j in range(N):
            i = d + j - l
            if i < 0 or i >= N:
                continue

            A_banded_T[d, j] = A_banded[D-1-d, i]

    return np.asarray(A_banded_T)

def vjp_cholesky_banded_lower(L_bar, L_banded, A_banded, A_bar):
    D, N = A_banded.shape[0], A_banded.shape[1]

    for i in range(N-1, -1, -1):
        for j in range(i, max(i-D, -1), -1):
            if j == i:
                A_bar[0, j] = 0.5 * L_bar[0, j] / L_banded[0, j]
            else:
                A_bar[i-j, j] = L_bar[i - j, j] / L_banded[0, j]
                L_bar[0, j] -= L_bar[i - j, j] * L_banded[i - j, j] / L_banded[0, j]

            for k in range(j-1, max(i-D, -1), -1):
                L_bar[i-k, k] -= A_bar[i-j, j] * L_banded[j-k, k]
                L_bar[j-k, k] -= A_bar[i-j, j] * L_banded[i-k, k]

def _vjp_solve_banded_A(A_bar, b_bar, C_bar, C, u, A_banded):
    D, N, K = A_banded.shape[0], A_banded.shape[1], C_bar.shape[1]

    for d in range(D):
        for j in range(N):
            i = d + j - u
            if i >= 0 and i < N:
                for k in range(K):
                    A_bar[d, j] -= b_bar[i, k] * C[j, k]

def _vjp_solveh_banded_A(A_bar, b_bar, C_bar, C, lower, A_banded):
    D, N, K = A_banded.shape[0], A_banded.shape[1], C_bar.shape[1]

    for j in range(N):
        for d in range(D):
            i = d + j if lower else d + j - D + 1
            if i < 0 or i >= N:
                continue

            for k in range(K):
                A_bar[d, j] -= b_bar[i, k] * C[j, k]
                if i != j:
                    A_bar[d, j] -= b_bar[j, k] * C[i, k]
