import torch
import torch.nn as nn

def pseudo_log2(neo):
    return torch.log2(torch.clamp(neo.float(), min=1.0))

########## ste functions ##########
def fake_quant_4bit(x):
    x_scaled = x * 16.0
    q = torch.round(x_scaled)
    q = torch.clamp(q, 0.0, 15.0)
    q = q / 16.0
    return (q - x).detach() + x

def fake_quant_signed_int(x, min_val=-8, max_val=7):
    q = torch.round(torch.clamp(x, min_val, max_val))
    return (q - x).detach() + x

def fake_floor(x):
    y = torch.floor(x)
    return (y - x).detach() + x
########## ste functions ##########

class NeuralSignalCodecFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, neo, alpha, beta, gamma, boundary_params, required_quant=True, required_lr_window=True):
        B, L = x.shape
        num_windows, latent_dim = alpha.shape

        # --- compute boundaries ---
        if required_lr_window:
            safe_boundary_params = boundary_params.clamp(-10.0, 10.0)
            sigma = torch.sigmoid(safe_boundary_params)  # (num_windows-1,)
            sigma = torch.clamp(sigma, 0.05, 0.95)

            boundaries = [0.0]
            for s in sigma:
                next_b = boundaries[-1] + (1.0 - boundaries[-1]) * s
                boundaries.append(next_b)
            boundaries.append(1.0)
            bound_t = torch.tensor(boundaries, device=x.device) * L
            bound_list = [int(b.item()) for b in bound_t]
        else:
            bound_list = [int(i * L / num_windows) for i in range(num_windows + 1)]

        # --- quantize params ---
        if required_quant:
            q_alpha = fake_quant_4bit(alpha)
            q_beta  = fake_quant_signed_int(beta)
            q_gamma = fake_quant_4bit(gamma)
        else:
            q_alpha, q_beta, q_gamma = alpha, beta, gamma

        window_outputs, window_meta, actual_bounds = [], [], []
        for i in range(num_windows):
            start, end = bound_list[i], bound_list[i + 1]
            actual_bounds.append((start, end))
            if end > start:
                x_win   = x[:, start:end]
                neo_win = neo[:, start:end]
            else:
                x_win   = torch.zeros_like(x[:, :1])
                neo_win = torch.zeros_like(neo[:, :1])

            output, meta = NeuralSignalCodecFunction.compute_window(
                x_win, neo_win, q_alpha[i], q_beta[i], q_gamma[i], required_quant
            )
            window_outputs.append(output)
            window_meta.append(meta)

        ctx.save_for_backward(x, alpha, beta, gamma, boundary_params)
        ctx.scalings   = [m['scaling'] for m in window_meta]
        ctx.x_wins     = [m['x_win'] for m in window_meta]
        ctx.neo_wins   = [m['neo_win'] for m in window_meta]
        ctx.bounds     = actual_bounds
        ctx.required_quant    = required_quant
        ctx.required_lr_window = required_lr_window
        ctx.num_windows       = num_windows

        if required_quant:
            out = fake_floor(torch.stack(window_outputs, dim=1)).sum(dim=1).clamp(-128, 127)
        else:
            out = torch.stack(window_outputs, dim=1).sum(dim=1)
        return out

    @staticmethod
    def compute_window(x_win, neo_win, alpha, beta, gamma_i, required_quant):
        B, W = x_win.shape
        D = alpha.shape[0]

        a = alpha.view(1, 1, D)
        b = beta.view(1, 1, D)

        neo_log = pseudo_log2(neo_win).unsqueeze(2)
        if required_quant:
            neo_log = fake_floor(neo_log)

        weights = neo_log * a + b
        if required_quant:
            weights = fake_floor(weights)
            scaling = torch.exp2(torch.clamp(weights - 8.0, -8.0, 7.0))
        else :
            scaling = torch.exp2(weights - 8.0)
        contrib = x_win.unsqueeze(2) * scaling
        if required_quant:
            contrib = fake_floor(contrib)

        out = contrib.sum(dim=1) * gamma_i
        return out, {
            'scaling': scaling.permute(0, 2, 1),  # [B,D,W]
            'x_win': x_win,
            'neo_win': neo_win,
        }

    @staticmethod
    def backward(ctx, grad_output):
        x, alpha, beta, gamma, boundary_params = ctx.saved_tensors
        scalings   = ctx.scalings
        x_wins     = ctx.x_wins
        neo_wins   = ctx.neo_wins
        required_quant    = ctx.required_quant
        required_lr_window = ctx.required_lr_window
        num_windows       = ctx.num_windows
        B, L = x.shape

        grad_x     = torch.zeros_like(x)
        grad_alpha = torch.zeros_like(alpha)
        grad_beta  = torch.zeros_like(beta)
        grad_gamma = torch.zeros_like(gamma)
        grad_boundary_params = torch.zeros_like(boundary_params) if required_lr_window else None

        LN2 = torch.log(torch.tensor(2.0, device=x.device, dtype=x.dtype))

        for i, (start, end) in enumerate(ctx.bounds):
            if end <= start:
                continue
            x_win   = x_wins[i]
            neo_win = neo_wins[i]
            scaling = scalings[i]

            grad_out_i = grad_output * gamma[i]   # [B,D]

            grad_x_win = (grad_out_i.unsqueeze(2) * scaling).sum(dim=1)
            grad_x[:, start:end] += grad_x_win

            grad_scaling = grad_out_i.unsqueeze(2) * x_win.unsqueeze(1)
            grad_weight  = grad_scaling * scaling * LN2

            neo_log = torch.log2(torch.clamp(neo_win, min=1.0))
            grad_alpha[i] += (grad_weight * neo_log.unsqueeze(1)).sum(dim=(0, 2))
            grad_beta[i]  += grad_weight.sum(dim=(0, 2))

            wf_out, _ = NeuralSignalCodecFunction.compute_window(
                x_win, neo_win, alpha[i], beta[i], 1.0, required_quant
            )
            grad_gamma[i] = (grad_output * wf_out).sum()

        if required_lr_window:
            window_aggr = []
            for i in range(num_windows):
                wf_out, _ = NeuralSignalCodecFunction.compute_window(
                    x_wins[i], neo_wins[i], alpha[i], beta[i], 1.0, required_quant
                )
                g_i = (grad_output * wf_out).sum()
                window_aggr.append(g_i)
            window_aggr = [x.new_tensor(0.0)] + window_aggr

            sigma = torch.sigmoid(boundary_params)
            sigma = torch.clamp(sigma, 0.05, 0.95)
            ds = sigma * (1 - sigma)

            d_boundaries = []
            for i, s in enumerate(sigma):
                if i == 0:
                    d_boundaries.append(L * ds[i])
                else:
                    prod = torch.prod(1 - sigma[:i])
                    d_boundaries.append(L * prod * ds[i])

            grad_boundary_params = torch.zeros_like(sigma)
            for i in range(len(sigma)):
                grad_boundary_params[i] = (window_aggr[i+1] - window_aggr[i]) * d_boundaries[i]

        return grad_x, None, grad_alpha, grad_beta, grad_gamma, grad_boundary_params, None, None



