import numpy as np
import scipy
import scipy.integrate


"""TODO: code needs improvement
    don't put everything in solve()
    find a better way to deal with time and evaluation, some solver requires fixed timestep
    also a better way to store parameters
    """



class IGA:
    def __init__(
            self,
            n: int,
            k: int
        ):
        self.n = n
        self.k = k

    def _get_knot_vector(self, n, k, x_lim):

        knot_vector = np.zeros((n+k*2,))
        knot_vector[:k] = x_lim[0]
        knot_vector[-k:] = x_lim[1]
        knot_vector[k:-k] = np.linspace(x_lim[0], x_lim[1], n)

        return knot_vector
    

class BurgerDirichlet(IGA):
    def __init__(self, n, k):
        self.xMin = -1
        self.xMax = 1
        self.x_eval = np.linspace(self.xMin, self.xMax, 256)
        self.Tmax = 1
        self.nTimesteps = 100
        self.t_eval = np.linspace(0, self.Tmax, self.nTimesteps+1)[:-1]
        super(BurgerDirichlet, self).__init__(n=n, k=k)
    
    def compute_gt(self):
        gt = scipy.io.loadmat("./data/burgers_shock.mat")["usol"]
        return gt
    
    def solve(self):
        xMin = self.xMin
        xMax = self.xMax
        n = self.n
        k = self.k
        nodes = np.linspace(xMin, xMax, n)
        nSplines = n + k - 1
        nSplinesActive = k + 1
        splines, splinesDer = self._get_splines_list_dirichlet(n, k, [xMin, xMax])
        quad_roots, quad_weights = scipy.special.roots_legendre(k+1)

        Mg = np.zeros((nSplines,nSplines))
        Kg = np.zeros((nSplines,nSplines))
        A = np.zeros(((k+1)*(n-1), nSplines))
        B = np.zeros(((k+1)*(n-1), nSplines))
        beta = 0.01/np.pi

        for m in range(n-1):
            roots_x = (nodes[m+1]-nodes[m])/2 * quad_roots + (nodes[m+1]+nodes[m])/2
            
            patch_splines_x_list = splines[m:m+nSplinesActive]
            patch_splines_x = np.zeros((nSplinesActive, k+1))
            patch_splines_der_x_list = splinesDer[m:m+nSplinesActive]
            patch_splines_der_x = np.zeros((nSplinesActive, k+1))
            for p in range(k+1):
                patch_splines_x[p,:] = patch_splines_x_list[p](roots_x)
                patch_splines_der_x[p,:] = patch_splines_der_x_list[p](roots_x)
            Me = np.einsum("ak,bk->ab",patch_splines_x*quad_weights,patch_splines_x)
            Me *= (nodes[m+1]-nodes[m])/2
            Ke = np.einsum("ak,bk->ab",patch_splines_der_x*quad_weights*beta,patch_splines_der_x)
            Ke *= (nodes[m+1]-nodes[m])/2
            # # nonlinear term
            A[m*(k+1):(m+1)*(k+1),m:m+nSplinesActive] = patch_splines_x.T
            B[m*(k+1):(m+1)*(k+1),m:m+nSplinesActive] = (patch_splines_der_x * quad_weights).T * (nodes[m+1]-nodes[m])/2
            current_id = m
            Mg[current_id:current_id+nSplinesActive, current_id:current_id+nSplinesActive] += Me
            Kg[current_id:current_id+nSplinesActive, current_id:current_id+nSplinesActive] += Ke

        # remove unnecessary fake splines
        Mg = Mg[1:-1, 1:-1]
        Kg = Kg[1:-1, 1:-1]
        A = A[:, 1:-1]
        B = B[:, 1:-1]

        Mginv = np.linalg.inv(Mg)

        n_init_points = (n+k-1)*10
        rng = np.random.default_rng(seed=0)
        init_points = rng.uniform(xMin, xMax,size=(n_init_points,))

        Phi = np.zeros((n_init_points, len(splines)-2))
        for i in range(1, len(splines)-1):
            Phi[:, i-1] = splines[i](init_points)

        u0 = - np.sin(init_points * np.pi).reshape((-1,1))
        a0 = np.linalg.lstsq(Phi, u0, rcond=1e-8)[0].reshape((-1,1))

        def a_t(t, a):
            nonlinear_term = (A.T) @ (((A @ a).reshape(-1,1)) * ((B @ a).reshape(-1,1)))
            rhs = -(Kg @ a).reshape(-1,1) - nonlinear_term
            a_t = Mginv @ (rhs.reshape(-1,1)).ravel()
            return a_t

        # a = scipy.integrate.solve_ivp(a_t, t_span=[0,self.Tmax], t_eval=self.t_eval, y0=a0.ravel(), method="RK45", rtol=1e-6, atol=1e-6).y
        a = scipy.integrate.solve_ivp(a_t, t_span=[0,self.Tmax], t_eval=self.t_eval, y0=a0.ravel(), method="LSODA", rtol=1e-9, atol=1e-9).y

        Phi_eval = np.zeros((len(self.x_eval), len(splines)-2))

        for i in range(1, len(splines)-1):
            Phi_eval[:, i-1] = splines[i](self.x_eval)

        sol = Phi_eval @ a

        return Mg.shape[0], sol


    def _get_splines_list_dirichlet(self, n, k, x_lim):
        n_splines = n + k - 1
        fake_spline = lambda x: x*0
        knot_vector = self._get_knot_vector(
            n,
            k,
            x_lim
        )
        coeffs = np.eye(n_splines)
        splines = []
        splines_der = []

        splines.append(fake_spline)
        splines_der.append(fake_spline)
    
        for i in range(1, n_splines-1):
            spline_current = scipy.interpolate.BSpline(knot_vector, coeffs[i,:], k)
            splines.append(spline_current)
            splines_der.append(spline_current.derivative(nu=1))

        splines.append(fake_spline)
        splines_der.append(fake_spline)

        return splines, splines_der
    
    

