#!/usr/bin/env python


import numpy as np
import ctypes

def _to_uint64_array(array):
    assert array.ndim == 1, "Array must be one-dimensional."
    if array.dtype == np.uint64:
        return array
    
    return array.astype(np.uint64)


class ClutContext(ctypes.Structure):
    _fields_ = [
        ('r', ctypes.c_size_t),
        ('c', ctypes.c_size_t),
        ('cLUT', ctypes.POINTER(ctypes.c_uint64)),
        ('cLUT_size', ctypes.c_size_t),
        ('cLUT_linelen', ctypes.c_size_t),
    ]

    def cLUT_toNumpy(self) -> np.array:
        return np.ctypeslib.as_array(self.cLUT, shape=(self.cLUT_size,))

    @classmethod
    def fromNumpy(cls, r: int, c: int, clut: np.array):
        """
        Keep the source array around as long as the returned object lives.
        Do NEVER call free_cLUT on the created context.
        """
        assert clut.ndim == 1 and clut.dtype == np.uint64

        return cls(r=r, c=c, cLUT=clut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)), cLUT_size=clut.size)


def _make_build_func(func, res_type=ctypes.c_void_p):
    func.argtypes = [ctypes.POINTER(ctypes.c_uint64), ctypes.c_size_t]
    func.restype = res_type
    
    def call_f(array: np.array) -> res_type:
        array = _to_uint64_array(array)

        return func(array.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)), array.size)

    return call_f


def _make_sample_func(func):
    func.argtypes = [ctypes.c_void_p]
    func.restype = ctypes.c_uint64

    def call_f(ctx: ctypes.c_void_p) -> int:
        return func(ctx)

    return call_f


def _make_free_func(func):
    func.argtypes = [ctypes.c_void_p]
    func.restype = None

    def call_f(ctx: ctypes.c_void_p) -> None:
        return func(ctx)

    return call_f


SAMPLER_NAMES = ['aldr.flat', 'fldr.flat', 'aldr.enc', 'fldr.enc', 'alias.c', 'cLUT', 'nLUT', 'numpy']


SAMPLER_CLASSES = ['ALDRFlatSampler', 'FLDRFlatSampler', 'ALDREncSampler', 'FLDREncSampler', 'AliasCSampler', 'CLUTSampler', 'NLUTSampler', 'NumPySampler']


BUILD_FUNCTIONS = ['preprocess_aldr_flat', 'preprocess_fldr_flat', 'preprocess_aldr_enc', 'preprocess_fldr_enc', 'preprocess_weighted_alias', 'build_cLUT', 'build_nLUT', 'build_numpy']


SAMPLE_FUNCTIONS = ['sample_aldr_flat', 'sample_aldr_flat', 'sample_aldr_enc', 'sample_aldr_enc', 'sample_weighted_alias_index', 'cLUT_sampling', 'nLUT_sampling', 'numpy_sampling']


FREE_FUNCTIONS = ['free_aldr_flat_s', 'free_aldr_flat_s', 'free_array_s', 'free_array_s', 'free_sample_weighted_alias_index', 'free_cLUT', 'free_nLUT', 'free_numpy']


Samplers = {} # To be filled in _setup

def _create_functions(library, factory, names):
    for f_name in names:
        globals()[f_name] = factory(getattr(library, f_name))


def _sampler_class_factory(key, className, buildFunc, sampleFunc, freeFunc):

    class Sampler(object):

        def __init__(self, array: np.array):
            self.key = key
            self.ctx = buildFunc(array)

        def sample(self):
            return sampleFunc(self.ctx)

        def __del__(self):
            freeFunc(self.ctx)

    return Sampler


def build_preinit_lut_sampler(key, sampling_function):

    class PreinitializedCLUTSampler(object):

        def __init__(self, r: int, c: int, clut: np.array):
            self.key = key
            self.cLUT = _to_uint64_array(clut).copy()
            self.ctx = ClutContext.fromNumpy(r, c, self.cLUT)
            self.ctx_p = ctypes.pointer(self.ctx)

        def sample(self):
            return sampling_function(self.ctx_p)

        @property
        def r(self) -> int:
            return self.ctx.r

        @property
        def c(self) -> int:
            return self.ctx.c

        def __del__(self):
            pass # Intentionally, as memory handling of self.ctx and self._mem_ref is done by Python

    return PreinitializedCLUTSampler


liblut = ctypes.cdll.LoadLibrary('./liblut.so')

def _setup():
    liblutinfo = zip(SAMPLER_NAMES, SAMPLER_CLASSES, BUILD_FUNCTIONS, SAMPLE_FUNCTIONS, FREE_FUNCTIONS)

    # Special handling of for cLUT sample class, that should be able to provide r and c for debug purposes
    CLUT_KEY = 'cLUT'
    build_res_types = {CLUT_KEY: ctypes.POINTER(ClutContext)}
    def add_c_and_r_for_cLUT(cl, key):
        if not key == CLUT_KEY:
            return

        def get_r(self):
            return self.ctx.contents.r

        def get_c(self):
            return self.ctx.contents.c

        def get_cLUT(self) -> np.array:
            return self.ctx.contents.cLUT_toNumpy()

        setattr(cl, 'r', property(get_r))
        setattr(cl, 'c', property(get_c))
        setattr(cl, 'cLUT', property(get_cLUT))


    for key, className, buildName, sampleName, freeName in liblutinfo:
        globals()[buildName] = bf = _make_build_func(getattr(liblut, buildName), res_type=build_res_types.get(key, ctypes.c_void_p))
        globals()[sampleName] = sf = _make_sample_func(getattr(liblut, sampleName))
        globals()[freeName] = ff = _make_free_func(getattr(liblut, freeName))
        cl = _sampler_class_factory(key, className, bf, sf, ff)
        add_c_and_r_for_cLUT(cl, key)
        
        globals()[className] = Samplers[key] = cl

    globals()["PreinitializedCLUTSampler"] = build_preinit_lut_sampler("cLUT", liblut.cLUT_sampling)

_setup()
__all__ = SAMPLER_CLASSES + BUILD_FUNCTIONS + SAMPLE_FUNCTIONS + FREE_FUNCTIONS + ["Samplers", "PreinitializedCLUTSampler"]
