''' Operators w.r.t. the model '''
import numpy as np
import torch


def Vec(M):
    return M.reshape(-1, 1, order='C')


def Vec_inv(M, shape):
    return M.reshape(shape, order='C')


''' R operator for matrix '''


def R_opt(M, blockshape, stride):
    D1, D2 = M.shape
    d1, d2 = blockshape
    s1, s2 = stride
    assert (D1 - d1) % s1 == 0 and (D2 -d2) % s2 == 0, "Dislocation. The number of step (D1 - d1) / s1 and (D2 - d2) / s2 should be integer."
    RM = []
    i, j = 0, 0
    while i + d1 <= D1:
        while j + d2 <= D2:
            block = M[i: i + d1, j: j + d2]
            RM.append(Vec(block))
            j += s2
        else:
            i += s1
            j = 0

    return np.concatenate(RM, axis=1).T


def R_inv(RM, blockshape, idctshape):
    m, n = RM.shape
    d1, d2 = blockshape
    p1, p2 = idctshape
    assert m == p1 * p2 and n == d1 * d2, "Dimension wrong"
    M = np.zeros([d1 * p1, d2 * p2])
    for i in range(m):
        Block = Vec_inv(RM[i, :], blockshape)
        ith = i // p2  # quotient
        jth = i % p2  # remainder
        M[d1*ith: d1*(ith+1), d2*jth: d2*(jth+1)] = Block

    return M


def Rearrange_Covmat(Sigma, shape, kernel_size, stride):
    P, D = shape
    index = np.arange(P * D).reshape(shape)
    Rindex = R_opt(index, blockshape=kernel_size, stride=stride)
    Rindex_vec = Rindex.ravel()
    new_dim = len(Rindex_vec)
    RSigma = np.zeros((new_dim, new_dim))
    for i in range(new_dim):
        for j in range(new_dim):
            RSigma[i, j] = Sigma[Rindex_vec[i], Rindex_vec[j]]

    return RSigma


def conv_vector(input_vector, kernel, stride):
    N = 1
    C_in, C_out = 1, 1
    L = len(input_vector)
    function = torch.nn.Conv1d(C_in, C_out, kernel_size=kernel.shape, stride=stride)
    function.weight.data = torch.tensor(kernel, dtype=torch.float32).reshape(C_out, C_in, *kernel.shape)
    function.bias.data = torch.zeros(C_out)
    input = torch.tensor(input_vector, dtype=torch.float32).reshape(N, C_in, L)
    output = function(input)
    output = torch.squeeze(output)

    return output.detach().numpy()


def conv_matrix(input_matrix, kernel, stride):
    N = 1
    C_in, C_out = 1, 1
    H, W = input_matrix.shape
    function = torch.nn.Conv2d(C_in, C_out, kernel_size=kernel.shape, stride=stride)
    function.weight.data = torch.tensor(kernel, dtype=torch.float32).reshape(C_out, C_in, *kernel.shape)
    function.bias.data = torch.zeros(C_out)
    input = torch.tensor(input_matrix, dtype=torch.float32).reshape(N, C_in, H, W)
    output = function(input)
    output = torch.squeeze(output)

    return output.detach().numpy()

