import math
import torch 
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from opt_einsum import contract


class Sin(nn.Module):
    def __init__(self, dim, w=1, train_freq=True):
        super().__init__()
        self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)
    def forward(self, x):
        return torch.sin(self.freq * x)
    
    
class PositionalEncoding(nn.Module):
    def __init__(self, dim, train=True):
        super().__init__()
        # dim must be even and >= 4
        assert dim % 2 == 0 and dim >= 4, "'dim' must be even and >= 4"
        self.f = nn.Parameter(2 * torch.pi * torch.arange((dim - 1) // 2)) if train else 2 * torch.pi * torch.arange((dim - 1) // 2)
    def forward(self, t):
        length = t.shape[-1]
        z = torch.exp(1j * torch.outer(t / length, self.f))
        z = torch.cat([t[:, None], t[:, None].flip(0), z.real, z.imag], dim=-1)
        return z


class ExponentialWindow(nn.Module):
    def __init__(self, num_filters, train_decay=True):
        super().__init__()
        self.decay_rate = nn.Parameter(1e-2 * torch.arange(num_filters).float() / num_filters) if train_decay else 1e-2 * torch.arange(num_filters).float() / num_filters
    def forward(self, x):
        t = torch.arange(x.shape[0], device=x.device)
        z = torch.exp(-torch.outer(t, self.decay_rate.abs()))
        y = x * z
        return y


class ImplicitTransferFunction(nn.Module):
    def __init__(self, num_order : int, 
                       den_order : int, 
                       num_filters : int=1, 
                       heads : int=1,
                       decay_rate : float=1e-2,
                       pos_enc_dim : int=32,
                       filter_order : int=64,
                       sin_freq : float=1,
                       real_fft : bool=False,
                       train_mixer : bool=False):
        super().__init__()
        self.num_order = num_order + 1
        self.den_order = den_order 
        self.num_filters = num_filters
        self.heads = heads

        self.heads_mixer = torch.nn.Parameter(torch.randn(heads)) if train_mixer else torch.ones(heads)
        self.decay_rate = torch.nn.Parameter(torch.Tensor([decay_rate]))
        self.fft = torch.fft.rfft if real_fft else torch.fft.fft
        self.norm_factors = nn.Parameter(torch.randn(den_order))
        self.eps = nn.Parameter(torch.Tensor([1e-2]))
        self.num = nn.Parameter(torch.randn(self.num_order, num_filters, heads))
        self.den = nn.Parameter(torch.randn(self.den_order, num_filters, heads))

    def forward(self, L):
        a, b = self._eval()
        P = self.fft(a, dim=0, n=L)
        Q = self.fft(b, dim=0, n=L)
        H = Q / P
        H = torch.sum(H * self.heads_mixer, dim=-1)
        return H
    
    def _eval(self):
        self.device = self.decay_rate.device
        num, den = self.num, self.den
        l1_norm = (torch.sum(torch.abs(den), dim=0, keepdim=True) + F.relu(self.eps))
        norm_factors = torch.clamp(self.norm_factors, -0.99, 0.99)
        den = den / l1_norm * norm_factors[:, None, None]
        den = torch.cat([torch.ones(1, self.num_filters, self.heads, device=self.device), den], dim=0)
        return den, num
    
    
def polyroots(p, return_companion=False):
    """
    Return the roots of a polynomial with coefficients given in p.
    The implementation is based on the numpy.roots function. 
    More info at: https://numpy.org/doc/stable/reference/generated/numpy.roots.html
    Args:
        p (Tensor): (heads, order) Coefficients of the polynomial.
        return_companion (bool): If True, the companion matrix is returned as well.
    Returns:
        Tensor: (heads, order) Roots of the polynomial.
    """
    # check if the polynomial has the right shape
    if len(p.shape) != 2:
        raise ValueError("The polynomial must be a second order tensor (heads, order).")
    # check if the polynomial is valid
    elif p.shape[-1] < 2:
        raise ValueError("A polynomial must have at least 2 coefficients.")
    # casting: if incoming tensor isn't floating point, make it floating point.
    if not torch.is_floating_point(p):
        p = p.to(torch.get_default_dtype())
    # build companion matrix and find its eigenvalues (the roots)
    # The companion matrix is a square matrix with the polynomial 
    # coefficients as its first row and ones below the main diagonal.
    # The eigenvalues of the companion matrix are the roots of the polynomial.
    # More info at: https://en.wikipedia.org/wiki/Companion_matrix
    heads, order = p.shape
    c = torch.zeros((heads, order-1, order-1), dtype=p.dtype, device=p.device)
    c[:,0] = -p[:,1:] / p[:,:1]
    c[:, list(range(1, order-1)), list(range(0, order-2))] = 1
    # compute the eigenvalues of the companion matrix (last dimension of c)
    roots = torch.linalg.eigvals(c)
    # return the roots
    if return_companion:
        return roots, c
    return roots


def get_direct_path(a, b):
    b0 = b[0]
    beta = b[1:] - b0 * a[1:]
    return b0, beta


def step(x, u, a, beta, b0, w): 
    # x: b d n h
    # u: b d
    # a: n+1 d h
    # beta: n d h
    # b0: d h
    a_ = a[1:] # n d h
    y = contract('n d h, b d n h -> b d h', beta, x) + contract('d h, b d -> b d h', b0, u)
    y = contract('b d h, h -> b d', y, w)
    low_rank = contract('n d h, b d n h-> b d h', a_, x)
    x = torch.roll(x, 1, dims=2)
    x[:, :, 0, :] = u[..., None] - low_rank
    return y, x


def step_iir(x, u, a, beta, b0, w):
    # u: b, h, d
    # x: b, d, n, h
    
    u = u[:, 0]
    a_ = a[1:]

    # TODO add support for u heads     
    y = contract('n d h, b d n h -> b d h', beta, x) + contract('d h, b d -> b d h', b0, u)
    y = contract('b d h, h -> b d', y, w)
    
    low_rank = contract('n d h, b d n h -> b d h', a_, x)
    x = torch.roll(x, 1, dims=2)
    x[..., 0, :] = u[..., None] - low_rank
    return y[:, None], x



def step_fir(x, u, b):
    """_summary_

    Args:
        x (_type_): _description_
        u (_type_): _description_
        b (_type_): N + 1, D
        bias (_type_): D
    """
    y = contract('n d, b d n -> b d', b[1:], x) + contract('d, b d -> b d', b[0], u)
    x = torch.roll(x, 1, dims=2)
    x[..., 0] = u 
    return y, x