class EulerBernoulli(IGA):
    def __init__(self, n: int, k: int):
        self.xMin = 0
        self.xMax = np.pi
        self.x_eval = np.linspace(self.xMin, self.xMax, 256)
        self.Tmax = 1
        self.nTimesteps_eval = 100
        self.t_eval = np.linspace(0, self.Tmax, self.nTimesteps_eval)
        self.dt = self.Tmax / (self.nTimesteps_eval-1) / 100 # roughly 1e-4
        self.nTimesteps_solve = int(1e4)-100+1
        self.t_solve = np.linspace(0, self.Tmax, self.nTimesteps_solve)
        super(EulerBernoulli, self).__init__(n, k)

    def solve(self):
        def _forcing(x,t):
            return (1-16*np.pi**2) * np.sin(x) * np.cos(4*np.pi*t)

        def _forcing_full_time(x):
            full_time_matrix = np.zeros((len(x), len(self.t_solve)))
            for idx, i in enumerate(self.t_solve):
                full_time_matrix[:,idx] = _forcing(x,i)
            return full_time_matrix
        
        xMin = self.xMin
        xMax = self.xMax
        n = self.n
        k = self.k
        nodes = np.linspace(xMin, xMax, n)
        nSplines = n + k - 1
        nSplinesActive = k + 1
        splines, splinesDer = self._get_splines_list_dirichlet(n, k, [xMin, xMax])
        quad_roots, quad_weights = scipy.special.roots_legendre(k+1)

        Mg = np.zeros((nSplines,nSplines))
        Kg = np.zeros((nSplines,nSplines))
        Fg = np.zeros((nSplines, len(self.t_solve)))

        for m in range(n-1):
            roots_x = (nodes[m+1]-nodes[m])/2 * quad_roots + (nodes[m+1]+nodes[m])/2
            
            patch_splines_x_list = splines[m:m+nSplinesActive]
            patch_splines_x = np.zeros((nSplinesActive, k+1))
            patch_splines_der_x_list = splinesDer[m:m+nSplinesActive]
            patch_splines_der_x = np.zeros((nSplinesActive, k+1))
            for p in range(k+1):
                patch_splines_x[p,:] = patch_splines_x_list[p](roots_x)
                patch_splines_der_x[p,:] = patch_splines_der_x_list[p](roots_x)
            # # volumn forcing
            Fe = np.einsum("kn,nt->kt", patch_splines_x*quad_weights, _forcing_full_time(roots_x))
            Fe *=(nodes[m+1]-nodes[m])/2
            # # mass matrix
            Me = np.einsum("ak,bk->ab",patch_splines_x*quad_weights,patch_splines_x)
            Me *= (nodes[m+1]-nodes[m])/2
            # # stiffness matrix
            Ke = np.einsum("ak,bk->ab",patch_splines_der_x*quad_weights,patch_splines_der_x)
            Ke *= (nodes[m+1]-nodes[m])/2
            current_id = m
            Mg[current_id:current_id+nSplinesActive, current_id:current_id+nSplinesActive] += Me
            Kg[current_id:current_id+nSplinesActive, current_id:current_id+nSplinesActive] += Ke
            Fg[current_id:current_id+nSplinesActive, :] += Fe

        Mg = Mg[1:-1,1:-1]
        Kg = Kg[1:-1,1:-1]
        Fg = Fg[1:-1,:]

        n_init_points = (n+k-1)*5

        rng = np.random.default_rng(seed=0)
        init_points = rng.uniform(xMin, xMax,size=(n_init_points,))

        Phi = np.zeros((n_init_points, len(splines)-2))

        for i in range(1, len(splines)-1):
            Phi[:, i-1] = splines[i](init_points)

        boundary_points = np.array([xMin, xMax]).reshape(-1,1)
        normal_vector = np.array([-1, 1]).reshape(-1,1)

        Phi_prime = np.zeros((2, len(splines)-2))
        for i in range(1, len(splines)-1):
            Phi_prime[:, i-1] = (splinesDer[i](boundary_points)*normal_vector).ravel()

        u0 = np.sin(init_points)
        a0 = np.linalg.lstsq(np.row_stack([Phi,Phi_prime,Kg-Mg]), np.concatenate([u0,np.zeros((boundary_points.shape[0]+Mg.shape[0], ))]), rcond=1e-8)[0].reshape((-1,1))

        a, _, _ = self._newmark(
            M=Mg,
            C=Mg*0,
            K=Kg,
            F=Fg,
            u0=a0.ravel(),
            ut0=(a0*0).ravel(),
            nt=self.nTimesteps_solve,
            dt=self.dt
        )

        Phi_eval = np.zeros((len(self.x_eval), len(splines)-2))
        for i in range(1, len(splines)-1):
            Phi_eval[:, i-1] = splines[i](self.x_eval)

        sol = Phi_eval @ a[:, ::100]

        return Mg.shape[0], sol
    
    def compute_gt(self):
        gt = np.zeros((len(self.x_eval), len(self.t_eval)))

        for nstep,i in enumerate(self.t_eval):
            gt[:,nstep] = np.sin(self.x_eval)*np.cos(4*np.pi*i)
        
        return gt

    def _newmark(
            self,
            M,
            C,
            K,
            F,
            u0,
            ut0,
            nt,
            dt,
            gaama=1/2,
            beta=1/4
        ):

        n = M.shape[0]

        a0=1/(beta*(dt**2))
        a1=gaama/(beta*dt)
        a2=1/(beta*dt)
        a3=(1/(2*beta))-1
        a4=(gaama/beta)-1
        a5=(dt/2)*((gaama/beta)-2)
        a6=dt*(1-gaama)
        a7=gaama*dt

        depl = np.zeros((n,nt+1))
        vel = np.zeros((n,nt+1))
        accl = np.zeros((n,nt+1))

        depl[:,0] = u0
        vel[:,0] = ut0
        accl[:,0] = np.linalg.inv(M) @ (F[:,0]-C@vel[:,0]-K@depl[:,0])

        Kcap = K + a0*M + a1*C
        lu, piv = scipy.linalg.lu_factor(Kcap)

        a = a1*C + a0*M
        b = a4*C + a2*M
        c = a5*C + a3*M

        for i in range(1, nt):
            Fcap = F[:,i] + a@depl[:,i-1] + b@vel[:,i-1] + c@accl[:,i-1]
            depl[:,i] = scipy.linalg.lu_solve((lu, piv), Fcap)
            accl[:,i] = a0*(depl[:,i]-depl[:,i-1]) - a2*vel[:,i-1] - a3*accl[:,i-1]
            vel[:,i] = vel[:,i-1] + a6*accl[:,i-1] + a7*accl[:,i]
        
        return depl, vel, accl



    def _get_splines_list_dirichlet(self, n, k, x_lim):
        n_splines = n + k - 1
        fake_spline = lambda x: x*0
        knot_vector = self._get_knot_vector(
            n,
            k,
            x_lim
        )
        coeffs = np.eye(n_splines)
        splines = []
        splines_der = []

        splines.append(fake_spline)
        splines_der.append(fake_spline)
    
        for i in range(1, n_splines-1):
            spline_current = scipy.interpolate.BSpline(knot_vector, coeffs[i,:], k)
            splines.append(spline_current)
            splines_der.append(spline_current.derivative(nu=2))

        splines.append(fake_spline)
        splines_der.append(fake_spline)

        return splines, splines_der
    

