import torch
import torch.nn as nn

from .LearnableMembrane import LearnableMembrane

class LearnablePolynomialMembrane(LearnableMembrane):
    def __init__(self, poly_degree: int = 3, bias: bool = False):
        """
        A learnable polynomial membrane update function for SNNs.
        Args:
            poly_degree (int): Highest degree of the polynomial (default: 3).
            bias (bool): If True, includes a learnable bias term (default: True).
        """
        super().__init__()
        assert poly_degree > 0, "Polynomial degree must be at least 1"

        self.poly_degree = poly_degree
        self.weight, self.bias = self.init_weight(poly_degree, bias)

    def init_weight(self, poly_degree: int, use_bias: bool):
        weight = torch.zeros(poly_degree, dtype=torch.float32)
        # for i in range(poly_degree):
        #     weight[i] = torch.randn(1).item() * (0.5 ** (i + 1))
            # weight[i] = i + 1
        weight[0] = -0.25
        bias = torch.zeros(1, dtype=torch.float32)
        return nn.Parameter(weight, requires_grad=True), nn.Parameter(bias, requires_grad=use_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.clamp(-1, 1)  # Keep within stable range
        return poly_eval(x, self.weight, self.bias)

    @torch.no_grad()
    def plot(self, x: torch.Tensor | None = None):
        if x is None:
            x = torch.linspace(-1, 1, 100, device=self.weight.device, dtype=self.weight.dtype)
        y = self.forward(x)
        return x.view(-1, 1).cpu(), y.view(-1, 1).cpu()

    def print_parameters(self):
        print("Polynomial Coefficients:")
        terms = [f"{self.bias.item():+.4f}"]  # constant term
        for i, coeff in enumerate(self.weight):
            terms.append(f"{coeff.item():+.4f}x^{i+1}")
        print(" ".join(terms))

    def get_parameters(self):
        terms = [f"{self.bias.item():+.4f}"]
        for i, coeff in enumerate(self.weight):
            terms.append(f"{coeff.item():+.4f}x^{i+1}")
        return " ".join(terms)


# Evaluates: w_0 + w_1 * x + w_2 * x^2 + ... + w_n * x^n
@torch.compile
def poly_eval(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    out = weight[-1]
    n = weight.shape[0]
    for i in range(n - 2, -1, -1):
        out = out * x + weight[i]

    out = out * x + bias
    return out

# def poly_eval(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
#     powers = torch.arange(1, weight.shape[0] + 1, device=x.device, dtype=x.dtype).view(1, -1)
#     x_powers = x.unsqueeze(-1) ** powers  # Shape: (N, poly_degree)
#     out = (x_powers * weight).sum(dim=-1) + bias  # Shape: (N,)
#     return out
        
