"""Implementation of LaughingHyena mixers. 

The main idea is to distill (post-training) Hyena filters into a particular form of transfer function,
which admits a recurrence and enables fast generation."""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from opt_einsum import contract

import src.utils.registry as registry
from src.utils.train import OptimModule
from src.utils.config import instantiate, auto_assign_attrs
from src.models.nn import Activation

try:
    from flash_attn.ops.fused_dense import FusedDense
except ImportError:
    FusedDense = None

from src.models.sequence.hyena import mul_sum


class LaughingHyenaFilter(nn.Module):
    def __init__(
        self,
        num_order: int,
        den_order: int,
        num_filters: int = 1,
        heads: int = 1,
        decay_rate: float = 1e-2,
        real_fft: bool = False,
        train_mixer: bool = False,
    ):
        super().__init__()
        self.num_order = num_order + 1
        self.den_order = den_order
        self.num_filters = num_filters
        self.heads = heads

        self.heads_mixer = (
            torch.nn.Parameter(torch.randn(heads)) if train_mixer else torch.ones(heads)
        )
        self.decay_rate = torch.nn.Parameter(torch.Tensor([decay_rate]))
        self.fft = torch.fft.rfft if real_fft else torch.fft.fft
        self.norm_factors = nn.Parameter(torch.randn(den_order))
        self.eps = nn.Parameter(torch.Tensor([1e-2]))
        self.num = nn.Parameter(torch.randn(self.num_order, num_filters, heads))
        self.den = nn.Parameter(torch.randn(self.den_order, num_filters, heads))

    def filter(self, L, *args, **kwargs):
        H = self(L)[None]
        return (
            torch.fft.irfft(H, dim=1).real
            if self.fft == torch.fft.rfft
            else torch.fft.ifft(H, dim=1).real
        )

    def forward(self, L):
        a, b = self._eval()
        P = self.fft(a, dim=0, n=L)
        Q = self.fft(b, dim=0, n=L)
        H = Q / P
        H = torch.sum(H * self.heads_mixer, dim=-1)
        return H

    def get_params(self):
        a, b = self._eval()
        w = self.heads_mixer
        return a, b, w

    def _eval(self):
        self.device = self.decay_rate.device
        num, den = self.num, self.den
        l1_norm = torch.sum(torch.abs(den), dim=0, keepdim=True) + F.relu(self.eps)
        norm_factors = torch.clamp(self.norm_factors, -0.99, 0.99)
        den = den / l1_norm * norm_factors[:, None, None]
        den = torch.cat(
            [torch.ones(1, self.num_filters, self.heads, device=self.device), den],
            dim=0,
        )
        return den, num


def polyroots(p, return_companion=False):
    """
    Return the roots of a polynomial with coefficients given in p.
    The implementation is based on the numpy.roots function.
    More info at: https://numpy.org/doc/stable/reference/generated/numpy.roots.html
    Args:
        p (Tensor): (heads, order) Coefficients of the polynomial.
        return_companion (bool): If True, the companion matrix is returned as well.
    Returns:
        Tensor: (heads, order) Roots of the polynomial.
    """
    # check if the polynomial has the right shape
    if len(p.shape) != 2:
        raise ValueError("The polynomial must be a second order tensor (heads, order).")
    # check if the polynomial is valid
    elif p.shape[-1] < 2:
        raise ValueError("A polynomial must have at least 2 coefficients.")
    # casting: if incoming tensor isn't floating point, make it floating point.
    if not torch.is_floating_point(p):
        p = p.to(torch.get_default_dtype())
    # build companion matrix and find its eigenvalues (the roots)
    # The companion matrix is a square matrix with the polynomial
    # coefficients as its first row and ones below the main diagonal.
    # The eigenvalues of the companion matrix are the roots of the polynomial.
    # More info at: https://en.wikipedia.org/wiki/Companion_matrix
    heads, order = p.shape
    c = torch.zeros((heads, order - 1, order - 1), dtype=p.dtype, device=p.device)
    c[:, 0] = -p[:, 1:] / p[:, :1]
    c[:, list(range(1, order - 1)), list(range(0, order - 2))] = 1
    # compute the eigenvalues of the companion matrix (last dimension of c)
    roots = torch.linalg.eigvals(c)
    # return the roots
    if return_companion:
        return roots, c
    return roots


