"""
Causal Monarch, complex numbers version.

Parameterized by R-tilde, and L-tilde:

R_tilde is an N X sqrt_N x sqrt_N tensor, where each sqrt_N x sqrt_N block looks like this (example is for sqrt_N = 8):

x 0 0 0 0 0 0 0
x x 0 0 0 0 0 0
x x x 0 0 0 0 0
x x x x 0 0 0 0
0 0 0 0 x 0 0 0
0 0 0 0 x x 0 0
0 0 0 0 x x x 0
0 0 0 0 x x x x

The x's can be learned.

L_tilde is also N x sqrt_N x sqrt_N tensor.

The *top* block is lower triangular, and the rest are zeros.

x 0 0 0 0 0 0 0
x x 0 0 0 0 0 0
x x x 0 0 0 0 0
x x x x 0 0 0 0
x x x x x 0 0 0
x x x x x x 0 0
x x x x x x x 0
x x x x x x x x

These are converted into block-diagonal R and L matrices, which are used to replace the FFT:

output = P @ L @ P @ R @ P @ input

To get R from R_tilde, you just multiply each block by the sqrt_N DFT matrix.
To get L from L_tilde, you multiply P @ DFT_N @ L_tilde.

P is the butterfly permutation.
"""

import torch
from einops import rearrange, repeat, einsum
import math
from torch import nn
from torch.nn import functional as F
from src.models.nn import Activation
from src.utils.train import OptimModule


def ref_dft_matrix(N):
    """Compute the DFT matrix of size N x N.

    This is where we could add extra compute for free."""
    n = torch.arange(N)
    # n = torch.arange(N).cuda()
    k = n.view(-1, 1)
    M = torch.exp(-2j * torch.pi * n * k / N)
    return M


def get_P_matrix(N, sqrt_N):
    idx = torch.arange(N)
    idx = rearrange(idx, "(m n) -> (n m)", m=sqrt_N)

    P = torch.diag(torch.ones(N, dtype=torch.cfloat))[idx]

    return P


def monarch_R(N, sqrt_N, R_tilde_base, R_tilde, ref_mat_block):
    '''
    Convert from R_tilde to block-diagonal R - just matmul each block by DFT
    '''
    R_tilde = R_tilde.reshape(R_tilde_base.shape) + R_tilde_base

    return torch.bmm(ref_mat_block, R_tilde)


def monarch_L(N, sqrt_N, L_tilde_base, L_tilde, permuted_ref_mat):
    '''
    Convert from L_tilde to block-diagonal L - matmul by P @ DFT_N on the left
    '''
    L_tilde = L_tilde + L_tilde_base

    return (permuted_ref_mat @ L_tilde).reshape(sqrt_N, sqrt_N, sqrt_N)

def apply_permutation(x, sqrt_N):
    return rearrange(x, '... (m n) -> ... (n m)', n=sqrt_N)

