# TorchDiffPC.py

import torch
import torch.nn as nn
from dataclasses import dataclass

print('Running TorchDiffPC.py with Quantized FORWARD PASS and Quantized INFERENCE logic.')

#  DiffPC Components


def _mod_idx(k: int, period: int) -> int:
    if period <= 0: return 0
    return (k % period + period) % period

@dataclass
class _LTCyclicPhase:
    """Scheduler for generating the adaptive threshold l_t for quantization."""
    m: int; n: int; a: float
    def get_l_t(self, sample_step: int) -> float:
        k = _mod_idx(sample_step, self.n)
        return self.a * (2.0 ** self.m) / (2.0 ** k)

@torch.no_grad()
def _tensor_fake_quantize(target_tensor: torch.Tensor, lt_m: int, lt_n: int, lt_a: float) -> (torch.Tensor, int):
    """
    Approximates a floating-point tensor by simulating the DiffPC update process.
    This runs a full cycle of lt_n steps to find the final quantized value.
    """
    scheduler = _LTCyclicPhase(m=lt_m, n=lt_n, a=lt_a)
    approximated_tensor = torch.zeros_like(target_tensor)
    total_spikes = 0
    
    for t in range(lt_n):
        l_t = scheduler.get_l_t(t)
        difference = target_tensor - approximated_tensor
        
        spike_mask = torch.abs(difference) > l_t
        transmitted_quantum = torch.sign(difference) * l_t
        
        total_spikes += torch.sum(spike_mask).item()
        
        approximated_tensor.add_(transmitted_quantum * spike_mask)

    return approximated_tensor, total_spikes

#  TorchSeq2PC Core Functions

@torch.no_grad()
def QuantizedFwdPass(model, X, n, lt_m, lt_n, lt_a):
    """
    Performs a forward pass where the activation of each layer is built up
    iteratively over 'n' steps using quantized communication.
    """
    vhat = [X.detach()]
    scheduler = _LTCyclicPhase(m=lt_m, n=lt_n, a=lt_a)
    total_ff_spikes = 0

    for layer_idx in range(len(model)):
        input_tensor = vhat[layer_idx]
        
        target_activation = model[layer_idx](input_tensor)
        approximated_activation = torch.zeros_like(target_activation)
        
        for t in range(n): # Iterative buildup over n steps
            l_t = scheduler.get_l_t(t)
            difference = target_activation - approximated_activation
            
            spike_mask = torch.abs(difference) > l_t
            transmitted_quantum = torch.sign(difference) * l_t
            
            # Accumulate feedforward spikes
            total_ff_spikes += torch.sum(spike_mask).item()
            
            approximated_activation.add_(transmitted_quantum * spike_mask)
            
        vhat.append(approximated_activation)
        
    return vhat, total_ff_spikes

def QuantizedPredErrs(model, vhat, dLdy, eta, n, lt_m, lt_n, lt_a, e_mult):
    """
    Predictive Coding using quantized states for all communication.
    Both beliefs (v_A) and errors (epsilon_A) are quantized before use.
    """
    DepthPlusOne = len(model) + 1
    
    v = [vh.clone().detach() for vh in vhat]
    epsilon = [torch.zeros_like(vh) for vh in vhat]
    epsilon[-1].copy_(dLdy)
    
    v_A = [vh.clone().detach() for vh in vhat]
    epsilon_A = [torch.zeros_like(vh) for vh in vhat]
    
    lt_a_error = lt_a * e_mult
    total_lrn_spikes = 0
    
    for i in range(n):
        for l in range(DepthPlusOne):
            delta_v = v[l] - v_A[l]
            quantized_delta_v, spikes_v = _tensor_fake_quantize(delta_v, lt_m, lt_n, lt_a)
            v_A[l].add_(quantized_delta_v)
            total_lrn_spikes += spikes_v
            
            delta_eps = epsilon[l] - epsilon_A[l]
            quantized_delta_eps, spikes_eps = _tensor_fake_quantize(delta_eps, lt_m, lt_n, lt_a_error)
            epsilon_A[l].add_(quantized_delta_eps)
            total_lrn_spikes += spikes_eps

        for layer in reversed(range(DepthPlusOne - 1)):
            _, epsdfdv = torch.autograd.functional.vjp(model[layer], v_A[layer], epsilon_A[layer+1])
            
            epsilon[layer] = vhat[layer] - v[layer]
            
            dv = epsilon[layer] - epsdfdv
            v[layer].add_(eta * dv)

    return v, epsilon, total_lrn_spikes

def SetPCGrads(model,epsilon,X,v=None):
    DepthPlusOne=len(model)+1
    if v==None:
      v=[X]
      for layer in range(1,DepthPlusOne): v.append(model[layer-1](v[layer-1]))
    for layer in range(0,DepthPlusOne-1):
      vtemp0=v[layer].clone().detach().requires_grad_(True)
      vtemp1=model[layer](vtemp0)
      for p in model[layer].parameters():
        if p.requires_grad:
          dtheta=torch.autograd.grad(vtemp1,p,grad_outputs=epsilon[layer+1],allow_unused=True,retain_graph=True)
          if dtheta and dtheta[0] is not None:
              p.grad = dtheta[0]
              
#  Main PCInfer Function

def PCInfer(model,LossFun,X,Y,ErrType,eta=.1,n=20,vinit=None,lt_m=3,lt_n=4,lt_a=1.0, e_mult=1.0):

    vhat, ff_spikes = QuantizedFwdPass(model, X, n, lt_m, lt_n, lt_a)
    
    vhat[-1].requires_grad_(True)
    Loss = LossFun(vhat[-1], Y)
    dLdy = torch.autograd.grad(Loss, vhat[-1], retain_graph=False)[0].detach()
    vhat[-1] = vhat[-1].detach()
    
    lrn_spikes = 0 # Initialize learning spikes to 0
    
    if ErrType == "QuantizedPred":

        v, epsilon, lrn_spikes = QuantizedPredErrs(model, vhat, dLdy, eta, n, lt_m, lt_n, lt_a, e_mult)
        SetPCGrads(model, epsilon, X, v)
    
    else:
        raise ValueError('For this version, please use ErrType "QuantizedPred"')

    return vhat, Loss, dLdy, v, epsilon, ff_spikes, lrn_spikes