def get_direct_path(a, b):
    "Computes coefficient of the delay-free path of a system in canonical form"
    b0 = b[0]
    beta = b[1:] - b0 * a[1:]
    return b0, beta


def step_iir(x, u, a, beta, b0, w):
    """Advances the state of an IIR filter.

    Args:
        x: (b, d, n, h) Tensor containing the previous states of the filter.
        u: (b, d, h) Tensor containing the input to the filter.
        a: (n+1, d, h) Tensor containing the filter coefficients.
        beta: (n, d, h) Tensor containing the filter coefficients.
        b0: (d, h) Tensor containing the filter coefficients.
        w: (d, h) Tensor containing the filter coefficients.
    """
    u = u[:, 0]
    a_ = a[1:]

    # TODO add support for u heads
    y = contract("n d h, b d n h -> b d h", beta, x) + contract(
        "d h, b d -> b d h", b0, u
    )
    y = contract("b d h, h -> b d", y, w)

    low_rank = contract("n d h, b d n h -> b d h", a_, x)
    x = torch.roll(x, 1, dims=2)
    x[..., 0, :] = u[..., None] - low_rank
    return y[:, None], x


def step_fir(x, u, b):
    """Advances the state of a FIR filter.

    Args:
        x: (b, d, n) Tensor containing the previous states of the filter.
        u: (b, d) Tensor containing the input to the filter.
        b: (n+1, d) Tensor containing the filter coefficients.
    """
    b0, b = b[0], b[1:]
    y = contract("n d, b d n -> b d", b, x) + contract("d, b d -> b d", b0, u)
    x = torch.roll(x, 1, dims=2)
    x[..., 0] = u
    return y, x


