from dataclasses import dataclass
import numpy as np
from swimpde.Domain import Domain
from swimpde.Ansatz import Ansatz
from typing import Callable
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import scipy
from dataclasses import field
import copy

@dataclass
class AllenCahnSolver:
    """
    Solver for the Burgers equation
    u_t - (u_xx + u_yy + u_zz) + u^3 - u = f
    with initial condition for u(t,x)

    Attributes:
    -----------
    domain: Domain
    ansatz: Ansatz
        basis functions from which the solution will be built by linear combination
        (use BoundaryCompliantAnsatz for this solver to ensure the boundary conditions are fulfilled)
    u0: Callable
        solution at time t0
    boundary_condition: str 
        boundary condition, one of "zero neumann"/"zero derivative" or "zero dirichlet"/"zero"
    forcing: Callable  
        forcing, a function of x and t
    nu: float = 1 
        diffusivity, constant
    regularization_scale: float
        regularization scale for computing the matrix inverse and solving least squares roblems
    ode_solver: str
        ode solver (to be used as 'method' in scipy.integrate.solve_ivp)

    """
    domain: Domain
    ansatz: Ansatz

    u0: Callable
    boundary_condition: str 
    forcing: Callable

    regularization_scale: float = 1e-8
    ode_solver: str = 'DOP853' #'RK45'

    # The following is required for the re-sampling procedure!
    ansatz_collection: list = field(default_factory=list)
    c_collection: list = field(default_factory=list)
    svd_collecion: list = field(default_factory=list)

    def __post_init__(self):
        # initialize internal parameters
        # ode solution for the time-dependent coefficients
        self._coefficients_c: Callable = None # time dependent coefficients, solution from solver

        # matrices for the ODE
        self._B: np.ndarray = None
        self._A: np.ndarray = None
        self._A_inv: np.ndarray = None
        #self._C: np.ndarray = None
        self._V_a: np.ndarray = None

    
    def evaluate(self, x_eval, t_eval):
        '''
        Evaluate the solution at given time and space points

        Parameters:
            x_eval: (n_eval, d), n_eval is the number of points, d is the dimension
            t_eval: (t, )
            
        Returns:
            sol_burger: (n, t)
        '''
        sol_c = self._coefficients_c(t_eval.reshape((np.shape(t_eval)[0], ))).T
        #sol_c = self._coefficients_c(t_eval).T
        sol_burger = self.ansatz.evaluate_model(x_eval) @ (self._V_a).T @ sol_c.T
        return sol_burger
        
    def _init_matrices(self, num_components=None):
        '''
        Set all matrices that occur in the ODE

        Parameters:
        rcond: regularization scale for inverse in gamma_A_inv
        '''
        self._A = self.ansatz.evaluate_model(self.domain.all_points)
        if num_components is None:
            U_a, S_a, V_a = np.linalg.svd(self._A, full_matrices=False)
            self._V_a = V_a
            print("Condition number of the non-orthogonal basis function: ", np.linalg.cond(self._A))

        else:
            U_a, S_a, V_a = scipy.sparse.linalg.svds(self._A , k=num_components)
            self._V_a = V_a
            S_a = np.flip(S_a)
        #U_a, S_a, V_a = np.linalg.svd(self._A)
        #self._V_a = V_a
        self._A = self._A @ self._V_a.T
        self._B = (self.ansatz.evaluate_model_laplace(self.domain.all_points))
        self._B = (self._B).reshape((self._B.shape[0], -1)) @ (self._V_a).T # n_points, n_neurons, 1
        # CD: Calculate the generalized inverse of a matrix using its singular-value decomposition (SVD) and including all large singular values.
        self._A_inv = np.linalg.pinv(self._A, self.regularization_scale)
        #self._C = (self.ansatz.evaluate_model_laplace(self.domain.all_points)) @ (self._V_a).T
        #U_b, S_b, V_b = np.linalg.svd(self._B.reshape((self._B.shape[0], -1)))

        
        # Adding labels and title
        
        #U_b, S_b, V_b = np.linalg.svd(self._B)
        #U_c, S_c, V_c = np.linalg.svd(self._C)
        plt.semilogy(np.arange(len(S_a)), S_a, label="feature matrix")
        #plt.semilogy(np.arange(len(S_b)), S_b, label="grad")
        #lt.semilogy(np.arange(len(S_c)), S_c, label="laplacian")
        plt.ylabel('Singular Values')
        plt.title("Singular Values of data matrix and its derivative")
        # Displaying the legend and the plot
        plt.legend()
        plt.show()
        
        #print("_A_inv: ", self._A_inv.shape)
        return self
    
    def _get_c0(self, initial_sol=None):
        '''
        Initial condition of the time-dependent coefficients for the ODE solver
        initial_sol =  Solution at the end of the previous time-block evaluated at all domain points
        Returns:
            c0: shape ((k+1)*2, ); initial condition for c (first k+1 entries) and d (= c_t, last k+1 entries)
        '''
        if initial_sol is None:
            c0 = np.linalg.lstsq(self._A, self.u0(self.domain.all_points), self.regularization_scale)[0]
        else:
            c0 = np.linalg.lstsq(self._A, initial_sol, self.regularization_scale)[0]
        return c0
    
    
    def fit(self, t_span, rtol=1e-8, atol=1e-8, num_svd=None, time_blocks = 1):
        '''
        Approximate the solution of the advection problem by choosing the model parameters and time-dependent coefficients accordingly
        '''
        # set up the model for the ansatz function
        self.ansatz.init_model(self.domain, self.boundary_condition) 

        # define the ODE to be solved for the time-dependent coefficients
        def c_t(t, c):
            f = self.forcing(self.domain.all_points,t).reshape(-1)
            diff_term = (c @ self._B.T)
            u = (c @ self._A.T)
            non_lin_term = u**3 - u #
            rhs = (f + diff_term - non_lin_term).T
            c_t = self._A_inv @ rhs
            return c_t.reshape(-1)

        # compute the matrices needed in the ODE
        self._init_matrices(num_components=num_svd)

        # get the initial value for the ODE
        c_0 = self._get_c0().reshape(-1)
        
        def event_func(t, y):
            # Define the event function to trigger when the absolute value of the solution exceeds a particular value
            return max(np.abs(y)) - 1e10

        event_func.terminal = True

        # solve the ODE
        solver = solve_ivp(fun=c_t, t_span=t_span, y0=c_0, dense_output=True, method=self.ode_solver, rtol=rtol, atol=atol,events=event_func)
        self._coefficients_c = solver.sol 
        
        # Check if the integration was successful and the event was triggered
        if solver.status == 0:
            print("Integration successful.")
        else:
            print("Integration failed or terminated due to exceeding the maximum absolute value.")

        return self

    
    def fit_time_blocks(self, t_span, rtol=1e-8, atol=1e-8, num_svd=None, time_blocks = 1):
        '''
        Approximate the solution of the advection problem by choosing the model parameters and time-dependent coefficients accordingly
        '''
        # define the ODE to be solved for the time-dependent coefficients
        def c_t(t, c):
            f = self.forcing(self.domain.all_points,t)
            diff_term = (self.c * c @ self._C.T)
            non_lin_term = (c @ self._A.T) * (c @ self._B.T)
            rhs = (f + diff_term - non_lin_term).T
            c_t = self._A_inv @ rhs
            return c_t.reshape(-1)
        
        def event_func(t, y):
            # Define the event function to trigger when the absolute value of the solution exceeds a particular value
            return max(np.abs(y)) - 1e10
        
        event_func.terminal = True # Terminate initial value problem if event_func is satisfied
        self.ansatz_collection = []
        self.c_collection = []
        self.svd_collecion = []
        
        t_block_size = (t_span[-1] - t_span[0])/time_blocks
        for i in range(time_blocks):
            if i == 0:
                # set up the model for the ansatz function
                self.ansatz.init_model(self.domain, self.boundary_condition)
                self.ansatz_collection.append(copy.deepcopy(self.ansatz))

            else:
                # set up the model for the ansatz function: Pass previous solution as the target function
                self.ansatz.init_model(self.domain, self.boundary_condition, initial_condition=initial_sol) #, initial_condition=initial_sol
                self.ansatz_collection.append(copy.deepcopy(self.ansatz))
            
            # compute the matrices needed in the ODE
            self._init_matrices(num_components=num_svd)
            
            # Store the SVD: Required for evaluating later in the time-blocking approach
            self.svd_collecion.append(self._V_a)
            
            # Solve the ODE for one time block
            t_block = [i * t_block_size, (i+1) * t_block_size]
            
            # Initialize coeffcients for the (re)-sampled weights
            if i == 0:
                c_0 = self._get_c0().reshape(-1)
            else:
                c_0 = self._get_c0(initial_sol=initial_sol).reshape(-1)
                #c_0 = self._get_c0().reshape(-1)

            # solve the ODE
            solver = solve_ivp(fun=c_t, t_span=t_block, y0=c_0, dense_output=True, method=self.ode_solver, rtol=rtol, atol=atol,events=event_func)
            self._coefficients_c = solver.sol

            # Store the interpolant to evaluate afterwards
            self.c_collection.append(self._coefficients_c)

            initial_sol = self.evaluate(self.domain.all_points, t_block[1].reshape(-1,))
        return self, solver.status
    
    def evaluate_blocks(self, x_eval, t_eval, time_blocks = 1, solver_status=0):
        '''
        Evaluate the solution at given time and space points

        Parameters:
            x_eval: (n_eval, d), n_eval is the number of points, d is the dimension
            t_eval: (t, )
            
        Returns:
            sol_burger: (n, t)
        '''
        t_block_size = (t_eval[-1] - t_eval[0])/time_blocks
        for i in range(time_blocks):
            if i < time_blocks - 1:
                sol_c = self.c_collection[i](t_eval[(i*t_block_size <= t_eval) & (t_eval < (i+1)*t_block_size)]).T
            else:
                sol_c = self.c_collection[i](t_eval[(i*t_block_size <= t_eval) & (t_eval <= (i+1)*t_block_size)]).T
            
            # Compute solution of Burgers equation using appropriate basis functions for the particular time-block
            sol_burger_block = self.ansatz_collection[i].evaluate_model(x_eval) @ self.svd_collecion[i].T @ sol_c.T
            if i == 0:
                #sol_c = sol
                sol_burger = sol_burger_block
                #sol_c = sol_c[:, :sol_c.shape[1]//2]
            else:
                #sol_c = np.vstack((sol_c, sol))
                sol_burger = np.hstack((sol_burger, sol_burger_block))

        #sol_burger = self.ansatz_collection[i].evaluate_model(x_eval) @ sol_c.T # Need to add the SVD part!!
        #sol_burger = self.ansatz.evaluate_model(x_eval) @ sol_c.T # Need to add the SVD part!!
        return sol_burger



