"""
Causal Monarch, real numbers version.

Parameterized by R-tilde, and L-tilde:

R_tilde is an N X sqrt_N x sqrt_N tensor, where each sqrt_N x sqrt_N block looks like this (example is for sqrt_N = 8):

0 x 0 x 0 x 0 x
x 0 x 0 x 0 x 0
0 x 0 x 0 x 0 0
x 0 x 0 x 0 0 0
0 x 0 x 0 0 0 0
x 0 x 0 0 0 0 0
0 x 0 0 0 0 0 0
x 0 0 0 0 0 0 0

The x's can be learned.

L_tilde is also N x sqrt_N x sqrt_N tensor.

The *top* block is a flipped lower triangular, and the rest are zeros.

x x x x x x x x
x x x x x x x 0
x x x x x x 0 0
x x x x x 0 0 0
x x x x 0 0 0 0
x x x 0 0 0 0 0
x x 0 0 0 0 0 0
x 0 0 0 0 0 0 0

These are converted into block-diagonal R and L matrices, which are used to construct a Monarch matrix:

M = P @ L @ P @ R @ P

P is the butterfly permutation.

R and L are generated from R_tilde and L_tilde via the Chebyshev matrix.
Let C_N be a Chebyshev matrix of size N x N:

C_N[i,j] = cos(i * j * pi / N)

To get R from R_tilde, you just multiply each block by C_{sqrt_N} and convert it to block-diagonal.
To get L from L_tilde, you multiply P @ C_N @ L_tilde.

Next, we need to modify this matrix M (consequence of the cosine rule):

mod is a N x N matrix that has sqrt_N \times \sqrt_N blocks of 1s or -1s.
Let the indices of the block be 0 \leq i, j < sqrt_N.
The block is 1's if i * (sqrt_N - j - 1) is even, and -1's if i * (sqrt_N - j) is odd.
Example for sqrt_N = 4:

mod = [
    block[1]  block[1]  block[1]  block[1]
    block[-1] block[1]  block[-1] block[1]
    block[1]  block[-1] block[1]  block[1]
    block[-1] block[1]  block[-1] block[1]
]

M_prime = mod * M

Then M_prime @ u is a real-number, drop-in replacement for FFT(u), and this operation is causal in u:

    inv(M_prime) @ ((M_prime @ u) * (M_prime @ k))

We don't have to materialize M_prime to compute it.
Instead, we can do this trick (assume sqrt_N is even for now):

    Let u be split into sqrt_N blocks [u0, u1, ..., u_{sqrt_N-1}]

    Let u_even = [u0, -u1, u2, ..., -u_{sqrt_N - 1}]
    Let u_odd = [-u0, u1, -u2, ..., u_{sqrt_N - 1}]

    Let
        y_even = M @ u_even = [y_even_0, y_even_1, .., y_even_{sqrt_N - 1}],
        y_odd = M @ u_odd = [y_odd_0, y_odd_1, ..., y_odd_{sqrt_N - 1}]

    Then, we have:
        y = [y_even_0, y_odd_1, y_even_2, ..., y_odd_{sqrt_N - 1}]

Note: we should be able to save some compute here, but I haven't tested it yet.
"""

import torch
from einops import rearrange, repeat, einsum
import math
from torch import nn
from torch.nn import functional as F
from src.models.nn import Activation
from src.utils.train import OptimModule


def inverse_mat(N):
    return torch.linalg.inv(torch.flip(ref_cheb_matrix(N), dims=(1,)))

def ref_cheb_matrix(N):
    """Compute the Chebyshev matrix of size N x N.

    This is where we could add extra compute for free."""
    n = torch.arange(N)
    # n = torch.arange(N).cuda()
    k = n.view(-1, 1)
    M = torch.cos(torch.pi * n * k / N)
    return M


def get_P_matrix(N, sqrt_N):
    idx = torch.arange(N)
    idx = rearrange(idx, "(m n) -> (n m)", m=sqrt_N)

    P = torch.diag(torch.ones(N))[idx]

    return P


