import warnings

from acgp.abstract_blas_wrapper import AbstractBlasWrapper
from acgp.hooks.abstract_hook import AbstractHook

import numpy as np


class MetaPivotedCholesky:
    def __init__(self, initial_window_size: int = 256, window_increase: int = 2, initial_selection: int = 1, add: int = 1, blaswrapper: AbstractBlasWrapper = None):
        assert(initial_selection > 0)
        assert(initial_window_size >= initial_selection)
        assert(window_increase >= add)
        self.parameters = {"block_size": block_size}
        self.window_size = initial_window_size
        self.block_size = block_size
        if blaswrapper is None:
            warnings.warn("Going to use extremely slow numpy wrappers for low level BLAS operations!")
            from acgp.blas_wrappers.numpy_blas_wrapper import NumpyBlasWrapper
            blaswrapper = NumpyBlasWrapper()

        self.in_place_chol = blaswrapper.in_place_chol
        self.solve_triangular_inplace = blaswrapper.solve_triangular_inplace
        self.symmetric_down_date = blaswrapper.symmetric_down_date
        self.dgemm = blaswrapper.dgemm

    def get_signature(self):
        return type(self).__name__ + str(self.block_size)

    def run_configuration(self, A, err, kernel_evaluator=lambda *args: None, hook=AbstractHook()) -> ():
        """
        Estimates the log-marginal likelihood of a Gaussian process from a subset.
        :param A: allocated memory where the Cholesky can be stored into
        :param err: targets - mu(X)
        :param kernel_evaluator: function that writes kernel entries and noise into given indices
        :param hook: callback that can decide to stop the Cholesky if desired
        :return:
        """
        lower = True  # we compute a lower triangular Cholesky decomposition
        N = A.shape[0]  # matrix size -- can be different from the total dataset size!
        assert(A.shape[1] == N)
        assert(err.shape[0] == N and err.shape[1] == 1)

        initial_block_size = self.initial_block_size
        if self.initial_block_size > A.shape[0]:
            initial_block_size = A.shape[0]
        window_increase = self.window_increase

        if hook.prepare(A, err, self.block_size):
            return

        # first iteration of the Cholesky
        initial_selection = self.initial_selection
        window_pointer = self.initial_window_size
        kernel_evaluator(A, 0, window_pointer, 0, window_pointer)
        # here do permutation?
        K_ = A[:initial_selection, :initial_selection]
        y_ = err[:initial_selection, :]
        self.in_place_chol(K_)
        # first iteration for solving the linear equation system
        self.solve_triangular_inplace(K_, y_, transpose_b=False, lower=lower)

        if initial_selection >= N:
            return

        # prepare window
        to = A[initial_selection:window_pointer, :initial_selection]  # the next part that we are going to write to
        self.solve_triangular_inplace(A[:initial_selection, :initial_selection], to, transpose_b=True, lower=lower)  # O(i^2 * B)
        # apply symmetric down-date
        K_ = A[initial_selection:window_pointer, initial_selection:window_pointer]
        # K_ -= to @ to.T
        self.symmetric_down_date(K_, to)
        # start solving the next part of the linear equation system for the quadratic form
        y_ = err[initial_selection:window_pointer, :]
        #y_ -= to @ y[:idi, :]
        self.dgemm(to, err[:initial_selection, :], y_)

        # main loop of the Cholesky
        block_size = self.add
        self.permutation = np.arange(dataset_size)
        for idi in range(initial_selection, N, block_size):
            # TODO: now we could do a first test whether to stop
            # TODO: now we could make a permutation
            # make sure we never go beyond the size of the matrix
            advance = min(block_size, N - idi)

            K_ = A[idi: idi + advance, idi: idi + advance]
            # perform Cholesky of the down-dated part
            self.in_place_chol(K_)  # O(B^3)
            # finish solving the linear equation system
            y_ = err[idi:idi + advance, :]
            self.solve_triangular_inplace(K_, y_, transpose_b=False, lower=lower)


            # TODO: is post_chol the right call here?
            K_ = A[idi + advance:window_pointer, idi + advance:window_pointer]
            y_ = err[idi + advance:window_pointer, :]
            if hook.post_chol(idi=idi+advance, K_=K_, y_=y_):
                return

            to = A[idi + advance:window_pointer, idi:idi + advance]  # the next part that we are going to write to
            self.solve_triangular_inplace(A[idi: idi + advance, idi: idi + advance], to, transpose_b=True, lower=lower)
            #K_ = A[idi + advance:window_pointer, idi + advance:window_pointer]
            # K_ -= to @ to.T
            self.symmetric_down_date(K_, to)
            # TODO: do similar operations on y


            # TODO: here is another place to test whether to stop
            # TODO: here we could also make a permutation

            # make sure we never go beyond the size of the matrix
            window_increase = min(self.window_increase, N - window_pointer)
            if window_increase > 0:
                #kernel_evaluator(A, idi, advance, 0, idi)
                # TODO: Should we maybe only evaluate diagonal and one off-diagonal?
                # TODO: This implementation could be the most efficient in terms of kernel evaluations.
                kernel_evaluator(A, window_pointer, window_increase, 0, window_pointer+window_increase)

                # solve block off-diagonal part of the Cholesky for the already computed Cholesky
                to = A[window_pointer:window_pointer+window_increase, :idi+advance]  # the next part that we are going to write to
                self.solve_triangular_inplace(A[:idi+advance, :idi+advance], to, transpose_b=True, lower=lower)  # O(i^2 * B)
                # apply symmetric down-date
                K_ = A[idi: idi + advance, idi: idi + advance]
                # K_ -= to @ to.T
                self.symmetric_down_date(K_, to)
                # start solving the next part of the linear equation system for the quadratic form
                y_ = err[idi:idi + advance, :]
                #y_ -= to @ y[:idi, :]
                self.dgemm(to, err[:idi, :], y_)
                if hook.pre_chol(idi=idi, K_=K_, y_=y_):
                    return
                window_pointer += window_increase

        hook.finalize()
        return

    def select_next_points(self, already_selected_points, number_of_next_points, window_pointer):
        # extend window
        temp = np.array([window_pointer - already_selected_points, 1])
        kernel_evaluator(temp, already_selected_points, window_pointer, already_selected_points, window_pointer)
        # down-date on temp

        res = np.argsort(np.diag(temp))
        # apply permutation

        # update all points with respect to the last selected ones