class NeuralSignalCodec(nn.Module):
    def __init__(self, input_size=128, num_windows=3, latent_dim=4, required_quant=True, required_lr_window=True):
        super().__init__()
        '''
        Require_quant:
            Train = True, Eval = True   --------- QAT
            Train = False, Eval = True  --------- PTQ, achieved by the forward quant_force
            Train = False, Eval = False --------- FP
        Required_lr_window:
            True : learnable window
            False : stable window 
       '''
        self.input_size = input_size
        self.num_windows = num_windows
        self.latent_dim = latent_dim
        self.required_quant = required_quant
        self.required_lr_window = required_lr_window

        n = num_windows
        p_list = []
        for i in range(1, n):
            remaining = n - (i - 1)
            p = 1 / remaining
            p_list.append(p)

        boundary_params_init = torch.log(torch.tensor(p_list) / (1 - torch.tensor(p_list))) # this setting divide the window with the mean parts initially
        self.boundary_params = nn.Parameter(boundary_params_init)
        self.alpha = nn.Parameter(torch.rand(num_windows, latent_dim) * 0.1 + 0.25)
        self.beta = nn.Parameter(torch.zeros(num_windows, latent_dim))
        self.gamma = nn.Parameter(torch.ones(num_windows) * 0.5)

    def forward(self, x, neo, required_quant_force = None):
        if required_quant_force is not None:
            required_quant = required_quant_force
        else :
            required_quant = self.required_quant
        return NeuralSignalCodecFunction.apply(
            x, neo, self.alpha, self.beta, self.gamma, self.boundary_params, required_quant, self.required_lr_window
        )

    def get_quantized_params_for_ptq(self):
        with torch.no_grad():
            sigma = torch.sigmoid(self.boundary_params)
            boundaries = [0.0]
            for s in sigma:
                next_b = boundaries[-1] + (1 - boundaries[-1]) * s
                boundaries.append(next_b)
            boundaries.append(1.0)
            boundaries = torch.tensor(boundaries, device=self.alpha.device) * self.input_size
            # true quant
            alpha_q = torch.clamp((self.alpha * 16).round() / 16.0, 0.0, 15.0 / 16.0)
            beta_q  = torch.clamp(self.beta.round(), -8, 7).int().float()
            gamma_q = torch.clamp((self.gamma * 16).round() / 16.0, 0.0, 15.0 / 16.0)
        return boundaries.tolist(), alpha_q.cpu(), beta_q.cpu(), gamma_q.cpu()


    def get_quantized_params(self): # deploy for the hardware or just extract from the code 
        with torch.no_grad():
            sigma = torch.sigmoid(self.boundary_params) 
            boundaries = [0.0]
            for s in sigma:
                next_b = boundaries[-1] + (1 - boundaries[-1]) * s
                boundaries.append(next_b)
            boundaries.append(1.0) 
            boundaries = torch.tensor(boundaries) * self.input_size

            if self.required_quant:
                alpha = fake_quant_4bit(self.alpha) # 0-15 4 bits
                beta  = fake_quant_signed_int(self.beta) # -8 - 7 4bits
                gamma = fake_quant_4bit(self.gamma) # 0 - 15 4 bits    
            else:
                alpha, beta, gamma = self.alpha, self.beta, self.gamma

        return boundaries, alpha, beta, gamma