def monarch_R(N, sqrt_N, R_tilde_base, R_tilde, ref_mat_block):
    '''
    Convert from R_tilde to block-diagonal R - just matmul each block by DFT
    '''
    R_tilde = R_tilde.reshape(R_tilde_base.shape) + R_tilde_base

    return torch.bmm(ref_mat_block, R_tilde)


def monarch_L(N, sqrt_N, L_tilde_base, L_tilde, permuted_ref_mat):
    '''
    Convert from L_tilde to block-diagonal L - matmul by P @ DFT_N on the left
    '''
    L_tilde = L_tilde + L_tilde_base

    return (permuted_ref_mat @ L_tilde).reshape(sqrt_N, sqrt_N, sqrt_N)

def apply_permutation(x, sqrt_N):
    return rearrange(x, '... (m n) -> ... (n m)', n=sqrt_N)

def apply_monarch(x, block_R, block_L):
    sqrt_N = block_R.shape[0]
    N = sqrt_N ** 2
    
    # permute
    x = apply_permutation(x, sqrt_N)

    # block diagonal matmul, ignoring the zeros in the padding
    x = torch.einsum(
        # "bnm,...bm->...bn", block_R, x.view(*x.shape[:-1], sqrt_N, sqrt_N)
        "bnm,...bm->...bn", block_R, x.view(*x.shape[:-1], sqrt_N, sqrt_N // 2)
    ).reshape(*x.shape[:-1], N)

    # permute
    x = apply_permutation(x, sqrt_N)

    # block diagonal matmul
    x = torch.einsum(
        "bnm,...bm->...bn", block_L, x.view(*x.shape[:-1], sqrt_N, sqrt_N)
    ).reshape(x.shape)

    # permute
    x = apply_permutation(x, sqrt_N)

    return x

def block_fft(x, block_R, block_L, x_mod, x_even_mask, x_odd_mask):
    sqrt_N = block_R.shape[0]
    N = x_mod.shape[0]
    x1 = x
    x2 = x * x_mod[N // 2:]

    x1 = apply_monarch(x1, block_R, block_L)
    x2 = apply_monarch(x2, block_R, block_L)

    return x1 * x_even_mask + x2 * x_odd_mask


class CausalMonarch(OptimModule):
    """
    Learnable Causal Block FFT module.
    """

    def __init__(self, N, sqrt_N, learnable=True, dft_lr=0.001, dropout=0.0, init_scale=0.02):
        super().__init__()
        self.N = N
        self.sqrt_N = sqrt_N
        self.learnable = learnable
        self.dft_lr = dft_lr
        self.dropout = dropout

        assert sqrt_N**2 == N

        R_tilde_base = repeat(torch.flip(torch.diag(torch.ones(sqrt_N)), dims=(1,)), "n m -> b n m", b=sqrt_N)
        L_tilde_base = torch.flip(torch.diag(torch.ones(sqrt_N)), dims=(1,))

        # Chebyshev matrices to cache
        Cheb_sqrt_N = repeat(ref_cheb_matrix(sqrt_N), "n m -> b n m", b=sqrt_N)
        Cheb_N = get_P_matrix(N, sqrt_N) @ ref_cheb_matrix(N)[:, :sqrt_N]

        # x_mod to cache
        x_mod = torch.zeros(N).float()
        x_mask_even = torch.zeros(N).long()
        x_mask_odd = torch.zeros(N).long()
        for i in range(sqrt_N):
            if i % 2 == 0:
                x_mask_even[i * sqrt_N : (i + 1) * sqrt_N] = 1
            else:
                x_mask_odd[i * sqrt_N : (i + 1) * sqrt_N] = 1
        x_mod[x_mask_even == 1] = 1.
        x_mod[x_mask_odd == 1] = -1.

        '''
        R_tilde is an N X sqrt_N x sqrt_N tensor, where each sqrt_N x sqrt_N block looks like this (example is for sqrt_N = 8):

        0 x 0 x 0 x 0 x
        x 0 x 0 x 0 x 0
        0 x 0 x 0 x 0 0
        x 0 x 0 x 0 0 0
        0 x 0 x 0 0 0 0
        x 0 x 0 0 0 0 0
        0 x 0 0 0 0 0 0
        x 0 0 0 0 0 0 0

        The x's can be learned.
        '''
        R_tilde = torch.zeros(sqrt_N, sqrt_N, sqrt_N)
        R_tilde_mask = torch.zeros(sqrt_N, sqrt_N, sqrt_N)
        for i in range(sqrt_N):
            for offset in range(2, sqrt_N, 2):
                R_tilde_mask[i] += torch.diag(torch.ones(sqrt_N - offset), diagonal=-offset)
                R_tilde[i] += torch.diag(torch.randn(sqrt_N - offset), diagonal=-offset)
            
            R_tilde_mask[i] = torch.flip(R_tilde_mask[i], dims=(0,))
            R_tilde[i] = torch.flip(R_tilde[i], dims=(0,))

        '''
        L_tilde is a lower triangular sqrt_N x sqrt_N tensor, flipped along dimension 0.
        '''
        L_tilde = torch.flip(torch.tril(torch.randn(sqrt_N, sqrt_N)), dims=(0,))
        L_tilde_mask = torch.flip(torch.tril(torch.ones(sqrt_N, sqrt_N), diagonal=-1), dims=(0,))

        R_tilde = R_tilde * init_scale
        L_tilde = L_tilde * init_scale
        R_tilde_mask = R_tilde_mask
        L_tilde_mask = L_tilde_mask

        if learnable:
            assert dft_lr > 0, "dft_lr must be positive if learnable is True"

            self.register("R_tilde", R_tilde, self.dft_lr)
            self.register("L_tilde", L_tilde, self.dft_lr)
        else:
            self.R_tilde = R_tilde_mask * 0.0
            self.L_tilde = L_tilde_mask * 0.0

        self.register("R_tilde_mask", R_tilde_mask, 0.0)
        self.register("L_tilde_mask", L_tilde_mask, 0.0)
        self.register("R_tilde_base", R_tilde_base, 0.0)
        self.register("L_tilde_base", L_tilde_base, 0.0)
        self.register("Cheb_sqrt_N", Cheb_sqrt_N, 0.0)
        self.register("Cheb_N", Cheb_N, 0.0)
        self.register("x_mod", x_mod, 0.0)
        self.register("x_even_mask", x_mask_even, 0.0)
        self.register("x_odd_mask", x_mask_odd, 0.0)

    def forward(self, x):
        # only pad up to N / 2, the rest is implicitly zeros
        L = x.shape[-1]
        if x.shape[-1] != self.N // 2:
            x = nn.ConstantPad1d((0, self.N // 2 - x.shape[-1]), 0)(x)
        # x = nn.ConstantPad1d((self.N // 2, 0), 0)(x)

        R_tilde = self.R_tilde * self.R_tilde_mask
        L_tilde = self.L_tilde * self.L_tilde_mask

        # ignore the zeros, padding in front!
        block_R = monarch_R(
            self.N,
            self.sqrt_N,
            self.R_tilde_base,
            R_tilde,
            self.Cheb_sqrt_N,
        )[..., self.sqrt_N // 2:]
        # )

        block_L = monarch_L(
            self.N,
            self.sqrt_N,
            self.L_tilde_base,
            L_tilde,
            self.Cheb_N,
        )

        return block_fft(x, block_R, block_L, self.x_mod, self.x_even_mask, self.x_odd_mask)
    
    def get_M_mat(self):
        R_tilde = self.R_tilde * self.R_tilde_mask
        L_tilde = self.L_tilde * self.L_tilde_mask

        block_R = monarch_R(
            self.N,
            self.sqrt_N,
            self.R_tilde_base,
            R_tilde,
            self.Cheb_sqrt_N,
        )

        block_L = monarch_L(
            self.N,
            self.sqrt_N,
            self.L_tilde_base,
            L_tilde,
            self.Cheb_N,
        )

        L = torch.zeros(self.N, self.N)
        R = torch.zeros(self.N, self.N)
        for i in range(self.sqrt_N):
            # fill in the diagonal with blocks
            L[i * self.sqrt_N : (i + 1) * self.sqrt_N, i * self.sqrt_N : (i + 1) * self.sqrt_N] = block_L[i]

            R[i * self.sqrt_N : (i + 1) * self.sqrt_N, i * self.sqrt_N : (i + 1) * self.sqrt_N] = block_R[i]

        P = get_P_matrix(self.N, self.sqrt_N)

        M = P @ L @ P @ R @ P

        mod = torch.ones(self.N, self.N)
        for i in range(self.sqrt_N):
            for j in range(self.sqrt_N):
                if (i * j) % 2 != 0:
                    mod[i * self.sqrt_N : (i + 1) * self.sqrt_N, j * self.sqrt_N : (j + 1) * self.sqrt_N] = -1
        mod = torch.flip(mod, dims=(1,))
        
        return M * mod
    
    def get_conv_map(self, L, k):
        P = get_P_matrix(self.N, self.sqrt_N)

        R_tilde = self.R_tilde * self.R_tilde_mask
        L_tilde = self.L_tilde * self.L_tilde_mask

        block_R = monarch_R(
            self.N,
            self.sqrt_N,
            self.R_tilde_base,
            R_tilde,
            self.Cheb_sqrt_N,
        )

        block_L = monarch_L(
            self.N,
            self.sqrt_N,
            self.L_tilde_base,
            L_tilde,
            self.Cheb_N,
        )

        diag_R = torch.zeros(self.N, self.N)
        diag_L = torch.zeros(self.N, self.N)
        for i in range(self.sqrt_N):
            diag_R[i * self.sqrt_N : (i + 1) * self.sqrt_N, i * self.sqrt_N : (i + 1) * self.sqrt_N] = block_R[i]
            diag_L[i * self.sqrt_N : (i + 1) * self.sqrt_N, i * self.sqrt_N : (i + 1) * self.sqrt_N] = block_L[i]

        pad_mat = torch.zeros(self.N, L)
        pad_mat[self.N // 2:self.N // 2 + L, :] = torch.diag(torch.ones(L))
        unpad_mat = torch.zeros(L, self.N)
        unpad_mat[:, :L] = torch.diag(torch.ones(N))

        M_real = P @ diag_L @ P @ diag_R @ P

        mod = torch.ones(self.N, self.N)
        for i in range(self.sqrt_N):
            for j in range(self.sqrt_N):
                if (i * j) % 2 != 0:
                    mod[i * self.sqrt_N : (i + 1) * self.sqrt_N, j * self.sqrt_N : (j + 1) * self.sqrt_N] = -1
        mod = torch.flip(mod, dims=(1,))

        M_prime = mod * M_real

        inv = inverse_mat(self.N)
        k_f_diag = torch.diag(M_prime @ pad_mat @ k[0][0])

        conv_map = unpad_mat @ inv @ k_f_diag @ M_prime @ pad_mat
        return conv_map


if __name__ == "__main__":

    import matplotlib.pyplot as plt

    B = 32
    H = 13
    N = 8
    k = torch.randn(B, H, N)

    print(f"(B, H, N) = ({B}, {H}, {N})")

    # test causality
    cbf_learnable = CausalMonarch(16, 4, learnable=True)

    u = torch.randn(B, H, N)
    u_new = torch.randn(B, H, N)
    
    M_inv = torch.linalg.inv(torch.flip(ref_cheb_matrix(16), dims=(1,)))
    
    y = torch.einsum("nm,...m->...n", M_inv, cbf_learnable(u) * cbf_learnable(k))[..., :N]

    new_ys = []
    for i in range(1, max(8, N + 1)):
        u[..., -i] = u_new[..., -i]
        new_ys.append(torch.einsum("nm,...m->...n", M_inv, cbf_learnable(u) * cbf_learnable(k))[..., :N])

    new_ys = torch.flip(torch.stack(new_ys), (0,))

    diff = y - new_ys

    plt.imshow(diff[:, B // 2, H // 2, :].detach().numpy(), vmin=-1, vmax=1)
    plt.colorbar()
    plt.savefig("causality.png")
    plt.show()

    print(torch.max(torch.tril(torch.abs(diff[:, B // 2, H // 2, :]), -1)))

    plt.imshow(cbf_learnable.get_conv_map(N, k).detach().numpy(), vmin=-0.1, vmax=0.1)
    plt.colorbar()
    plt.savefig('conv_map.png')
    plt.show()

