import math
import gpytorch.kernels
import logging
import mlflow
import numpy as np
import torch
from typing import Callable
from utils.result_management.constants import U_DET, L_DET, U_QUAD, L_QUAD, BLOCK_SIZE

try:
    from acgp.blas_wrappers.openblas.openblas_wrapper import OpenBlasWrapper as CPUBlasWrapper
except Exception as e:
    logging.exception(e)
    from acgp.blas_wrappers.numpy_blas_wrapper import NumpyBlasWrapper as CPUBlasWrapper

from acgp.blas_wrappers.torch_wrapper import TorchWrapper as GPUBlasWrapper
from acgp.meta_cholesky import MetaCholesky
from acgp.hooks.stop_hook import StopHook
from hyperparameter_tuning.utils.abstract_hyper_parameter_tuning_algorithm import AbstractHyperParameterTuningAlgorithm
from acgp.backends.numpy_backend import NumpyBackend
from acgp.backends.torch_backend import TorchBackend
from acgp.bound_computation import Bounds


ESTIMATOR = "ESTIMATOR"
DEFAULT = "default"
ALL_POINTS = "all_points"
estimators = [DEFAULT, ALL_POINTS]
MAX_N = "max_n"


class _StoppedCholesky(AbstractHyperParameterTuningAlgorithm):
    """
    This class is private and just meant to share code between the CPU and the GPU implementation classes below.
    """
    @staticmethod
    def add_parameters_to_parser(parser):
        default_block_size = 1024
        parser.add_argument("-r", "--relative_precision", type=float, default=0.1)
        parser.add_argument("-bs", f"--{BLOCK_SIZE}", type=int, default=default_block_size) #np.iinfo(np.int64).max)
        parser.add_argument("-e", "--" + ESTIMATOR, type=str, choices=estimators, default=estimators[0])
        parser.add_argument("-mn", f"--{MAX_N}", type=str, default=default_block_size * 40)

    def __init__(self, X: torch.Tensor, y: torch.Tensor, k: gpytorch.kernels.Kernel, sn2: Callable, mu: Callable, args, device="cpu"):
        """

        :param X:
        :param y:
        :param k:
        :param sn2:
            a function(!) that returns the noise
        :param mu:
            prior mean function
        :param args:
            parser arguments
        """
        super().__init__(X, y, k, sn2, mu, args, device=device)
        self.estimator = args[ESTIMATOR]
        self.set_tag(ESTIMATOR, self.estimator)
        # TODO: refactor!! exchange strings for constants!
        block_size = args[BLOCK_SIZE]
        r = args["relative_precision"]
        max_n = args[MAX_N]
        self.set_tag(BLOCK_SIZE, block_size)
        self.set_tag("r", r)
        self.set_tag(MAX_N, max_n)
        self.r = r
        self.block_size = block_size
        self.max_n = max_n
        if max_n < X.shape[0]:
            self.Xsub = X[:max_n, :]
            self.ysub = y[:max_n, :]
        else:
            self.Xsub = X
            self.ysub = y

        self.last_iter = None
        self.last_advance = None
        self.last_logged_bounds = {}
        self.require_ground_truth = block_size < X.shape[0]
        self.alpha0 = None
        self.subset_size = 0
        self.blaswrapper = CPUBlasWrapper()
        self.backend = NumpyBackend()
        self.bound_backend = TorchBackend(device=device)
        self.const = self.X.shape[0] / 2 * torch.log(2 * torch.tensor(math.pi, requires_grad=False, device=self.device, dtype=self.y.dtype))
        self.A = None
        self.K = None  # buffer for the kernel matrix
        self._setup_buffers()  # in a separate method for easier exchange

    def _setup_buffers(self):
        N = min(self.X.shape[0], self.max_n)
        self.A = np.zeros([N, N], order='F')
        self.K = np.zeros_like(self.A)

    def requires_ground_truth_recording(self):
        return self.require_ground_truth

    def create_loss_closure(self):
        N = self.X.shape[0]
        maxN = self.A.shape[0]

        @torch.no_grad()
        def Phi(A):
            # TODO: Use in-place operation?
            return torch.tril(A - torch.diag(torch.diag(A)) / 2)


        class MyChol(torch.autograd.Function):
            @staticmethod
            def forward(ctx, K):
                L = torch.as_tensor(self.A[:K.shape[0], :K.shape[0]])
                return L

            @staticmethod
            def backward(ctx, grad_outputs):
                with torch.no_grad():
                    N = self.subset_size
                    L = torch.as_tensor(self.A[:N, :N])
                    # reverse mode
                    # TODO: maybe use blas wrapper? could be faster
                    # S, _ = torch.triangular_solve(Phi(L.T @ grad_outputs), L, upper=False, transpose=True)

                    # S, _ = torch.triangular_solve(Phi(L.T @ grad_outputs), L.T, upper=True)
                    # S, _ = torch.triangular_solve(S.T, L.T, upper=True)
                    # # grad = S + S.T - torch.diag(torch.diag(S))
                    # grad = Phi(S + S.T)


                    #S = torch.triangular_solve(Phi(L.T @ grad_outputs), L, upper=False, transpose=True).solution.T
                    S = Phi(L.T @ grad_outputs)
                    # TODO: make below operation in-place
                    # WHAT THE HELL?! WHY ON EARTH IS TORCH CLONING THE COEFFICIENT MATRIX!!!
                    S = torch.triangular_solve(S, L, upper=False, transpose=True).solution.T

                    # TODO: need to find a generic way to make this a numpy array if necessary
                    #self.blaswrapper.solve_triangular_inplace(L, S, transpose_b=True, lower=True)
                    torch.triangular_solve(S, L, upper=False, transpose=True, out=(S, L))
                    #self.blaswrapper.solve_triangular_inplace(L, S, transpose_b=False, lower=True)
                    grad = Phi(S + S.T)
                    return grad#, None


        my_chol = MyChol.apply

        chol = MetaCholesky(block_size=self.block_size, blaswrapper=self.blaswrapper)

        def loss_closure():
            # the step below is necessary to convince torch that we compute the gradient of a new matrix
            K = self._get_K()
            err = self._get_y_copy()
            hook = self._get_hook()
            k_func = lambda *args: self.k(*args).evaluate()
            chol.run_configuration(self.A, err, kernel_evaluator=self._get_kernel_evaluator(X=self.X, k=k_func,
                                                                                            sn2=self.sn2(), K=K),
                                   hook=hook)

            iter = hook.iteration
            adv = min(self.block_size, maxN - iter)
            subset_size = iter + adv
            self.last_iter = iter
            self.last_advance = adv
            self.subset_size = subset_size

            if adv > 0:
                # finish the last step of the Cholesky
                self.blaswrapper.in_place_chol(self.A[iter:iter + adv, iter:iter + adv])
            # FIXME: REMOVE THIS!
            #K = k_func(self.X[:subset_size, :]) + self.sn2() * torch.eye(subset_size)
            # pretend to apply Cholesky but actually don't for gradient computation
            L = my_chol(K[:subset_size, :subset_size])
            log_sub_det = 2 * torch.sum(torch.log(torch.diag(L[:iter, :iter])))  # NOT until subset_size!
            # TODO: we solve here the linear equation system again
            # maybe there is a way around that but for now this is necessary to get a gradient
            alpha, _ = torch.triangular_solve(self.y[:subset_size, :] - self.mu(self.X[:subset_size, :]), L, upper=False)
            self.alpha0 = alpha
            sub_quad = torch.sum(torch.square(alpha[:iter, :]))  # NOT until subset_size!
            if adv > 0:
                # TODO: is beta correct?!
                beta = self.y[iter:iter+adv] - self.mu(self.X[iter:iter+adv, :]) - L[iter:iter+adv, :iter] @ alpha[:iter, :]
                bounds = Bounds(delta=1.0, N=N, min_noise=self.sn2(), backend=self.bound_backend)
                # TODO: I don't need the whole of K_
                K_ = L[iter:iter+adv, iter:iter+adv] @ L[iter:iter+adv, iter:iter+adv].T
                diagK_ = torch.diag(K_).reshape(-1, 1)
                off_diagK_ = torch.diag(K_, diagonal=-1).reshape(-1, 1)
                U_det, L_det, U_quad, L_quad, auxilary_variables = bounds.get_bound_estimators_and_auxilary_quantities(
                    t0=iter, log_sub_det=log_sub_det, sub_quad=sub_quad, A_diag=diagK_, A_diag_off=off_diagK_, y=beta,
                    noise_diag=self.sn2() * torch.ones([1], device=self.device))
                self.last_logged_bounds["U_DET"] = U_det.item()
                self.last_logged_bounds["L_DET"] = L_det.item()
                self.last_logged_bounds["U_QUAD"] = U_quad.item()
                self.last_logged_bounds["L_QUAD"] = L_quad.item()
                # TODO: when using the torch wrapper, some of the variables are torch tensors and mlflow does not like that
                #self.last_logged_bounds.update({k: v.item() for k, v in auxilary_variables.items()})

                if self.estimator == DEFAULT:
                    log_det = U_det / 2 + L_det / 2
                    quad = U_quad / 2 + L_quad / 2
                elif self.estimator == ALL_POINTS:
                    # I suspect that for the gradient the bound estimator favors too much the last processed datapoints.
                    # Hence, let's experiment with the following estimators which incorporate ALL datapoints equally.
                    # However, it also appears that this estimator can give the linesearch a bit more trouble.
                    factor = 1 + (N - subset_size) / subset_size
                    log_det = factor * 2 * torch.sum(torch.log(torch.diag(L[:subset_size, :subset_size])))
                    quad = factor * torch.sum(torch.square(alpha[:subset_size, :]))
                else:
                    raise RuntimeError(f"Unknown estimator: {self.estimator}")
                self.last_logged_bounds["DET"] = log_det.item()
                self.last_logged_bounds["QUAD"] = quad.item()
                est = log_det + quad
            else:
                self.last_logged_bounds[U_DET] = log_sub_det.item()
                self.last_logged_bounds[L_DET] = log_sub_det.item()
                self.last_logged_bounds[U_QUAD] = sub_quad.item()
                self.last_logged_bounds[L_QUAD] = sub_quad.item()
                self.last_logged_bounds["DET"] = log_sub_det.item()
                self.last_logged_bounds["QUAD"] = sub_quad.item()

                est = log_sub_det + sub_quad
            return est / 2 + self.const
        return loss_closure

    def get_posterior(self, X_star, full_posterior=False):
        # Predictive-posterior computation from GP Book / Rasmussen et al. 2006 (pp. 19)
        with torch.no_grad():
            L = torch.as_tensor(self.A[:self.subset_size, :self.subset_size])
            #L = L.to(self.device)  # shouldn't be necessary
            # the alpha0 is NOT the alpha from the Rasmussen book
            v = self.k(self.X[:self.subset_size, :], X_star).evaluate()
            torch.triangular_solve(v, L, upper=False, out=(v, L))
            f_m_star = self.mu(X_star) + v.T @ self.alpha0
            if full_posterior:
                f_v_star = self.k(X_star).evaluate() - v.T @ v
            else:
                # it appears that when diag=True then the returned tensor is not lazy...
                f_v_star = self.k(X_star, diag=True) - torch.sum(torch.square(v), dim=[0])
                f_v_star = torch.reshape(f_v_star, [-1, 1])
            return f_m_star, f_v_star

    def get_name(self):
        return self.get_registry_key() + str(self.block_size)

    def log_metrics(self, step: int):
        mlflow.log_metric("FULLY_PROCESSED_DATAPOINTS", self.last_iter, step=step)
        mlflow.log_metric("PARTIALLY_PROCESSED_POINTS", self.last_advance, step=step)
        mlflow.log_metrics(self.last_logged_bounds, step=step)

    @classmethod
    def _get_kernel_evaluator(cls, X, k, sn2, K):
        def kernel_evaluator(A, i0, i1, j0, j1):
            if i0 == j0 and i1 == j1:
                # TODO: is there a better way to fill the diagonal? Allocating a whole identity matrix seems expensive
                K[i0:i0 + i1, i0:i0 + i1] = k(X[i0:i0 + i1, :]) + sn2 * torch.eye(i1, device=K.device, dtype=X.dtype)
                # copy values into designated array
                # we use tril here, so we don't have to call it on the large matrix later
                A[i0:i0 + i1, j0:j0 + j1] = np.tril(K[i0:i0 + i1, j0:j0 + j1].detach().numpy())
            elif j1 <= i0:
                K[i0:i0 + i1, j0:j0 + j1] = k(X[i0:i0 + i1, :], X[j0:j0 + j1, :])
                # copy values into designated array
                # TODO: hopefully the Fortran order of A is not screwed here
                # seems okay: in the test cases the difference is 0
                A[i0:i0 + i1, j0:j0 + j1] = K[i0:i0 + i1, j0:j0 + j1].detach().numpy()
                #A[i0:i0 + i1, j0:j0 + j1] = np.asfortranarray(K[i0:i0 + i1, j0:j0 + j1].detach().numpy())
            else:
                raise RuntimeError("This case should not occur")

        return kernel_evaluator

    def _get_y_copy(self):
        y = self.ysub.numpy().copy()
        y -= self.mu(self.Xsub).detach().numpy()
        return y

    def _get_K(self):
        # this will avoid a copy
        return torch.as_tensor(self.K)

    def _get_hook(self):
        return StopHook(N=self.X.shape[0], min_noise=self.sn2().item(), relative_tolerance=self.r, backend=self.backend)