class LaughingHyenaOperator(nn.Module):
    def __init__(
        self,
        d_model,
        l_max,
        order=2,
        filter_order=64,
        num_heads=1,
        inner_factor=1,
        num_blocks=1,
        fused_bias_fc=False,
        outer_mixing=False,
        dropout=0.0,
        filter_dropout=0.0,
        filter_cls="hyena-filter",
        post_order_ffn=False,
        jit_filter=False,
        short_filter_order=3,
        activation="id",
        return_state=False,
        **filter_args,
    ):
        r"""
        Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf

        Args:
            d_model (int): Dimension of the input and output embeddings (width of the layer)
            l_max: (int): Maximum input sequence length. Defaults to None
            order: (int): Depth of the Hyena recurrence. Defaults to 2
            filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64
            num_heads: (int): Number of heads. Defaults to 1
            inner_factor: (int): Width multiplier. Defaults to 1
            num_blocks: (int): Number of blocks in sequence length. Defaults to 1
            fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False
            dropout: (float): Dropout probability. Defaults to 0.0
            filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
            post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False
            jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False
            short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3
            activation: (str): type of act between kernel output and FF (default identity)
            return_state: (bool): whether to return a state
        """
        super().__init__()
        assert (
            d_model % num_heads == 0
        ), f"Model dimension {d_model} must be divisible by num heads {num_heads}"
        assert (
            l_max % num_blocks == 0
        ), f"Maximum signal length {l_max} must be divisible by block dimension {num_blocks}"
        block_dim = l_max // num_blocks
        head_dim = d_model // num_heads

        auto_assign_attrs(
            self,
            d_model=d_model,
            order=order,
            l_max=l_max,
            num_heads=num_heads,
            inner_factor=inner_factor,
            block_dim=block_dim,
            head_dim=head_dim,
            filter_order=filter_order,
            post_order_ffn=post_order_ffn,
            short_filter_order=short_filter_order,
            num_blocks=num_blocks,
            filter_dropout=filter_dropout,
            jit_filter=jit_filter,
            outer_mixing=outer_mixing,
            activation=activation,
            return_state=return_state,
        )
        self.activation = Activation(activation)
        self.dropout = nn.Dropout(dropout)
        self.setup_projections(fused_bias_fc, inner_factor)
        self.setup_filters(filter_cls, filter_args)

    def setup_projections(self, fused_bias_fc, inner_factor):
        "Initializes input and output projections (over the width dimension)"
        if fused_bias_fc and FusedDense is None:
            raise ImportError("fused_dense is not installed")
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
        self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model)
        self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model)
        if self.post_order_ffn:
            self.ord_proj_w = nn.Parameter(
                torch.randn(self.order, self.num_heads, self.num_heads)
                / math.sqrt(self.head_dim)
            )

    def setup_filters(self, filter_cls, filter_args):
        "Initializes the explicit and implicit filters"
        assert self.order >= 2, f"Order must be at least 2, (got {self.order})"
        total_width = self.d_model * self.inner_factor * (self.order + 1)

        self.short_filter = nn.Conv1d(
            in_channels=total_width,
            out_channels=total_width,
            kernel_size=self.short_filter_order,
            groups=total_width,
            padding=self.short_filter_order - 1,
        )

        filter_cls = instantiate(registry.layer, filter_cls, partial=True)
        self.k = None

        self.filter_fn = filter_cls(
            self.head_dim * self.inner_factor * (self.order - 1),
            order=self.filter_order,
            seq_len=self.l_max,
            channels=1,
            dropout=self.filter_dropout,
            **filter_args,
        )
        if self.jit_filter:
            self.filter_fn = torch.jit.script(self.filter_fn, self.L)

    def _reset_recurrence(self):
        if hasattr(self, "state"):
            del self.state
        if hasattr(self, "state_short"):
            del self.state_short
        if hasattr(self, "recurrent_mode"):
            del self.recurrent_mode

    def setup_recurrence(self, a, b, w, state=None, state_short=None, bsz=None):
        self._reset_recurrence()
        self.recurrent_mode = True
        b0, beta = get_direct_path(a, b)
        self.register_buffer("b0", b0)
        self.register_buffer("beta", beta)

        b_short = self.short_filter.weight[
            :, 0
        ]  # assuming depthwise i.e out_channels = 1
        b_short = b_short.flip(1).permute(1, 0)
        self.register_buffer("b_short", b_short)

        bias_short = self.short_filter.bias[None]
        self.register_buffer("bias_short", bias_short)

        self.register_buffer("a", a)  # d_state + 1, D, H

        state_dim = a.shape
        assert (bsz is None) ^ (state is None), "Either bsz or state must be None"

        state_short = (
            state_short
            if state_short is not None
            else torch.zeros(bsz, b_short.shape[1], 2)
        )
        self.register_buffer("state_short", state_short)
        state = (
            state
            if state is not None
            else torch.zeros(bsz, state_dim[1], state_dim[0] - 1, state_dim[2])
        )
        self.register_buffer("state", state)
        out_head_mix = w
        self.register_buffer("out_head_mix", out_head_mix)

    def forward_recurrent(self, u, *args, **kwargs):
        u = self.in_proj(u)
        u = u[:, 0]  # b d

        uc, self.state_short = step_fir(self.state_short, u, self.b_short)
        uc = uc + self.bias_short

        uc = rearrange(
            uc,
            "b (ho v) -> b ho v",
            ho=self.num_heads,
            v=self.head_dim * (self.order + 1),
        )

        *x, v = uc.split(self.d_model, dim=2)
        bias = rearrange(
            self.filter_fn.bias, "(v o) -> o v", v=self.head_dim, o=self.order - 1
        )

        for o, x_i in enumerate(reversed(x[1:])):
            if self.outer_mixing:
                v = rearrange(v, "b h v z -> b h 1 v z")
                v = self.dropout(v * rearrange(x_i, "b h v z -> b h v 1 z"))
                v = v.sum(dim=2)
            else:
                v = self.dropout(v * x_i)

            v_pre = v
            v, self.state = step_iir(
                self.state, v, self.a, self.beta, self.b0, self.out_head_mix
            )
            v = v + v_pre * bias[o]

            if self.post_order_ffn:
                w = self.ord_proj_w[o]
                v = mul_sum(
                    rearrange(w, "h1 h2 -> 1 h1 h2 1 1 1"),
                    rearrange(v, "b h v z -> b h 1 v z"),
                )

        y = self.activation(rearrange(v * x[0], "b h v -> b (h v)", h=self.num_heads))
        y = self.out_proj(y)[:, None]  # b 1 d

        if self.return_state:
            return y, None
        return y

    def forward_conv(self, u, *args, **kwargs):
        l = u.size(-2)
        l_filter = min(l, self.l_max)

        u = self.in_proj(u)

        u = rearrange(u, "b l d -> b d l")

        uc = self.short_filter(u)[..., :l_filter]

        uc = rearrange(
            uc,
            "b (ho v) (z l) -> b ho v z l",
            z=self.num_blocks,
            ho=self.num_heads,
            v=self.head_dim * (self.order + 1),
        )

        *x, v = uc.split(self.d_model, dim=2)
        k = self.k if self.k is not None else self.filter_fn.filter(l_filter)

        # `c` is always 1 by default
        k = rearrange(k, "c l (v o) -> c o v l", v=self.head_dim, o=self.order - 1)[0]

        bias = rearrange(
            self.filter_fn.bias, "(v o) -> o v", v=self.head_dim, o=self.order - 1
        )

        for o, x_i in enumerate(reversed(x[1:])):
            if self.outer_mixing:
                v = rearrange(v, "b h v z l -> b h 1 v z l")
                v = self.dropout(v * rearrange(x_i, "b h v z l -> b h v 1 z l"))
                v = v.sum(dim=2)
            else:
                v = self.dropout(v * x_i)

            # the bias term is broadcasted. Last dimension (l) is handled by fftconv
            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])

            if self.post_order_ffn:
                w = self.ord_proj_w[o]
                v = mul_sum(
                    rearrange(w, "h1 h2 -> 1 h1 h2 1 1 1"),
                    rearrange(v, "b h v z l -> b h 1 v z l"),
                )

        y = self.activation(
            rearrange(
                v * x[0],
                "b h v z l -> b (z l) (h v)",
                z=self.num_blocks,
                h=self.num_heads,
            )
        )
        y = self.out_proj(y)

        if self.return_state:
            return y, None
        return y

    def forward(self, u, *args, **kwargs):
        if hasattr(self, "recurrent_mode"):
            return self.forward_recurrent(u, *args, **kwargs)
        else:
            return self.forward_conv(u, *args, **kwargs)

    @property
    def d_output(self):
        return self.d_model


if __name__ == "__main__":
    # the recurrent mode only takes a single step and advanced by a single step
    for batch_size in [1, 16]:
        x = torch.randn(batch_size, 1, 768).to("cuda")
        l = LaughingHyenaOperator(
            d_model=768,
            num_heads=1,
            order=2,
            num_blocks=1,
            l_max=1024,
            return_state=False,
        )
        filter = LaughingHyenaFilter(
            num_order=6,
            den_order=6,
            num_filters=768,
            heads=16,
            decay_rate=1e-2,
            real_fft=False,
            train_mixer=True,
        )
        a, b, w = filter.get_params()
        print(a.shape, b.shape, w.shape)
        l.setup_recurrence(a, b, w, bsz=batch_size)
        # the layer must be moved to device after initializing the recurrence, which sets some new parameters and buffers
        l = l.to("cuda")

        y = l(x)
