import numpy as np
from scipy.linalg import cholesky, solve_triangular
from pgpr_type import pgpr_matrix, pgpr_vector

class pgpr_chol:
    """
    Cholesky factorization class with related functions such as solving linear systems,
    computing the inverse, and calculating the log-determinant.
    """
    def __init__(self, A):
        """
        Perform Cholesky factorization on matrix A.
        :param A: The matrix to factorize (must be square and symmetric positive definite).
        """
        if A.nrows() != A.ncols():
            raise ValueError("Matrix must be square for Cholesky decomposition.")
        
        self.n = A.nrows()
            # Convert pgpr_matrix to a NumPy array
        A_np = np.zeros((self.n, self.n))
        for i in range(self.n):
            for j in range(self.n):
                A_np[i, j] = A[i][j]
        # Perform Cholesky decomposition using SciPy
        self.el = cholesky(A_np, lower=True)  # Lower triangular matrix
        self.elt = self.el.T  # Transpose of the lower triangular matrix

        # self.el = pgpr_matrix(self.n, self.n)  # Lower triangular Cholesky factor
        # self.elt = pgpr_matrix(self.n, self.n)  # Transposed lower triangular matrix
        # self.el.assign(self.n, self.n, 0.0)
        # self.elt.assign(self.n, self.n, 0.0)

        # # Copy input matrix to el
        # for i in range(self.n):
        #     for j in range(self.n):
        #         self.el[i][j] = A[i][j]

        # self._regularize(self.el)  # Regularize the matrix if necessary
        
        # # Perform Cholesky decomposition
        # for i in range(self.n):
        #     for j in range(i, self.n):
        #         sum_val = self.el[i][j] - sum(self.el[i][k] * self.el[j][k] for k in range(i))
        #         if i == j:
        #             if sum_val <= 0.0:
        #                 raise ValueError(f"Cholesky decomposition failed at element ({i}, {j}) with sum = {sum_val}")
        #             self.el[i][i] = sum_val ** 0.5
        #         else:
        #             self.el[j][i] = sum_val / self.el[i][i]
        
        # # Zero out the upper triangular part
        # for i in range(self.n):
        #     for j in range(i + 1, self.n):
        #         self.el[i][j] = 0.0
        
        # # Transpose of the lower triangular matrix
        # for i in range(self.n):
        #     for j in range(self.n):
        #         self.elt[j][i] = self.el[i][j]

    def _regularize(self, A, epsilon=1e-3):
        """
        Regularize the matrix by adding a small value to the diagonal.
        :param A: The matrix to regularize.
        :param epsilon: The regularization value.
        """
        for i in range(A.nrows()):
            A[i][i] += epsilon

    def solve(self, b):
        """
        Solve the linear system A * x = b using the Cholesky factorization.
        :param b: The right-hand side vector.
        :return: The solution vector x.
        """
        if b.size() != self.n:
            raise ValueError("Dimension mismatch in solve.")
        # Convert b to a NumPy array of float type
        b_np = np.array([float(b[i]) for i in range(len(b))], dtype=np.float64)

        # Forward substitution: Solve el * y = b
        y = solve_triangular(self.el, b_np, lower=True)
        # Backward substitution: Solve elt * x = y
        x = solve_triangular(self.elt, y, lower=False)
        result = pgpr_vector(len(x))
        for i in range(len(x)):
            result[i] = x[i]
        # # Forward substitution: Solve el * y = b
        # y = pgpr_vector(self.n)
        # for i in range(self.n):
        #     y[i] = (b[i] - sum(self.el[i][k] * y[k] for k in range(i))) / self.el[i][i]
        
        # # Backward substitution: Solve elt * x = y
        # x = pgpr_vector(self.n)
        # for i in range(self.n - 1, -1, -1):
        #     x[i] = (y[i] - sum(self.elt[i][k] * x[k] for k in range(i + 1, self.n))) / self.el[i][i]
        
        return result

    def elsolve(self, b):
        """
        Solve the system el * y = b using forward substitution.
        :param b: The right-hand side vector.
        :return: The solution vector y.
        """
        if b.size() != self.n:
            raise ValueError("Dimension mismatch in elsolve.")
        
        b_np = np.array([float(b[i]) for i in range(len(b))], dtype=np.float64)    
        # Use SciPy's solve_triangular for forward substitution
        y = solve_triangular(self.el, b_np, lower=True) 
        # y = pgpr_vector(self.n)
        # for i in range(self.n):
        #     y[i] = (b[i] - sum(self.el[i][k] * y[k] for k in range(i))) / self.el[i][i]
        
        return y

    def inverse(self):
        """
        Compute the inverse of the matrix using the Cholesky factorization.
        :return: The inverse matrix.
        """
        ainv = pgpr_matrix(self.n, self.n)
        ainv.assign(self.n, self.n, 0.0)
        for i in range(self.n):
            e = pgpr_vector(self.n)
            e.assign(self.n, 0.0)
            e[i] = 1.0
            x = self.solve(e)
            for j in range(self.n):
                ainv[j][i] = x[j]
        return ainv

    def logdet(self):
        """
        Compute the log-determinant of the matrix.
        :return: The log-determinant value.
        """
        return 2.0 * np.sum(np.log(np.diag(self.el)))