class StoppedCholesky(_StoppedCholesky):
    def __new__(cls, X: torch.Tensor, y: torch.Tensor, k: gpytorch.kernels.Kernel, sn2: Callable, mu: Callable, args, device="cpu"):
        if device == "cuda":
            return GPUStoppedCholesky(X, y, k, sn2, mu, args, device=device)
        else:
            return _StoppedCholesky(X, y, k, sn2, mu, args, device=device)


class GPUStoppedCholesky(_StoppedCholesky):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.blaswrapper = GPUBlasWrapper()
        self.backend = TorchBackend(device="cuda")
        self.bound_backend = self.backend

    def _setup_buffers(self):
        N = min(self.X.shape[0], self.max_n)
        self.A = torch.zeros([N, N], device=self.device, requires_grad=False, dtype=self.y.dtype)
        self.K = None

    @classmethod
    def _get_kernel_evaluator(cls, X, k, sn2, K):
        def kernel_evaluator(A, i0, i1, j0, j1):
            if i0 == j0 and i1 == j1:
                # TODO: is there a better way to fill the diagonal? Allocating a whole identity matrix seems expensive
                K[i0:i0 + i1, i0:i0 + i1] = k(X[i0:i0 + i1, :]) + sn2 * torch.eye(i1, device=K.device, dtype=X.dtype)
                # copy values into designated array
                # we use tril here, so we don't have to call it on the large matrix later
                A[i0:i0 + i1, j0:j0 + j1] = torch.tril(K[i0:i0 + i1, j0:j0 + j1].clone())
            elif j1 <= i0:
                K[i0:i0 + i1, j0:j0 + j1] = k(X[i0:i0 + i1, :], X[j0:j0 + j1, :])
                # copy values into designated array
                A[i0:i0 + i1, j0:j0 + j1] = K[i0:i0 + i1, j0:j0 + j1].clone()
            else:
                raise RuntimeError("This case should not occur")

        return kernel_evaluator

    def _get_y_copy(self):
        y = self.ysub.clone()
        with torch.no_grad():
            y -= self.mu(self.Xsub)
        return y

    def _get_K(self):
        # TODO: Can we avoid reallocating this much memory?
        return torch.zeros_like(self.A) #, device=self.device)

    def _get_hook(self):
        return StopHook(N=self.X.shape[0], min_noise=self.sn2(), relative_tolerance=self.r, backend=self.backend)