"""

    def fit_resample(self, t_span, rtol=1e-8, atol=1e-8, num_svd=None, time_blocks = 1):
        '''
        Approximate the solution of the advection problem by choosing the model parameters and time-dependent coefficients accordingly
        '''
        ### CD: Loop over the time blocks
        # set up the model for the ansatz function
        self.ansatz.init_model(self.domain, self.boundary_condition) 

        # define the ODE to be solved for the time-dependent coefficients
        def c_t(t, c):
            f = self.forcing(self.domain.all_points,t)
            diff_term = (self.c * c @ self._C.T)
            non_lin_term = (c @ self._A.T) * (c @ self._B.T)
            rhs = (f + diff_term - non_lin_term).T
            c_t = self._A_inv @ rhs
            return c_t.reshape(-1)

        # compute the matrices needed in the ODE
        self._init_matrices(num_components=num_svd)

        # get the initial value for the ODE
        c_0 = self._get_c0().reshape(-1)
        
        def event_func(t, y):
            # Define the event function to trigger when the absolute value of the solution exceeds a particular value
            return max(np.abs(y)) - 1e10

        event_func.terminal = True

        # solve the ODE
        solver = solve_ivp(fun=c_t, t_span=t_span, y0=c_0, dense_output=True, method=self.ode_solver, rtol=rtol, atol=atol,events=event_func)
        self._coefficients_c = solver.sol 
        
        # Check if the integration was successful and the event was triggered
        if solver.status == 0:
            print("Integration successful.")
        else:
            print("Integration failed or terminated due to exceeding the maximum absolute value.")

        return self

"""

"""
            if i == 0:
                # get the initial value for the ODE
                c_0 = self._get_c0().reshape(-1)
            else:
                # Initialize coefficients for the current time block as the final values of the previous time block
                c_0 = self._coefficients_c[i-1](t_block[0]).reshape(-1)
"""
"""
            if solver.status == 0:
                self._coefficients_c.append(solver.sol)
            else:
                print("Integration failed or terminated due to exceeding the maximum absolute value.")
                for j in range(i, time_blocks):
                    self._coefficients_c.append(lambda t : t * 100)
                break
"""