import numpy as np
from pgpr_type import Mdoub, Vdoub
from pgpr_chol import pgpr_chol
from pgpr_cov import pgpr_cov
from pgpr_util import pgpr_timer, load_data, load_data_gs, save_data, get_rmse, get_mnlp, load_vector, A_invB_C

class pgpr_ppic_ls:
    def __init__(self, hypf):
        """
        Initialize the pgpr_ppic_ls class.
        :param hypf: Path to the hyperparameter file.
        """
        self.cov = pgpr_cov(hypf)
        self.h_mu = self.cov.mu
        self.pmu = Vdoub()
        self.pvar = Vdoub()
        self.elapsed = 0.0
        self.rmse = 0.0
        self.mnlp = 0.0

    def chol_cov(self, K_dd):
        """
        Perform Cholesky decomposition on a covariance matrix.
        :param K_dd: Covariance matrix (Mdoub).
        :return: Cholesky decomposition (pgpr_chol).
        """
        return pgpr_chol(K_dd)
    
    def chol_cov(self, obs, dnum):
        """
        Perform Cholesky decomposition on the covariance matrix.
        :param obs: Observation matrix (Mdoub).
        :param dnum: Number of observations.
        :return: Cholesky decomposition (pgpr_chol).
        """
        K_dd = Mdoub()
        self.cov.se_ard_n_matrix(obs, K_dd)  # Compute the covariance matrix
        return self.chol_cov(K_dd)  # Return the Cholesky decomposition

    def chol_pcov(self, obs, dnum, act, anum, chol_kuu=None):
        """
        Compute posterior covariance and perform Cholesky decomposition.
        :param obs: Observation matrix (Mdoub).
        :param dnum: Number of observations.
        :param act: Support set matrix (Mdoub).
        :param anum: Number of support points.
        :param chol_kuu: Cholesky decomposition of K_uu (optional).
        :return: Cholesky decomposition of posterior covariance.
        """
        kdd = Mdoub()
        if chol_kuu is None:
            chol_kuu = self.chol_cov(act, anum)
        self.post_cov(act, anum, chol_kuu, obs, dnum, kdd)
        return self.chol_cov(kdd)
    '''
    # def chol_pcov(self, obs, dnum, act, anum, chol_kuu):
    #     """
    #     Compute posterior covariance and perform Cholesky decomposition.
    #     :param obs: Observation matrix (Mdoub).
    #     :param dnum: Number of observations.
    #     :param act: Active set matrix (Mdoub).
    #     :param anum: Number of active points.
    #     :param chol_kuu: Cholesky decomposition of K_uu.
    #     :return: Cholesky decomposition of posterior covariance.
    #     """
    #     kdd = Mdoub()
    #     self.post_cov(act, anum, chol_kuu, obs, dnum, kdd)  # Compute posterior covariance
    #     return self.chol_cov(kdd)  # Perform Cholesky decomposition
    '''
    
    def post_var(self, obs, ss, chol, xt, ts, t_var):
        """
        Compute posterior variance using FGP.
        :param obs: Observation matrix (Mdoub).
        :param ss: Number of observations.
        :param chol: Cholesky decomposition of K_dd.
        :param xt: Test set matrix (Mdoub).
        :param ts: Number of test points.
        :param t_var: Output variance vector (Vdoub).
        :return: Success status.
        """
        # Step 1: Initialize vectors
        v = Vdoub(ss)
        beta = Vdoub(ss)
        t_var.resize(ts)

        # Step 2: Compute cross-covariance matrix K_td
        K_td = Mdoub()
        self.cov.se_ard_cross(xt, obs, K_td)

        # Step 3: Compute posterior variance
        for i in range(ts):
            t_var[i] = self.cov.nos + self.cov.sig  # Initialize with noise and signal variance

            # Extract K_ti (i-th row of K_td)
            K_ti = Vdoub(ss)
            for j in range(ss):
                K_ti[j] = K_td[i][j]

            # Solve v = L \ K_ti
            v = chol.elsolve(K_ti)

            # Update t_var[i] by subtracting the contribution of v
            for j in range(ss):
                t_var[i] -= v[j] * v[j]

        return 0  # SUCC

    def post_cov(self, obs, ss, chol, xt, ts, t_cov):
        """
        Compute posterior covariance using FGP.
        :param obs: Observation matrix (Mdoub).
        :param ss: Number of observations.
        :param chol: Cholesky decomposition of K_dd.
        :param xt: Test set matrix (Mdoub).
        :param ts: Number of test points.
        :param t_cov: Output covariance matrix (Mdoub).
        :return: Success status.
        """
        v = Vdoub(ss)
        beta = Vdoub(ss)
        t_cov.resize(ts, ts)
        K_td = Mdoub()
        self.cov.se_ard_cross(xt, obs, K_td)
        self.cov.se_ard_n_matrix(xt, t_cov)

        for i in range(ts):
            K_ti = Vdoub(ss)
            for j in range(ss):
                K_ti[j] = K_td[i][j]
            v = chol.elsolve(K_ti)
            for j in range(ss):
                t_cov[i][i] -= v[j] * v[j]
            chol.solve(K_ti, beta)
            for t in range(i + 1, ts):
                for j in range(ss):
                    t_cov[t][i] -= K_td[t][j] * beta[j]
                t_cov[i][t] = t_cov[t][i]
        return 0  # SUCC

    def pitc_prep(self, D, ds, chol_sdd, U, us, chol_kuu, fu, suu):
        """
        Prepare the local summary for PITC block.
        :param D: Dataset matrix (Mdoub).
        :param ds: Number of data points.
        :param chol_sdd: Cholesky decomposition of S_dd.
        :param U: Support set matrix (Mdoub).
        :param us: Number of support points.
        :param chol_kuu: Cholesky decomposition of K_uu.
        :param fu: Output mean vector (Vdoub).
        :param suu: Output covariance matrix (Mdoub).
        :return: Success status.
        """
        v = Vdoub(ds)
        alpha = Vdoub(ds)
        beta = Vdoub(ds)
        K_ud = Mdoub()
        self.cov.se_ard_cross(U, D, K_ud)

        for i in range(ds):
            v[i] = D[i][self.cov.dim] - self.h_mu
        alpha = chol_sdd.solve(v)

        for i in range(us):
            fu[i] = 0.0
            for j in range(ds):
                fu[i] += K_ud[i][j] * alpha[j]
            K_ui = Vdoub(ds)
            for j in range(ds):
                K_ui[j] = K_ud[i][j]
            v = chol_sdd.elsolve(K_ui)
            suu[i][i] = sum(v[j] * v[j] for j in range(ds))
            beta = chol_sdd.solve(K_ui)
            for t in range(i + 1, us):
                suu[t][i] = sum(K_ud[t][j] * beta[j] for j in range(ds))
                suu[i][t] = suu[t][i]
        return 0  # SUCC

    def pic_regr_blk_mpi(self, D, ds, aset, as_, ls_zu, ls_kuu, lmu, lcov):
        """
        Perform block-wise regression using PIC.
        :param D: Dataset matrix (Mdoub).
        :param ds: Number of data points.
        :param aset: Active set matrix (Mdoub).
        :param as_: Number of active points.
        :param ls_zu: Local mean vector (Vdoub).
        :param ls_kuu: Local covariance matrix (Mdoub).
        :param lmu: Path to output mean file.
        :param lcov: Path to output covariance file.
        :return: Success status.
        """
        # Step 1: Compute K_uu and perform Cholesky decomposition
        kuu = Mdoub()
        self.cov.se_ard_matrix(aset, kuu)
        chol_kuu = self.chol_cov(kuu)

        # Step 2: Compute posterior covariance and perform Cholesky decomposition
        chol_sdd = self.chol_pcov(D, ds, aset, as_, chol_kuu)

        # Step 3: Prepare the local summary for PITC
        self.pitc_prep(D, ds, chol_sdd, aset, as_, chol_kuu, ls_zu, ls_kuu)

        # Step 4: Check if the matrix is symmetric and positive semi-definite
        global_summary = np.array(ls_kuu.data())  # Convert Mdoub to NumPy array
        if np.allclose(global_summary, global_summary.T, atol=1e-10):
            print("The matrix is symmetric.")
            eigenvalues = np.linalg.eigvalsh(global_summary)
            if np.all(eigenvalues >= 0):
                print("The matrix is positive semi-definite.")
            else:
                print("The matrix is NOT positive semi-definite.")
        else:
            print("The matrix is NOT symmetric.")

        save_data(lmu, ls_zu)
        save_data(lcov, ls_kuu)

        return 0  # SUCC

    def regress_local(self, train, support, local_mean, local_cov):
        """
        Perform local regression.
        :param train: Path to training data file.
        :param support: Path to support set file.
        :param lmu: Path to output mean file.
        :param lcov: Path to output covariance file.
        """
        ddim = self.cov.dim + 1
        traindata = Mdoub(1, ddim)
        supportset = Mdoub(1, ddim)
        load_data(train, traindata)
        ds = traindata.nrows()
        load_data(support, supportset)
        ss = supportset.nrows()
        ls_zu = Vdoub(ss)
        ls_kuu = Mdoub(ss, ss)
        timer = pgpr_timer()
        timer.start()
        self.pic_regr_blk_mpi(traindata, ds, supportset, ss, ls_zu, ls_kuu, local_mean, local_cov)
        self.elapsed = timer.end()

    def pic_core(self, data, support, tsetk, gs_k, gs_mu, lsum_covariance, lsum_mean):
        """
        Perform the core PIC regression process.
        :param data: Path to the dataset file.
        :param support: Path to the support set file.
        :param tsetk: Path to the test set file.
        :param gs_k: Path to the global summary covariance file.
        :param gs_mu: Path to the global summary mean file.
        :param ls_k: Path to the local summary covariance file.
        :param ls_mu: Path to the local summary mean file.
        :return: Success status.
        """
        timer = pgpr_timer()
        timer.start()

        ddim = self.cov.dim + 1
        datak = Mdoub(1, ddim)
        tset = Mdoub(1, ddim)
        aset = Mdoub(1, ddim)

        load_data(data, datak)
        dsk = datak.nrows()
        load_data(support, aset)
        as_ = aset.nrows()
        load_data(tsetk, tset)
        ts = tset.nrows()

        ls_kuu = Mdoub(as_, as_)
        suu = Mdoub(as_, as_)
        gs_zu = Vdoub(as_)
        ls_zu = Vdoub(as_)

        load_vector(lsum_mean, ls_zu)
        load_vector(gs_mu, gs_zu)

        kuu = Mdoub()
        self.cov.se_ard_matrix(aset, kuu)

        es_kuu = Mdoub(kuu)
        es_zu = Vdoub(as_, 0.0)
        gs_kuu = Mdoub(kuu)

        load_data_gs(gs_k, suu)
        load_data_gs(lsum_covariance, ls_kuu)

        chol_kuu = self.chol_cov(kuu, as_)

        # Check if the matrix is symmetric and positive semi-definite
        global_summary = np.array(suu.data())
        if np.allclose(global_summary, global_summary.T, atol=1e-10):
            print("The matrix is symmetric.")
            eigenvalues = np.linalg.eigvalsh(global_summary)
            if np.all(eigenvalues >= 0):
                print("The matrix is positive semi-definite.")
            else:
                print("The matrix is NOT positive semi-definite.")
        else:
            print("The matrix is NOT symmetric.")

        # Update gs_kuu with suu
        for i in range(as_):
            for j in range(as_):
                gs_kuu[i][j] += suu[i][j]

        # Compute es_zu and es_kuu
        for r in range(as_):
            es_zu[r] = gs_zu[r] - ls_zu[r]
            for c in range(as_):
                es_kuu[r][c] = gs_kuu[r][c] - ls_kuu[r][c] - kuu[r][c]

        chol_sddk = self.chol_pcov(datak, dsk, aset, as_, chol_kuu)
        chol_suu = self.chol_cov(gs_kuu, as_)

        # Predictive mean for each block
        v = Vdoub(dsk)
        K_du = Mdoub()
        K_td = Mdoub()
        K_tu = Mdoub()
        K_ut = Mdoub()
        self.cov.se_ard_cross(datak, aset, K_du)
        self.cov.se_ard_cross(tset, aset, K_tu)
        self.cov.se_ard_cross(tset, datak, K_td)

        for i in range(ts):
            for j in range(as_):
                K_ut[j][i] = K_tu[i][j]

        for i in range(dsk):
            v[i] = datak[i][self.cov.dim] - self.h_mu

        if self.pmu.size() != ts:
            self.pmu.resize(ts)

        t1 = Vdoub(ts)
        t2 = Vdoub(ts)
        m_tu = Mdoub(ts, as_)
        tk_tu = Mdoub(ts, as_)

        A_invB_C(K_td, chol_sddk, v, self.pmu)
        A_invB_C(K_tu, chol_kuu, es_zu, t1)
        A_invB_C(K_tu, chol_kuu, es_kuu, m_tu)
        A_invB_C(K_td, chol_sddk, K_du, tk_tu)

        for i in range(ts):
            for j in range(as_):
                tk_tu[i][j] += m_tu[i][j]

        A_invB_C(tk_tu, chol_suu, gs_zu, t2)

        for i in range(ts):
            self.pmu[i] += (self.h_mu + t1[i] - t2[i])

        # Predictive variance
        if self.pvar.size() != ts:
            self.pvar.resize(ts)

        A_invB_C(tk_tu, chol_suu, self.pvar)
        A_invB_C(m_tu, chol_kuu, K_ut, t1)
        A_invB_C(K_td, chol_sddk, t2)

        for i in range(ts):
            self.pvar[i] += self.cov.nos + self.cov.sig - t1[i] - t2[i]

        trueval = Vdoub(ts)
        for i in range(ts):
            trueval[i] = tset[i][ddim - 1]

        self.rmse = get_rmse(trueval, self.pmu)
        self.mnlp = get_mnlp(trueval, self.pmu, self.pvar)
        self.elapsed = timer.end()

        return 0  # SUCC
