R"""Interface to the bionmf-gpu library.

Note that to interface with recent GPUs, I had to make a small change of __[something]_xor(...)
to __[something]_xor_syn(0xFFFFFF, ...

User guide link:
    https://github.com/bioinfo-cnb/bionmf-gpu/blob/master/doc/user_guide.txt.md



Script to move my local version to banana:

    rsync -ra -e ssh \
        --exclude "*/__pycache__" \
        --exclude "*/.git" \
        "$HOME/Desktop/other_code/bionmf-gpu/" \
        "m@banana.cs.unc.edu:/fruitbasket/users/m/other_code/bionmf-gpu/"


"""
import dataclasses
import os
import subprocess
import tempfile
from typing import Optional, Union

import numpy as np
import tensorflow as tf
# TODO: Get multigpu support working


NMF_GPU_BINARY_PATH = '/fruitbasket/users/m/other_code/bionmf-gpu/bin/NMF_GPU'

# I'm using this rather than /tmp as I think /tmp is on a local machine on the server
# and will easily run out of disk space.
TEMP_DIR_PATH = '/fruitbasket/users/m/tmp'


@dataclasses.dataclass
class NmfGpu:
    """Utility class for interacting with NMF_GPU.

    Currently only supports the single-GPU version, but might support more later.
    """
    binary_path: str = NMF_GPU_BINARY_PATH

    # The following correspond to makefile parameters. Data will automatically be converted
    # these formats when interacting with NMF_GPU; they don't affect what gets passed
    # to this (Python) class. See the follwoing for details on the makefile parameters:
    # https://github.com/bioinfo-cnb/bionmf-gpu/blob/master/doc/installation_guide.txt.md#mkparams
    #
    # Corresponds to the SINGLE makefile parameter.
    single_precision: bool = True
    #
    # Corresponds to the UNSIGNED makefile parameter.
    unsigned_matrix_dims: bool = True

    # gpu_device: Optional[int] = None

    tmp_dir_path: str = TEMP_DIR_PATH

    def __post_init__(self):
        self.binary_path = os.path.expanduser(self.binary_path)
        self.tmp_dir_path = os.path.expanduser(self.tmp_dir_path)

    @property
    def _tf_float_dtype(self):
        return tf.float32 if self.single_precision else tf.float64

    @property
    def _tf_int_dtype(self):
        return tf.uint32 if self.unsigned_matrix_dims else tf.int32

    @property
    def _np_float_dtype(self):
        return np.float32 if self.single_precision else np.float64

    @property
    def _np_int_dtype(self):
        return np.uint32 if self.unsigned_matrix_dims else np.int32
 
    def _array_to_bytes(self, x: Union[tf.Tensor, np.ndarray]) -> bytes:
        if isinstance(x, tf.Tensor):
            x = tf.cast(x, self._tf_float_dtype).numpy()
        else:
            x = x.astype(self._np_float_dtype)
        # Flatten the array, ensuring we do so in row-major order.
        return x.tobytes('C')

    def _shape_to_bytes(self, shape) -> bytes:
        return np.array(shape, dtype=self._np_int_dtype).tobytes()

    def _write_matrix_to_native_io_file(self, x: Union[tf.Tensor, np.ndarray], file):
        """Writes a matrix to a native binary io file.

        See https://github.com/bioinfo-cnb/bionmf-gpu/blob/master/doc/user_guide.txt.md#fileformat
        for details on this file.
        """
        shape_bytes = self._shape_to_bytes(x.shape)
        x_bytes = self._array_to_bytes(x)
        file.write(shape_bytes)
        file.write(x_bytes)

    def _read_from_native_io_file(self, filepath: str) -> np.ndarray:
        with open(filepath, mode='rb') as f:
            content = f.read()
        shape = np.frombuffer(content, dtype=self._np_int_dtype, count=2)
        flat_matrix = np.frombuffer(content, dtype=self._np_float_dtype, offset=4 * 2)
        return flat_matrix.reshape(shape, order='C')

    def run_nmf(
        self,
        x: Union[tf.Tensor, np.ndarray],
        n_components: int,
        max_iters: int = 2000,
        n_iter_test_conv: int = 10,
        stop_threshold: int = 40,
        gpu_device: Optional[int] = None,
        # *,
        # return_same_type: bool = True,
    ):
        # TODO: See how long file-io stuff takes, maybe try to read from memory filesystem rather than disk.

        # TODO: Document the function parameters.
        with tempfile.NamedTemporaryFile('w+b', dir=self.tmp_dir_path) as f:
            self._write_matrix_to_native_io_file(x, f)
            f.flush()

            script_args = [
                self.binary_path,
                f.name,
                '-b', 1,
                '-e', 1,
                '-k', n_components,
                '-i', max_iters,
                '-j', n_iter_test_conv,
                '-t', stop_threshold
            ]
            # TODO: Maybe default to something based on the CUDA_VISIBLE_DEVICES
            # environment variable.
            if gpu_device is not None:
                script_args.extend(['-z', gpu_device])

            script_args = [str(s) for s in script_args]
            script_output = subprocess.run(script_args)

            # TODO: Process the script output instead of just printing it.
            # TODO: Check and handle errors if they occur.
            print(script_output)

            # TODO: I AM NOT CLEANING UP THESE FILES, MAYBE MAKE A WAY TO USE
            # TEMP FILES TO AUTOMATICALLY CLEAN THEM UP.

            W_filepath = f'{f.name}_W.native.dat'
            W = self._read_from_native_io_file(W_filepath)

            Ht_filepath = f'{f.name}_H.native.dat'
            Ht = self._read_from_native_io_file(Ht_filepath)

            # if return_same_type and isinstance(x, tf.Tensor):
            #     W = tf.constant(W)
            #     Ht = tf.constant(Ht)

        return W, Ht


"""
-b 1 -e 1


-K,-k <factorization_rank>: Factorization Rank. It must be at least 2, but not greater than any of both matrix dimensions. Default value: 2.

-I,-i <nIters>: Maximum number of iterations if the algorithm does not converge to a stable solution. Default: 2000.

-J,-j <niter_test_conv>: Performs a convergence test every <niter_test_conv> iterations. Default: 10.
If this value is greater than <nIters> (-i option), no test is performed. See "Test of Convergence" for details.

-T,-t <stop_threshold>: Stopping threshold. Default value: 40.
If matrix H has not changed on the last <stop_threshold> times that the convergence test has been performed, it is considered that the algorithm has converged to a solution and stops it. See "Test of Convergence" for details.

2.3. Other options
-h,-H: Prints a help message with all arguments.

-Z,-z <GPU_device>: Device ID to attach on (default: 0).
On the multi-GPU version, devices will be selected from this value.
For instance,
"""