from dataclasses import dataclass
import numpy as np
from swimpde.Domain import Domain
from swimpde.Ansatz import Ansatz
from typing import Callable
from scipy.integrate import solve_ivp



@dataclass
class WaveSolver:
    """
    Solver for the wave equation
    ∂²u(x,t)/∂t² = c²∇∙(γ(x)∇u(x,t)) + 1/p * f(t,x)

    with initial condition for u(t,x) and ∂u(x,t)/∂t and boundary condition (only zero dirichlet/neumann supported)

    Attributes:
    -----------
    domain: Domain
    ansatz: Ansatz
        basis functions from which the solution will be built by linear combination
        (use BoundaryCompliantAnsatz for this solver to ensure the boundary conditions are fulfilled)
    u0: Callable
        solution at time t0
    ut0: Callable
        time derivative of the solution at t0 
    boundary_condition: str 
        boundary condition, one of "zero neumann"/"zero derivative" or "zero dirichlet"/"zero"
    gamma: Callable  
        gamma, a function of x
    gamma_x: Callable  
        derivative of gamma with respect to x
    forcing: Callable  
        forcing, a function of x and t
    c: float = 1 
        wave speed, constant
    rho: float = 1 
        density, constant
    regularization_scale: float
        regularization scale for computing the matrix inverse and solving least squares roblems
    ode_solver: str
        ode solver (to be used as 'method' in scipy.integrate.solve_ivp)

    """
    domain: Domain
    ansatz: Ansatz

    u0: Callable
    ut0: Callable
    boundary_condition: str 
    forcing: Callable
    gamma: Callable = lambda x: np.ones_like(x[:, :1])
    gamma_x: Callable = lambda x: np.zeros_like(x[:, :1])
    c: float = 1
    rho: float = 1

    regularization_scale: float = 1e-8
    ode_solver: str = 'RK45'



    def __post_init__(self):
        # initialize internal parameters
        # ode solution for the time-dependent coefficients
        self._coefficients_cd: Callable = None # time dependent coefficients, solution from solver

        # matrices for the ODE
        self._gamma_x_B: np.ndarray = None
        self._gamma_D: np.ndarray = None
        self._gamma_A_inv: np.ndarray = None
        self._A_one: np.ndarray = None

    
    def fit(self, t_span):
        '''
        Approximate the solution of the wave problem by choosing the model parameters and time-dependent coefficients accordingly
        '''
        # set up the model for the ansatz function
        self.ansatz.init_model(self.domain, self.boundary_condition) 

        # define the ODE to be solved for the time-dependent coefficients
        def cd_t(t, cd):
            c, d = cd[:len(cd)//2, ], cd[len(cd)//2:, ]
            c_t = d.ravel()
            rhs = c @ self._gamma_x_B.T + c @ self._gamma_D.T + self.forcing(self.domain.all_points,t)*(1/self.rho)
            d_t = self._gamma_A_inv @ rhs
            return np.concatenate([c_t, d_t])

        # compute the matrices needed in the ODE
        self._init_matrices()

        # get the initial value for the ODE
        cd_0 = self._get_cd0()

        # solve the ODE
        self._coefficients_cd = solve_ivp(fun=cd_t, t_span=t_span, y0=cd_0, dense_output=True, method=self.ode_solver).sol 

        return self
    
    def evaluate(self, x_eval, t_eval):
        '''
        Evaluate the solution at given time and space points

        Parameters:
            x_eval: (n_eval, d), n_eval is the number of points, d is the dimension
            t_eval: (t, )
            
        Returns:
            sol_wave: (n, t)
        '''
        
        sol_cd = self._coefficients_cd(t_eval).T
        sol_c = sol_cd[:, :sol_cd.shape[1]//2]
        sol_wave = self.ansatz.evaluate_model(x_eval) @ sol_c.T

        return sol_wave
        
    def _init_matrices(self):
        '''
        Set all matrices that occur in the ODE

        Parameters:
        rcond: regularization scale for inverse in gamma_A_inv
        '''
        B =  self.ansatz.evaluate_model_gradient(self.domain.all_points)
        self._gamma_x_B = self.c**2 * np.einsum("nid,nd->ni", B, self.gamma_x(self.domain.all_points))

        D = self.ansatz.evaluate_model_laplace(self.domain.all_points)
        self._gamma_D = self.c**2 * self.gamma(self.domain.all_points) * D
        self._D = D

        A = self.ansatz.evaluate_model(self.domain.all_points)
        gamma_A = self.gamma(self.domain.all_points) * A
        self._gamma_A_inv = np.linalg.pinv(gamma_A, self.regularization_scale)
        self._A = A
        print("gamma_A_inv: ", self._gamma_A_inv.shape)

        return self
    
    def _get_cd0(self):
        '''
        Initial condition of the time-dependent coefficients for the ODE solver
        Returns:
            cd0: shape ((k+1)*2, ); initial condition for c (first k+1 entries) and d (= c_t, last k+1 entries)
        '''
        c0 = np.linalg.lstsq(self._A, self.u0(self.domain.all_points), self.regularization_scale)[0]
        d0 = np.linalg.lstsq(self._A, self.ut0(self.domain.all_points), self.regularization_scale)[0]
        cd0 = np.concatenate([c0, d0]).ravel()
        return cd0