from dataclasses import dataclass
import numpy as np
from swimpde.Domain import Domain
from swimpde.Ansatz import Ansatz
from typing import Callable
import scipy 


@dataclass
class HelmholtzSolver:
    """
    Solver for the Helmholtz equation
        𝚫u(x) - u = f(x) on 𝛀

    Attributes:
    -----------
    domain: Domain
    f: Callable
        forcing function
    ansatz: Ansatz 
        basis functions from which the solution will be built by linear combination
        (use BasicAnsatz for this solver)
    """
    domain: Domain
    ansatz: Ansatz
    forcing: Callable 
    regularization_scale: float = 1e-8
    

    
    def __post_init__(self):
        self._coefficients: np.ndarray = None 
        # matrices for the ODE
        self._A: np.ndarray = None
        self._B: np.ndarray = None
        self._V_a: np.ndarray = None

    def _init_matrices(self, num_components=None):
        '''
        Set all matrices that occur in the ODE

        Parameters:
        rcond: regularization scale for inverse in gamma_A_inv
        '''
        self._A = self.ansatz.evaluate_model(self.domain.interior_points)
        if num_components is None:
            U_a, S_a, V_a = np.linalg.svd(self._A, full_matrices=False)
            self._V_a = V_a
            print("Condition number of the non-orthogonal basis function: ", np.linalg.cond(self._A))

        else:
            U_a, S_a, V_a = scipy.sparse.linalg.svds(self._A , k=num_components)
            self._V_a = V_a
            S_a = np.flip(S_a)

        self._A = self._A @ self._V_a.T
        self._B = (self.ansatz.evaluate_model_laplace(self.domain.interior_points)) @ (self._V_a).T
        #U_b, S_b, V_b = np.linalg.svd(self._B)

        # Adding labels and title
        
        #plt.semilogy(np.arange(len(S_a)), S_a, label="feature matrix")
        #plt.semilogy(np.arange(len(S_b)), S_b, label="grad")
        #plt.ylabel('Singular Values')
        #plt.title("Singular Values of data matrix and its derivative")
        # Displaying the legend and the plot
        #plt.legend()
        #plt.show()
        
        return self
    
    def fit(self, num_svd=None, approx_sol=None, analytical_sol:callable = None):
        '''
        Approximate the solution of the Poisson problem by choosing the model parameters accordingly
        '''
        #initialize the model
        self.ansatz.init_model(self.domain, initial_condition=approx_sol)

        # fit the model to the forcing function
        #self.ansatz.fit_model_helmholtz(self.domain.interior_points, self.forcing(self.domain.interior_points))
        #self.ansatz.fit_model(self.domain.interior_points, self.forcing(self.domain.interior_points))
        if approx_sol is None:
            target = np.zeros_like(self.forcing(self.domain.interior_points))
        else:
            target = approx_sol.reshape(-1, )
        self.ansatz.fit_model(self.domain.interior_points, target)

        # compute the matrices needed in the ODE
        self._init_matrices(num_components=num_svd)
        
        # find a linear combination of basis functions that satisfies the PDE as well as possible
        basis_space = self._B - self._A
        basis_space_bc = self.ansatz.evaluate_model(self.domain.boundary_points) @ self._V_a.T
        matrix_in = np.row_stack(
            [
                 basis_space,
                 basis_space_bc
            ]
        )
        matrix_bias = np.row_stack([
            np.ones((basis_space.shape[0],1)),
            np.ones((basis_space_bc.shape[0],1))
        ])
        matrix_in = np.column_stack([matrix_in,matrix_bias])

        matrix_out = np.hstack([
            self.forcing(self.domain.interior_points).T,
            analytical_sol(self.domain.boundary_points)
        ])
        
        last_layer_weights = np.linalg.lstsq(
            matrix_in, matrix_out, rcond=self.ansatz.regularization_scale
        )[0]

        self._coefficients = last_layer_weights

        return self

       
    def evaluate(self, x):
        '''
        Evaluate the solution
        '''
        #print("coefficient corresponding to the bias", self._coefficients[-1])
        return self.ansatz.evaluate_model(x) @ (self._V_a).T @ self._coefficients[:-1] + self._coefficients[-1]
        #retun self.ansatz.evaluate_model(x_eval) @ (self._V_a).T @ sol_c.T

    def evaluate_laplace(self, x):
        '''
        Evaluate the learned forcing function
        '''
        return self.ansatz.evaluate_model_laplace(x)  @ (self._V_a).T @ self._coefficients[:-1] 
