from dataclasses import dataclass
import numpy as np
from swimpde.Domain import Domain
from swimpde.Ansatz import Ansatz
from typing import Callable
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import scipy


@dataclass
class LaplaceBeltramiSolver:
    """
    Solver the Laplace Beltrami equation:
    ∂^(2) u(x) / ∂^(2) x = f(x)
    < ∂u(x), n > = 0 (n is the unit normal vector) 

    Attributes:
    -----------
    domain: Domain
    ansatz: Ansatz
        basis functions from which the solution will be built by linear combination
        (use BoundaryCompliantAnsatz for this solver to ensure the boundary conditions are fulfilled)
    boundary_condition: str 
        boundary condition, one of "zero neumann"/"zero derivative" or "zero dirichlet"/"zero"
    forcing: Callable  
        forcing, a function of x and t
    c: float = 1 
        wave speed, constant
    regularization_scale: float
        regularization scale for computing the matrix inverse and solving least squares roblems
    ode_solver: str
        ode solver (to be used as 'method' in scipy.integrate.solve_ivp)

    """
    domain: Domain
    ansatz: Ansatz

    boundary_condition: str 
    forcing: Callable

    regularization_scale: float = 1e-8
    ode_solver: str = 'DOP853' #'RK45'

    def __post_init__(self):
        # initialize internal parameters
        # ode solution for the time-dependent coefficients
        self._coefficients: np.ndarray = None 

        #self._coefficients_c: Callable = None # time dependent coefficients, solution from solver
        # matrices for the ODE
        #self._B: np.ndarray = None
        #self._A: np.ndarray = None
        #self._A_inv: np.ndarray = None
        #self._A_one: np.ndarray = None
        #self._V_a: np.ndarray = None

    def fit(self, num_svd=None):
        '''
        Approximate the solution of the Laplace Beltrami Equation by choosing the model parameters accordingly
        '''
        #initialize the model
        self.ansatz.init_model(self.domain)

        # fit the model to the forcing function
        self.ansatz.fit_model_laplace(self.domain.interior_points, self.forcing(self.domain.interior_points))
        du_dx = self.ansatz.evaluate_model_gradient(self.domain.boundary_points)
        # Compute gradients
        gradients = []
        for k in range(self.domain.boundary_points.shape[0]):
            gradients.append(du_dx[k,:,:] @ self.domain.normal_vectors[k,:])
        neumann_data_points = np.row_stack(gradients)
    
        # find a linear combination of basis functions that satisfies the PDE as well as possible

        matrix_in = np.row_stack(
            [
                self.ansatz.evaluate_model_laplace(self.domain.interior_points),
                neumann_data_points,
            ]
        )
        ######### CD: See if the following is correct!

        matrix_in = np.column_stack(
            [
                matrix_in,
                np.row_stack(
                    [np.zeros((self.domain.interior_points.shape[0], 1)), np.zeros((self.domain.boundary_points.shape[0], 1))]
                ),
            ]
        )
        
        matrix_out = np.hstack([self.forcing(self.domain.interior_points), np.zeros((np.shape(self.domain.boundary_points)[0], ))])
        last_layer_weights = np.linalg.lstsq(
            matrix_in, matrix_out, rcond=self.ansatz.regularization_scale
        )[0]

        self._coefficients = last_layer_weights

        return self

       
    def evaluate(self, x):
        '''
        Evaluate the solution
        '''
        return self.ansatz.evaluate_model(x) @ self._coefficients[:-1] + self._coefficients[-1]
        #return self.ansatz.evaluate_model(x) @ self._coefficients #+ self._coefficients[-1]

    def evaluate_laplace(self, x):
        '''
        Evaluate the learned forcing function
        '''
        return self.ansatz.evaluate_model_laplace(x) @ self._coefficients[:-1] 
        #return self.ansatz.evaluate_model_laplace(x) @ self._coefficients