import time
import numpy as np
import scipy.sparse as sp


class PDEDataBase:
    """
    Base class for preparing PDE data.
    """
    def __init__(self) -> None:
        self.config = None
        self.A = None
        self.b = None
        self.precond = []
        self.precond_time = None
        self.precond_error = None
        self.after_cond = None
        self.original_cond = None
        self.coords = None
        self.ref_sol = None
        self.ref_data = None
        self.test_X = None
        self.test_Y = None

    def init_problem(self):
        """
        Initialize the discrete PDE problem `Ax=b`.
        """
        raise NotImplementedError("init_problem() is not implemented.")

    def init_precond(self):
        """
        Initialize the preconditioner for the matrix `A`.
        """
        # Compute preconditioner
        t1 = time.time()
        res = sp.linalg.spilu(self.A,
            drop_tol=self.config["Drop Tolerance (in ILU)"],
            fill_factor=self.config["Fill Factor (in ILU)"])
        t2 = time.time()
        print("Time for computing preconditioner: {:.2f}s".format(t2 - t1))
        self.precond_time = t2 - t1

        Pr = sp.csr_matrix((np.ones(self.A.shape[0]), 
            (res.perm_r, np.arange(self.A.shape[0]))))
        Pc = sp.csr_matrix((np.ones(self.A.shape[0]), 
            (res.perm_c, np.arange(self.A.shape[0]))))
        L = res.L
        U = res.U
        self.precond = [Pr, L, U]

        # Preconditioner quality
        if self.config["Log Precondition Quality"]:
            # x := P^{-1}b
            b = Pr @ self.b
            y = sp.linalg.spsolve_triangular(L, b, lower=True)
            x = sp.linalg.spsolve_triangular(U, y, lower=False)
            x = Pc.T @ x
            x = x.reshape(-1)
            # x_ref := A^{-1}b
            x_ref = sp.linalg.spsolve(self.A, self.b)
            x_ref = x_ref.reshape(-1)
            self.precond_error = np.linalg.norm(x - x_ref, ord=2) / \
                np.linalg.norm(x_ref, ord=2)
            print("L2 relative error between P^{-1}b and A^{-1}b" + 
                ": {:.2e}".format(
                    self.precond_error
                ))

            Pr = Pr.todense()
            Pc = Pc.todense()
            L = L.todense()
            U = U.todense()
            A = self.A.todense()
            A_1 = np.linalg.inv(A)
            # A^{-1}P = A^{-1} @ Pr.T @ L @ U @ Pc
            A_1_P = A_1 @ Pr.T @ L @ U @ Pc
            self.original_cond = np.linalg.norm(self.b, ord=2) / \
                np.linalg.norm(x_ref, ord=2) * \
                np.linalg.norm(A_1, ord=2)
            print("Original condition number" + ": {:.2e}".format(
                    self.original_cond
                ))
            self.after_cond = np.linalg.norm(x, ord=2) / \
                np.linalg.norm(x_ref, ord=2) * \
                np.linalg.norm(A_1_P, ord=2)
            print("Condition number after preconditioning" + 
                ": {:.2e}".format(
                    self.after_cond
                ))
    
    def compute_original_cond(self):
        # x_ref := A^{-1}b
        x_ref = sp.linalg.spsolve(self.A, self.b)
        x_ref = x_ref.reshape(-1)
        # print original condition number
        A = self.A.todense()
        A_1 = np.linalg.inv(A)
        self.original_cond = np.linalg.norm(self.b, ord=2) / \
            np.linalg.norm(x_ref, ord=2) * \
            np.linalg.norm(A_1, ord=2)
        print("Original condition number" + ": {:.2e}".format(
            self.original_cond
        ))
    
    def prepare_test_data(self):
        """
        Prepare test data.
        """
        dim_s_t = self.config["Spatial Temporal Dimension"]
        if self.ref_sol is not None:
            bbox = self.config["Bounding Box"]
            if dim_s_t == 2:
                sample_points = 50
            elif dim_s_t == 3:
                sample_points = 30
            else:
                sample_points = int(1e4**(1 / dim_s_t))
            xlist = [np.linspace(bbox[i * 2], bbox[i * 2 + 1], sample_points) 
                for i in range(dim_s_t)]
            self.test_X = np.stack(np.meshgrid(*xlist), axis=-1).reshape(-1, dim_s_t)
            self.test_Y = self.ref_sol(self.test_X)
        else:
            if self.ref_data is None:
                raise RuntimeError("ref_data has not been initialized.")
            nan_mask = np.isnan(self.ref_data).any(axis=1)
            self.ref_data = self.ref_data[~nan_mask]
            if self.config["Spatial Dimension"] != self.config["Spatial Temporal Dimension"] \
                and self.config["Use Time Stepping"]:
                dim_s = self.config["Spatial Dimension"]
                output_dim = self.config["Output Dimension"]
                bbox = self.config["Bounding Box"]
                time_start, time_end = bbox[-2], bbox[-1]
                n_intervals = self.config["Sub-Time Intervals"]
                interval_len = (time_end - time_start) / n_intervals
                # Get the test data for each sub-time interval
                EPSILON = 1e-8
                ref_X = self.ref_data[:, :dim_s]
                ref_t = np.linspace(time_start, time_end, 
                    (self.ref_data.shape[1] - dim_s) // output_dim)
                ref_Y = self.ref_data[:, dim_s:]
                ref_Y = ref_Y.reshape(-1, ref_t.shape[0], output_dim)
                ref_Y0 = ref_Y[:, 0:1, :]
                self.test_X = [ref_X, []]
                self.test_Y = [ref_Y0, []]
                for i in range(n_intervals):
                    # Initial time is not included
                    start = i * interval_len + time_start + EPSILON
                    end = (i + 1) * interval_len + time_start + EPSILON
                    indices = (start < ref_t) & (ref_t < end)
                    if np.sum(indices) == 0:
                        self.test_X[1].append(None)
                        self.test_Y[1].append(None)
                        continue
                    self.test_X[1].append(ref_t[indices])
                    self.test_Y[1].append(ref_Y[:, indices, :])
            else:
                self.test_X = self.ref_data[:, :dim_s_t]
                self.test_Y = self.ref_data[:, dim_s_t:]

    def rearrange_u_data(self, u_data):
        """
        Rearrange the solution vector `u_data`
        to be 1D array if the output dimension is
        greater than 1.
            u_data: (n_points, output_dim)

        Note:
            For the time-dependent problems,
            `u_data` is of shape (n_points, n_timesteps, output_dim).
        """
        raise NotImplementedError("rearrange_u_data() is not implemented.")

    def setup(self, config: dict):
        """
        Setup the PDE data.
        """
        self.config = config
        self.init_problem()
        if self.config["Use Preconditioner"]:
            self.init_precond()
        elif self.config["Log Precondition Quality"]:
            self.compute_original_cond()
        self.prepare_test_data()

    def get_matrix(self):
        """
        Return the sparse matrix `A` 
        of the discretized PDE problem `Ax=b`.
        """
        return self.A
    
    def get_rhs(self):
        """
        Return the right-hand side `b` 
        of the discretized PDE problem `Ax=b`.
        """
        return self.b
    
    def get_precond(self):
        """
        Return the preconditioner `[Pr, L, U]` for the matrix `A`, 
        using the incomplete LU factorization.
        """
        return self.precond

    def get_coords(self):
        """
        Return the coordinates of `x` in the discretized PDE problem `Ax=b`.
        """
        return self.coords
    
    def get_test_data(self):
        """
        Return the test data.
        """
        return self.test_X, self.test_Y