class AdvectionPeriodic(IGA):
    def __init__(self, beta, n, k):
        self.beta = beta
        self.xMin = 0
        self.xMax = 2 * np.pi
        self.x_eval = np.linspace(self.xMin, self.xMax, 256)
        self.Tmax = 1
        self.nTimesteps = 100
        self.t_eval = np.linspace(0, self.Tmax, self.nTimesteps)
        super(AdvectionPeriodic, self).__init__(n=n, k=k)

    def compute_gt(self):
        gt = np.zeros((len(self.x_eval), len(self.t_eval)))
        for nstep,i in enumerate(self.t_eval):
            gt[:,nstep] = np.sin(self.x_eval-self.beta*i)
        return gt
    
    def solve(self):
        xMin = self.xMin
        xMax = self.xMax
        n = self.n
        k = self.k
        beta = self.beta
        nodes = np.linspace(xMin, xMax, n)
        nSplines = n + k - 1
        nSplinesActive = k + 1
        splines, splinesDer = self._get_splines_list_periodic(n, k, [xMin, xMax])
        quad_roots, quad_weights = scipy.special.roots_legendre(k+1)

        Mg = np.zeros((nSplines-k,nSplines-k)) # some splines are the same ones
        Kg = np.zeros((nSplines-k,nSplines-k))

        for m in range(n-1):
            roots_x = (nodes[m+1]-nodes[m])/2 * quad_roots + (nodes[m+1]+nodes[m])/2
            patch_splines_x_list = splines[m:m+nSplinesActive]
            patch_splines_x = np.zeros((nSplinesActive, k+1))
            patch_splines_der_x_list = splinesDer[m:m+nSplinesActive]
            patch_splines_der_x = np.zeros((nSplinesActive, k+1))
            for p in range(k+1):
                patch_splines_x[p,:] = patch_splines_x_list[p](roots_x)
                patch_splines_der_x[p,:] = patch_splines_der_x_list[p](roots_x)
            Me = np.einsum("ak,bk->ab",patch_splines_x*quad_weights,patch_splines_x)
            Me *= (nodes[m+1]-nodes[m])/2
            Ke = np.einsum("ak,bk->ab",patch_splines_x*quad_weights*beta,patch_splines_der_x)
            Ke *= (nodes[m+1]-nodes[m])/2
            if m < k:
                Mg[-(k-m):, -(k-m):] += Me[0:k-m, 0:k-m]
                Mg[:m+1, :m+1] += Me[-(m+1):, -(m+1):]
                Mg[:m+1, -(k-m):] += Me[-(m+1):, :k-m]
                Mg[-(k-m):, :m+1] += Me[:k-m, -(m+1):]
                Kg[-(k-m):, -(k-m):] += Ke[0:k-m, 0:k-m]
                Kg[:m+1, :m+1] += Ke[-(m+1):, -(m+1):]
                Kg[:m+1, -(k-m):] += Ke[-(m+1):, :k-m]
                Kg[-(k-m):, :m+1] += Ke[:k-m, -(m+1):]
            else:
                Mg[m-k:m-k+nSplinesActive, m-k:m-k+nSplinesActive] += Me
                Kg[m-k:m-k+nSplinesActive, m-k:m-k+nSplinesActive] += Ke

        invM = np.linalg.inv(Mg)

        n_init_points = (n+k-1)*5
        rng = np.random.default_rng(seed=0)
        init_points = rng.uniform(xMin, xMax,size=(n_init_points,))

        Phi = np.zeros((n_init_points, len(splines)))

        for i in range(len(splines)):
            Phi[:, i] = splines[i](init_points)
        
        u0 = np.sin(init_points).reshape((-1,1))

        # same splines share the same coefficients
        Phi_reduce = Phi.copy()
        Phi_reduce[-k:,:] += Phi_reduce[0:k,:]
        Phi_reduce[:,-k:] += Phi_reduce[:,0:k]
        Phi_reduce = Phi_reduce[k:,k:]
        u0_reduce = u0.copy()
        u0_reduce[-k:] += u0_reduce[0:k]
        u0_reduce = u0_reduce[k:]

        a0 = np.linalg.lstsq(Phi_reduce, u0_reduce, rcond=1e-8)[0].reshape((-1,1))

        ainvb = invM @ Kg
        def a_t(t, a):
            a_t = -(ainvb @ (a.reshape(-1,1))).ravel()
            return a_t
        a_reduce = scipy.integrate.solve_ivp(a_t, t_span=[0,self.Tmax], t_eval=self.t_eval, y0=a0.ravel(), method="LSODA", rtol=1e-12, atol=1e-12).y
        a = np.zeros((len(splines), self.nTimesteps))
        a[k:,:] = a_reduce.copy()
        a[0:k,:] = a_reduce[-k:,:]

        Phi_eval = np.zeros((len(self.x_eval), len(splines)))

        for i in range(len(splines)):
            Phi_eval[:, i] = splines[i](self.x_eval)

        sol = Phi_eval @ a

        return Mg.shape[0], sol
    
    def _get_splines_list_periodic(self, n, k, x_lim):

        n_splines = n + k - 1
        knot_vector = self._get_knot_vector(
            n+2*k,
            k,
            [x_lim[0]-k*(x_lim[1]-x_lim[0])/(n-1), x_lim[1]+k*(x_lim[1]-x_lim[0])/(n-1)]
        )
        coeffs = np.eye(n_splines+2*k)
        splines = []
        splines_der = []
        for i in range(k, n_splines+k):
            spline_current = scipy.interpolate.BSpline(knot_vector, coeffs[i,:], k)
            splines.append(spline_current)
            splines_der.append(spline_current.derivative(nu=1))

        return splines, splines_der



    

