import math

import torch
import torch.nn.functional as F
from torch.autograd import grad
import pytest
import os
import sys

SAFARI_PATH = os.environ.get("SAFARI_PATH", None)
sys.path.append(SAFARI_PATH)

from src.models.sequence.monarch_conv import MonarchFilter


def test_monarch_conv_causality():
    # set seed
    torch.random.manual_seed(0)

    L = 128
    bs = 2
    D = 4

    l = MonarchFilter(D, L, monarch_L=256, monarch_sqrt_L=16, learnable=True, real=True)
    # batch size, num heads, head_dim, num_blocks, block_dim
    x = torch.randn(bs, 1, D, 1, L, requires_grad=True)
    y = l(x, L, k=l.filter(L).permute(0, 2, 1))

    # gradients from the future must be approx. zero
    for i in range(L-1):
        g = grad(y[0, 0, 0, 0, i], x, retain_graph=True, allow_unused=True)[0]
        assert torch.max(torch.abs(g[0, 0, 0, 0, i + 1 :])) < 1e-4, "function is not causal"
        # assert torch.allclose(
        #     g[0, 0, 0, 0, i + 1 :], torch.zeros_like(g[0, 0, 0, 0, i + 1 :]), atol=1e-5
        # ), "function is not causal"
