import cupy as cp
import os

from utils.compile_raw_cuda import compile_kernel_with_replacements
from .kernels import Kernel

# Helper functions for calling raw cuda kernels that evaluate the RfK and PSPK contributions

def _load_preamble(kernel_name: str, replacements: list[tuple[str, str]]):
    current_dir = os.path.dirname(os.path.abspath(__file__))
    preamble_path = os.path.join(current_dir, 'cupy_cuda', f'preamble_{kernel_name}.cu')
    with open(preamble_path, 'r') as f:
        p = f.read()
        for (placeholder, replacement) in replacements:
            p = p.replace(f"__REPL_{placeholder}", replacement)
        return p


def _inject_array(xs: cp.ndarray):
    return '{' + ', '.join(f'{x:.8f}f' for x in xs) + '}'

class RawCuda:
    def __init__(self, replacements, kernel_name, debug):
        self.debug = debug
        replacements = replacements + [("DEBUG", str(debug).lower())]
        cubin_path = compile_kernel_with_replacements(kernel_name, replacements, debug=debug)
        self.kernel = cp.RawModule(path=cubin_path, backend='nvcc').get_function(kernel_name)


class InVoltFixed(RawCuda):
    def __init__(self, phi: Kernel, n_neurons: int, n_in: int, debug=False):
        self.n_neurons = n_neurons
        self.n_in = n_in
        replacements = [('KERNEL_EVAL', phi.c_kernel_str),
                        ('TRAINS_IN', str(n_in)),
                        ('TRAINS_OUT', str(n_neurons))]
        name = 'in_volt_fixed'

        super().__init__(replacements, name, debug=debug)

    def __call__(self, eval_ts, s_in, in_weights, neg_delays_inv_widths, inv_widths, n_threads=32):
        n, _, Ti = s_in.shape
        To = len(eval_ts)
        grid = (n, To)
        #n_threads = (self.n_neurons // 32 + 1) * 32 #TODO: Wrong results if n_threads < n_out ???? => Seems Fixed
        block = (n_threads,)
        shared_mem = cp.dtype(cp.float32).itemsize * (Ti + self.n_neurons)
        vis = cp.zeros((n, To, self.n_neurons), dtype=cp.float32)
        self.kernel(
            args=(s_in, in_weights, neg_delays_inv_widths, inv_widths, eval_ts, vis, Ti, To),
            grid=grid, block=block, shared_mem=shared_mem,
        )
        if self.debug:
            cp.cuda.Stream.null.synchronize()
        return vis

class TauCorrFixed(RawCuda):
    def __init__(self, n_neurons: int, n_in: int, debug=False, accumulate_absolute=False):
        self.n_neurons = n_neurons
        self.n_in = n_in
        replacements = [('TRAINS_IN', str(n_in)),
                        ('TRAINS_OUT', str(n_neurons)),
                        ('ACCUMULATE_ABSOLUTE', str(accumulate_absolute).lower())]
        name = 'tau_corr_fixed'

        super().__init__(replacements, name, debug=debug)

    def compute_shared_bytes(self, Ti: int, H: int, T: int) -> int:
        bytes_u16 = Ti * 2  # sizeof(uint16_t) == 2
        pad = (4 - (bytes_u16 % 4)) % 4  # minimal pad so float region is 4-byte aligned
        return bytes_u16 + pad + (T + H) * 4  # float is 4 bytes

    def __call__(self, tmax, dt, s_in, f, d_max: int, n_threads=32):
        n, _, Ti = s_in.shape
        T = int(tmax / dt)
        _, _, H = f.shape
        grid = (n, self.n_neurons)
        block = (n_threads,)
        shared_mem = self.compute_shared_bytes(Ti, H, d_max)
        C = cp.zeros((n, self.n_neurons, d_max), dtype=cp.float32)
        s_int = (s_in / dt).astype(cp.uint16)
        self.kernel(
            args=(s_int, f, C, Ti, T, H, d_max),
            grid=grid, block=block, shared_mem=shared_mem,
        )
        if self.debug:
            cp.cuda.Stream.null.synchronize()
        return C


class OutVoltFixed(RawCuda):
    def __init__(self, phi: Kernel, n_neurons: int, debug=False):
        self.n_neurons = n_neurons
        replacements = [('KERNEL_EVAL', phi.c_kernel_str),
                        ('TRAINS_OUT', str(n_neurons))]
        name = 'out_volt_fixed'

        super().__init__(replacements, name, debug=debug)

    def __call__(self, eval_ts, s_out, out_weights, neg_delays_inv_widths, inv_widths, n_threads=32):
        n, _, Ti = s_out.shape
        To = len(eval_ts)
        grid = (n, self.n_neurons, To)
        block = (n_threads,)
        vos = cp.zeros((n, To, self.n_neurons), dtype=cp.float32)
        self.kernel(
            args=(s_out, out_weights, neg_delays_inv_widths, inv_widths, eval_ts, vos, Ti, To),
            grid=grid, block=block,
        )
        if self.debug:
            cp.cuda.Stream.null.synchronize()
        return vos

class OutVoltBasis(RawCuda):
    def __init__(self, phi: Kernel, n_neurons: int, num_basis_q: int, debug=False):
        self.n_neurons = n_neurons
        self.num_basis = num_basis_q
        replacements = [('KERNEL_EVAL', phi.c_kernel_str),
                        ('TRAINS_OUT', str(n_neurons)),
                        ('NUM_BASIS', str(num_basis_q))]
        name = 'out_volt_basis'

        super().__init__(replacements, name, debug=debug)

    def __call__(self, eval_ts, s_out, out_weights, neg_delays_inv_widths, inv_widths):
        n, _, Ti = s_out.shape
        To = len(eval_ts)
        grid = (n, self.n_neurons, To)
        block = (32,)
        vos = cp.zeros((n, To, self.n_neurons), dtype=cp.float32)
        self.kernel(
            args=(s_out, out_weights, neg_delays_inv_widths, inv_widths, eval_ts, vos, Ti, To),
            grid=grid, block=block,
        )
        if self.debug:
            cp.cuda.Stream.null.synchronize()
        return vos

class FitContributionsFixed(RawCuda):
    def __init__(self, phi: Kernel, n_neurons: int, n_in: int, debug=False):
        self.n_neurons = n_neurons
        self.n_in = n_in
        replacements = [('KERNEL_EVAL', phi.c_kernel_str),
                        ('TRAINS_IN', str(n_in))]
        name = 'fit_contrib_fixed'

        super().__init__(replacements, name, debug=debug)

    def __call__(self, A, eval_ts, s_in, neg_delay_inv_width, inv_width, n_threads=128):
        n, _, Ti = s_in.shape
        To = len(eval_ts)
        grid = (n, To)
        block = (n_threads,)
        self.kernel(
            args=(s_in, eval_ts, A, neg_delay_inv_width, inv_width, Ti, To),
            grid=grid, block=block,
        )
        if self.debug:
            cp.cuda.Stream.null.synchronize()