import torch
import math
import torch.nn as nn
from SNN.spike_neuron import LMHTNeuron, DTIFNeuron
from typing import Iterable
def dummy(x):
    return x
class ExpNegPWL32Full(nn.Module):
    xbnd_ref = None
    slope_ref = None
    bias_ref = None
    y_hi_ref = None
    y_lo_ref = None

    def __init__(self, xmin=-5.0, xmax=5.0, nseg=64):
        super().__init__()
        self.xmin, self.xmax = xmin, xmax
        if ExpNegPWL32Full.xbnd_ref is None:
            xbnd = torch.linspace(xmin, xmax, nseg + 1)
            ybnd = torch.exp(-xbnd)
            width = (xmax - xmin) / nseg
            slope = (ybnd[1:] - ybnd[:-1]) / width
            bias = ybnd[:-1] - slope * xbnd[:-1]
            ExpNegPWL32Full.xbnd_ref = xbnd
            ExpNegPWL32Full.slope_ref = slope
            ExpNegPWL32Full.bias_ref = bias
            ExpNegPWL32Full.y_hi_ref = torch.tensor(0.0)
            ExpNegPWL32Full.y_lo_ref = torch.exp(torch.tensor(5.0))

    def forward(self, x):
        device = x.device
        xbnd = ExpNegPWL32Full.xbnd_ref.to(device)
        slope = ExpNegPWL32Full.slope_ref.to(device)
        bias = ExpNegPWL32Full.bias_ref.to(device)
        y_hi = ExpNegPWL32Full.y_hi_ref.to(device)
        y_lo = ExpNegPWL32Full.y_lo_ref.to(device)

        idx = torch.bucketize(x, xbnd[1:-1])
        y_mid = slope[idx] * x + bias[idx]
        return torch.where(
            x < self.xmin, y_lo,
            torch.where(x > self.xmax, y_hi, y_mid)
        )

class SiLU4bitFromExp(nn.Module):
    _shared_expneg = None

    def __init__(self):
        super().__init__()
        self.expneg = ExpNegPWL32Full()
        self.register_buffer("step", torch.tensor(256.0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        e_neg = self.expneg(x).to(x.dtype)
        division = DTIFNeuron(16, dummy, T=4, avg=self.avg)
        y_q = division(x, (1.0 + e_neg))
        return y_q

class Softmax8bitFromExp(nn.Module):
    _shared_exp = None

    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim
        self.expneg = ExpNegPWL32Full()
        self.register_buffer("step", torch.tensor(256.0))

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        z = logits - logits.max(dim=self.dim, keepdim=True).values
        x = z + 5.0
        e_pos = self.expneg(-x).to(logits.dtype)
        division = DTIFNeuron(4, dummy, T=4, avg=self.avg)
        y_q = division(e_pos,  e_pos.sum(dim=self.dim, keepdim=True))
        return y_q

def calc_K(n_iter: int = 12) -> float:
    K = 1.0
    for k in range(n_iter):
        K *= math.sqrt(1 + 2 ** (-2 * k))
    return K
_K12 = calc_K(12)

def cordic_hypot_pair_no_scale(x: torch.Tensor,
                               y: torch.Tensor,
                               n_iter: int = 12) -> torch.Tensor:
    xi, yi = x.clone(), y.clone()
    one  = torch.tensor(1.0,  dtype=xi.dtype, device=xi.device)
    mone = torch.tensor(-1.0, dtype=xi.dtype, device=xi.device)
    for k in range(n_iter):
        di = torch.where(yi >= 0, one, mone)
        x_shift, y_shift = yi / (1 << k), xi / (1 << k)
        xi = xi + di * x_shift
        yi = yi - di * y_shift
    return xi.abs()

def cordic_l2(v: torch.Tensor,
              eps: float = 0.0,
              n_iter: int = 12) -> torch.Tensor:
    K = _K12 if n_iter == 12 else calc_K(n_iter)
    K_tensor = v.new_tensor(K)
    if eps > 0.0:
        eps_val = torch.full_like(v[..., :1], math.sqrt(eps))
        v = torch.cat([eps_val, v], dim=-1)
    v = v.abs().sort(dim=-1).values
    while v.size(-1) > 1:
        D = v.size(-1)
        half = D // 2
        left  = v[..., :2*half:2]
        right = v[..., 1:2*half:2]
        v_out = cordic_hypot_pair_no_scale(left, right, n_iter) / K_tensor
        if D % 2 == 1:
            v = torch.cat([v_out, v[..., -1:]], dim=-1)
        else:
            v = v_out
    return v.squeeze(-1)

class snnRMSNorm_new(nn.Module):
    def __init__(self, ori_norm, T: int = 2, avg: bool = True,
                 n_iter: int = 8, q_bits: int = 8):
        super().__init__()
        self.register_buffer("weight", ori_norm.weight)
        self.bias = None
        self.variance_epsilon = ori_norm.variance_epsilon
        self.use_temporary_parameter = False
        self.output_bits = ori_norm.output_bits
        self.out_features = self.weight.shape[-1]
        self.T = T
        self.avg = avg
        self.n_iter = n_iter
        self.q_bits = q_bits
        self.step = float(1 << q_bits)
        L = math.ceil((2 ** self.output_bits - 1) / T)
        self.output_quantizer = LMHTNeuron(
            L, ori_norm.output_quantizer, T=T, avg=self.avg)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        BT, D, L = x.shape
        B = BT // self.T
        x = x.view(self.T, B, D, L)
        weight = self.temp_weight if self.use_temporary_parameter else self.weight
        step = self.step
        if x.dtype == torch.float16:
            x = x.to(torch.float32)
            fp16_out = True
        else:
            fp16_out = False
        X_acc  = torch.zeros_like(x[0])
        Y_prev = torch.zeros_like(X_acc)
        seg_out = []
        sqrt_D = math.sqrt(X_acc.size(-1))
        for t in range(self.T):
            X_acc = X_acc + x[t]
            r = cordic_l2(X_acc, self.variance_epsilon*X_acc.size(-1), self.n_iter)
            inv_std = sqrt_D / r
            Y_float = X_acc * inv_std.unsqueeze(-1)
            Y_q = torch.floor(Y_float * weight * step) / step
            Y = Y_q.to(weight.dtype)
            seg_out.append(Y - Y_prev)
            Y_prev = Y
        Out = torch.stack(seg_out, dim=0)
        Out = self.output_quantizer(Out)
        if fp16_out:
            Out = Out.to(torch.float16)
        return Out.view(-1, D, L)