def block_fft(x, block_R, block_L):
    sqrt_N = block_R.shape[0]
    N = sqrt_N ** 2
    
    # permute
    x = apply_permutation(x, sqrt_N)

    # block diagonal matmul, ignoring the zeros in the padding
    x = torch.einsum(
        # "bnm,...bm->...bn", block_R, x.view(*x.shape[:-1], sqrt_N, sqrt_N)
        "bnm,...bm->...bn", block_R, x.view(*x.shape[:-1], sqrt_N, sqrt_N // 2)
    ).reshape(*x.shape[:-1], N)

    # permute
    x = apply_permutation(x, sqrt_N)

    # block diagonal matmul
    x = torch.einsum(
        "bnm,...bm->...bn", block_L, x.view(*x.shape[:-1], sqrt_N, sqrt_N)
    ).reshape(x.shape)

    # permute
    x = apply_permutation(x, sqrt_N)

    return x


class CausalBlockFFT(OptimModule):
    """
    Learnable Causal Block FFT module.
    """

    def __init__(self, N, sqrt_N, learnable=True, dft_lr=0.001, dropout=0.0, init_scale=0.02):
        super().__init__()
        self.N = N
        self.sqrt_N = sqrt_N
        self.learnable = learnable
        self.dft_lr = dft_lr
        self.dropout = dropout

        assert sqrt_N**2 == N

        R_tilde_base = torch.view_as_real(
            repeat(torch.diag(torch.ones(sqrt_N)), "n m -> b n m", b=sqrt_N).to(
                torch.cfloat
            )
        )
        L_tilde_base = torch.view_as_real(
            torch.diag(torch.ones(sqrt_N)).to(torch.cfloat)
        )

        FFT_sqrt_N = torch.view_as_real(
            repeat(ref_dft_matrix(sqrt_N), "n m -> b n m", b=sqrt_N)
        )
        FFT_N = torch.view_as_real(get_P_matrix(N, sqrt_N) @ ref_dft_matrix(N)[:, :sqrt_N])

        '''
        R_tilde is an N X sqrt_N x sqrt_N tensor, where each sqrt_N x sqrt_N block looks like this (example is for sqrt_N = 8):

        x 0 0 0 0 0 0 0
        x x 0 0 0 0 0 0
        x x x 0 0 0 0 0
        x x x x 0 0 0 0
        0 0 0 0 x 0 0 0
        0 0 0 0 x x 0 0
        0 0 0 0 x x x 0
        0 0 0 0 x x x x

        The x's can be learned.
        '''
        R_tilde = torch.zeros(sqrt_N, sqrt_N, sqrt_N)
        R_tilde_mask = torch.zeros(sqrt_N, sqrt_N, sqrt_N)
        for i in range(sqrt_N):
            R_tilde[i][: sqrt_N // 2, : sqrt_N // 2] = torch.tril(
                torch.randn(sqrt_N // 2, sqrt_N // 2)
            )
            R_tilde[i][sqrt_N // 2 :, sqrt_N // 2 :] = torch.tril(
                torch.randn(sqrt_N // 2, sqrt_N // 2)
            )

            R_tilde_mask[i][: sqrt_N // 2, : sqrt_N // 2] = torch.tril(
                torch.ones(sqrt_N // 2, sqrt_N // 2), diagonal=-1
            )
            R_tilde_mask[i][sqrt_N // 2 :, sqrt_N // 2 :] = torch.tril(
                torch.ones(sqrt_N // 2, sqrt_N // 2), diagonal=-1
            )

        '''
        L_tilde is a lower triangular sqrt_N x sqrt_N tensor.
        '''
        L_tilde = torch.tril(torch.randn(sqrt_N, sqrt_N))
        L_tilde_mask = torch.tril(torch.ones(sqrt_N, sqrt_N), diagonal=-1)

        R_tilde = torch.view_as_real(F.dropout(R_tilde, p=self.dropout).to(torch.cfloat)) * init_scale
        L_tilde = torch.view_as_real(F.dropout(L_tilde, p=self.dropout).to(torch.cfloat)) * init_scale
        R_tilde_mask = torch.view_as_real(R_tilde_mask.to(torch.cfloat))
        L_tilde_mask = torch.view_as_real(L_tilde_mask.to(torch.cfloat))

        if learnable:
            assert dft_lr > 0, "dft_lr must be positive if learnable is True"

            self.register("R_tilde", R_tilde, self.dft_lr)
            self.register("L_tilde", L_tilde, self.dft_lr)
        else:
            self.R_tilde = R_tilde_mask * 0.0
            self.L_tilde = L_tilde_mask * 0.0

        self.register("R_tilde_mask", R_tilde_mask, 0.0)
        self.register("L_tilde_mask", L_tilde_mask, 0.0)
        self.register("R_tilde_base", R_tilde_base, 0.0)
        self.register("L_tilde_base", L_tilde_base, 0.0)
        self.register("FFT_sqrt_N", FFT_sqrt_N, 0.0)
        self.register("FFT_N", FFT_N, 0.0)

    def forward(self, x):
        # only pad up to N / 2, the rest is implicitly zeros
        L = x.shape[-1]
        if x.shape[-1] != self.N // 2:
            x = nn.ConstantPad1d((0, self.N // 2 - x.shape[-1]), 0)(x)
        # if x.shape[-1] != self.N:
            # x = nn.ConstantPad1d((0, self.N - x.shape[-1]), 0)(x)

        R_tilde = torch.view_as_complex(self.R_tilde * self.R_tilde_mask)
        L_tilde = torch.view_as_complex(self.L_tilde * self.L_tilde_mask)

        # ignore the zeros
        block_R = monarch_R(
            self.N,
            self.sqrt_N,
            torch.view_as_complex(self.R_tilde_base),
            R_tilde,
            torch.view_as_complex(self.FFT_sqrt_N),
        )[..., :self.sqrt_N // 2]
        # )

        block_L = monarch_L(
            self.N,
            self.sqrt_N,
            torch.view_as_complex(self.L_tilde_base),
            L_tilde,
            torch.view_as_complex(self.FFT_N),
        )

        return block_fft(x, block_R, block_L)


if __name__ == "__main__":

    import sys
    import os

    import matplotlib.pyplot as plt

    B = 32
    H = 13
    N = 128
    k = torch.randn(B, H, N).to(torch.cfloat)

    print(f"(B, H, N) = ({B}, {H}, {N})")

    # test FFT
    cbf = CausalBlockFFT(256, 16, learnable=False)
    k_f = cbf(k)

    k_f_ref = torch.fft.fft(k, n=2 * N, dim=-1)

    print(torch.max(torch.abs(k_f - k_f_ref)))

    # test causality
    cbf_learnable = CausalBlockFFT(256, 16, learnable=True)

    u = torch.randn(B, H, N).to(torch.cfloat)
    u_new = torch.randn(B, H, N).to(torch.cfloat)

    y = torch.fft.ifft(cbf_learnable(u) * cbf_learnable(k)).real[..., :N]

    new_ys = []
    for i in range(1, max(10, N + 1)):
        u[..., -i] = u_new[..., -i]
        new_ys.append(torch.fft.ifft(cbf_learnable(u) * cbf_learnable(k)).real[..., :N])

    new_ys = torch.flip(torch.stack(new_ys), (0,))

    diff = y - new_ys

    plt.imshow(diff[:, B // 2, H // 2, :].detach().numpy())
    plt.colorbar()
    plt.savefig("causality.png")

    print(torch.max(torch.tril(torch.abs(diff[:, B // 2, H // 2, :]), -1)))
