import numpy as np
from pgpr_util import SQR  # Import helper functions
from pgpr_type import Mdoub, Vdoub  # Import matrix and vector types

class pgpr_cov:
    """
    The pgpr_cov class provides information about covariance.
    """

    def __init__(self, hypf=None, h=None, d=None):
        """
        Initialize the covariance class with either a hyperparameter file or a hyperparameter vector.
        :param hypf: Path to the hyperparameter file.
        :param h: Hyperparameter vector.
        :param d: Dimension of the features (used with h).
        """
        if hypf is not None:
            # Initialize from a hyperparameter file
            with open(hypf, 'r') as file:
                lines = file.readline().split()
                self.sig = float(lines[0])
                self.nos = float(lines[1])
                self.mu = float(lines[2])
                self.dim = int(lines[3])
                self.lsc = [float(val) for val in lines[4:]]
                # print(f'length scales initial: {self.lsc}')
            self.sig = SQR(self.sig)
            self.nos = SQR(self.nos)
        elif h is not None and d is not None:
            # Initialize from a hyperparameter vector
            self.nos = SQR(h[0])
            self.lsc = h[1:d + 1]
            self.sig = SQR(h[d + 1])
            self.mu = h[-1]
            self.dim = d
        else:
            raise ValueError("Either 'hypf' or 'h' and 'd' must be provided.")
        
    def get_hyperparameters(self):
        """
        Retrieve the current hyperparameters of the covariance function.
        :return: List of hyperparameters [signal_variance, noise_variance, mean, length_scales...].
        """
        return [self.sig, self.nos, self.mu] + list(self.lsc)
    
    def update_hyperparameters(self, hyperparams):
        """
        Update the hyperparameters of the covariance function.
        :param hyperparams: List or array of hyperparameters [signal_variance, noise_variance, mean, length_scales...].
        """
        self.sig = max(float(hyperparams[0]), 1e-6)  # Ensure signal variance is positive
        self.nos = max(float(hyperparams[1]), 1e-6)  # Ensure noise variance is positive
        self.mu = float(hyperparams[2])  # Mean can be any value
        self.lsc = [max(float(l), 1e-6) for l in hyperparams[3:]]  # Ensure length scales are positive

    def se_ard_n(self, x, y):
        """
        Compute the squared exponential covariance with noise.
        :param x: First input vector.
        :param y: Second input vector.
        :return: Covariance value.
        """
        val = sum(SQR((x[i] - y[i]) / self.lsc[i]) for i in range(self.dim))
        return self.sig * np.exp(-0.5 * val) + self.nos

    def se_ard(self, x, y):
        """
        Compute the squared exponential covariance without noise.
        :param x: First input vector.
        :param y: Second input vector.
        :return: Covariance value.
        """
        val = sum(SQR((x[i] - y[i]) / self.lsc[i]) for i in range(self.dim))
        return self.sig * np.exp(-0.5 * val)

    def se_ard_n_matrix(self, a, k):
        """
        Compute the covariance matrix with noise for a dataset.
        :param a: Input dataset (rows are samples, columns are features).
        :param k: Output covariance matrix.
        """
        ss = a.nrows()
        k.resize(ss, ss)
        for i in range(ss):
            for j in range(i, ss):
                k[i][j] = self.se_ard(a[i], a[j])
                if i == j:
                    k[i][j] += self.nos
                else:
                    k[j][i] = k[i][j]

    def se_ard_matrix(self, a, k):
        """
        Compute the covariance matrix without noise for a dataset.
        :param a: Input dataset (rows are samples, columns are features).
        :param k: Output covariance matrix.
        """
        ss = a.nrows()
        k.resize(ss, ss)
        for i in range(ss):
            for j in range(i, ss):
                k[i][j] = self.se_ard(a[i], a[j])
                if i != j:
                    k[j][i] = k[i][j]
    
    def se_ard_cross(self, a, b, k):
        """
        Compute the cross-covariance matrix between two datasets.
        :param a: First dataset (pgpr_matrix, rows are samples, columns are features).
        :param b: Second dataset (pgpr_matrix, rows are samples, columns are features).
        :param k: Output cross-covariance matrix (pgpr_matrix).
        """
        ssa = a.nrows()
        ssb = b.nrows()
        k.resize(ssa, ssb)
        for i in range(ssa):
            for j in range(ssb):
                k[i][j] = self.se_ard(a[i], b[j])

    def se_ard_cross_numpy(self, a, b, k):
        """
        Compute the cross-covariance matrix between two datasets using NumPy arrays.
        :param a: First dataset (NumPy array, rows are samples, columns are features).
        :param b: Second dataset (NumPy array, rows are samples, columns are features).
        :param k: Output cross-covariance matrix (NumPy array).
        """
        ssa = a.shape[0]
        ssb = b.shape[0]
        k.resize((ssa, ssb))
        for i in range(ssa):
            for j in range(ssb):
                k[i, j] = self.se_ard(a[i], b[j])

    def se_ard_cross_vector(self, a, b, k):
        """
        Compute the cross-covariance matrix between a dataset and a list of vectors.
        :param a: First dataset (NumPy array, rows are samples, columns are features).
        :param b: Second dataset (list of NumPy vectors).
        :param k: Output cross-covariance matrix (NumPy array).
        """
        ssa = a.shape[0]
        ssb = len(b)
        k.resize((ssa, ssb))
        for i in range(ssa):
            for j in range(ssb):
                k[i, j] = self.se_ard(a[i], b[j])