import math

import torch
import torch.nn as nn
from einops import rearrange

import sys
import os

SAFARI_PATH = os.environ.get("SAFARI_PATH", None)
sys.path.append(SAFARI_PATH)

from src.models.sequence.hyena import HyenaFilter
from src.models.sequence.causal_monarch import CausalBlockFFT
from src.models.sequence.causal_monarch_real import CausalMonarch, inverse_mat
from src.utils.config import auto_assign_attrs

@torch.jit.script
def _mul_sum(y, q):
    return (y * q).sum(dim=1)

class MonarchFilter(HyenaFilter):
    def __init__(
        self,
        d_model,
        seq_len,
        dft_lr=1e-3,
        dft_dropout=0.0,
        monarch_L=None,
        monarch_sqrt_L=None,
        real=False,
        head_dim=1,
        **kwargs,
    ):
        super().__init__(d_model=d_model, seq_len=seq_len, **kwargs)
        auto_assign_attrs(
            self,
            d_model=d_model,
            seq_len=seq_len,
        )

        self.monarch_L = monarch_L
        self.monarch_sqrt_L = monarch_sqrt_L
        self.real = real
        self.head_dim = head_dim

        if real:
            self.monarch = CausalMonarch(
                N=self.monarch_L,
                sqrt_N=self.monarch_sqrt_L,
                learnable=True,
                dft_lr=dft_lr,
                dropout=dft_dropout,
            )
            self.inv = inverse_mat(self.monarch_L)
        else:
            self.monarch = CausalBlockFFT(
                N=self.monarch_L,
                sqrt_N=self.monarch_sqrt_L,
                learnable=True,
                dft_lr=dft_lr,
                dropout=dft_dropout,
            )

    def run_fftconv(self, x, filter):
        if not self.real:
            k_m = self.monarch(filter.to(torch.cfloat))

            if len(x.shape) > 3:
                k_m = k_m.unsqueeze(-2)
            u_m = self.monarch(x.to(torch.cfloat))

            y = (
                torch.fft.ifft(u_m * k_m, dim=-1).real[..., :L]
            )
        else:
            k_m = self.monarch(filter)
            u_m = self.monarch(x)

            y = torch.einsum("nm,...m->...n", self.inv, u_m * k_m)[..., :L]
        
        return y

    def forward(self, x, L, filter=None, bias=None, x1=None, x2=None, *args, **kwargs):
        if self.head_dim > 1:
            assert x1 is not None and x2 is not None
        if filter is None:
            filter = self.filter(L)

        # Ensure compatibility with filters that return a tuple
        filter = filter[0] if type(filter) is tuple else filter

        if bias is None:
            bias = self.bias

        bias = bias if self.use_bias else 0 * bias
        bias = bias.squeeze()

        if self.head_dim == 1:
            y = self.run_fftconv(x, filter) + bias[..., None, None] * x[
                ..., :L
            ]
        else:
            kv = (rearrange(x, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim)
                * rearrange(x1, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim))  # b d1 d2 h l
            y = self.run_fftconv(kv, filter) + bias[..., None, None, None] * kv[..., :L]
            q = rearrange(x1, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim)
            
            y = _mul_sum(y, q)

        return y


if __name__ == "__main__":
    L = 128
    bs = 1
    D = 16

    l = MonarchFilter(D, L, monarch_N=256, monarch_sqrt_N=16)
    x = torch.randn(bs, 1, D, 1, L)
    y = l(x, L, k=l.filter(L).permute(0, 2, 1))
    print(y.shape)

    l_heads = MonarchFilter(D, L, monarch_N=256, monarch_sqrt_N=16, head_dim=4)
    x = torch.randn(bs, D, L)
    x1 = torch.randn(bs, D, L)
    x2 = torch.randn(bs, D, L)

    y = l_heads(x, L, x1=x1, x2=x2, k=l_heads.filter(L).permute(0, 2, 1))
    print(y.shape)