import time
import itertools
import numpy as np
import jax
import jax.numpy as jnp
from scipy.sparse import diags
import logging
from tqdm import tqdm
jax.config.update("jax_enable_x64", True)




class LegendreSystemJax:
    def __init__(self, args):
        self.args = args
        self.N = args.N
        # self.M = args.M

        BOUNDARY_CONDITION = args.boundary_condition

        self.dtype = jnp.float64
        self.xx = self.LGLpoints() # Legendre-Gauss-Lobatto (LGL) quadrature
        self.lepolys = self.LegendrePolys()
        self.ww = self.LGLweights()
        self.ak, self.bk = self.ak_bk( self.N, BOUNDARY_CONDITION)
        self.phi = self.basis_1D()

        self.S_1D = self.S_Legendre(self.N, self.bk)
        self.R_1D = self.R_Legendre(self.N, self.ak, self.bk)
        self.M_1D = self.M_Legendre(self.N, self.ak, self.bk)


        self.Gram_matrix = self.M_1D

        if args.dimension == '1D':
            print(f'LegendreSystemJax - 1D')
            # Pure functions
            self.ForwardTransform = self.make_forward_transform_1d()
            self.BackwardTransform = self.make_backward_transform_1d()
            self.reconstruct = self.make_reconstruct_1d()
            self.LGL_quadrature = self.make_LGL_quadrature_1d()
            self.relative_l2 = self.make_relative_l2_1d()
            self.relative_linf = self.make_relative_linf_1d()
        elif args.dimension == '2D':
            print(f'LegendreSystemJax - 2D')
            self.X_2D, self.Y_2D = self.LGLpoints_2D()
            self.ww_2D = self.LGLweights_2D()
            self.phi_2D = self.basis_2D()

            self.MM_2D = self.AotimesB(self.N, self.M_1D , self.M_1D)
            self.SM_2D = self.AotimesB(self.N, self.S_1D , self.M_1D)
            self.MS_2D = self.AotimesB(self.N, self.M_1D , self.S_1D)
            self.RM_2D = self.AotimesB(self.N, self.R_1D , self.M_1D)
            self.MR_2D = self.AotimesB(self.N, self.M_1D , self.R_1D)

            self.ForwardTransform = self.make_forward_transform_2d()
            # self.BackwardTransform = self.make_backward_transform_2d()
            self.reconstruct = self.make_reconstruct_2d()
            self.LGL_quadrature = self.make_LGL_quadrature_2d()
            self.relative_l2 = self.make_relative_l2_2d()
            self.relative_linf = self.make_relative_linf_2d()
        elif args.dimension == '2D':
            self.X_3D, self.Y_3D, self.Z_3D = self.LGLpoints_3D()
            self.phi_3D = self.basis_3D()
        else:
            NotImplementedError('Wrong Dimension')
        self.AA = self.A_matrix(args)


    def A_matrix(self, args):
        if args.equation == 'RD_1D':
            '''
            - eps * u_xx + u = f
            '''
            AA = - args.pde_parameter * self.S_1D + self.M_1D
        elif args.equation == 'CD_1D':
            '''
            -eps * u_xx - u_x = f
            '''
            AA =  - args.pde_parameter * self.S_1D - self.R_1D
        elif args.equation == 'Helmholtz_1D':
            '''
            u_xx + ku * u = f
            '''
            AA =  self.S_1D + args.pde_parameter * self.M_1D
        if args.equation == 'RD_2D':
            '''
            - eps * (u_xx + u_yy) + u = f
            '''
            AA = - args.pde_parameter * (self.SM_2D + self.MS_2D) + self.MM_2D
        elif args.equation == 'Helmholtz_2D':
            '''
            (u_xx + u_yy) + k_u * u = f
            '''
            AA = (self.SM_2D + self.MS_2D) + args.pde_parameter * self.MM_2D
        elif args.equation == 'CD_2D':
            '''
            - eps * (u_xx + u_yy) - (u_x + u_y) = f
            '''
            AA = - args.pde_parameter * (self.SM_2D + self.MS_2D) - (self.RM_2D+self.MR_2D)
        else:
            NotImplementedError('Wrong Error')
        logging.info(f'PDE : {args.equation} with Parameter : {args.pde_parameter}')
        return AA

    def LGLpoints(self):
        NN = self.N
        dtype = self.dtype
        if NN <= 0:
            #print("n should be bigger than 1")
            return jnp.array([[]])
        elif NN == 1:
            return jnp.array([[-1], [1]])
        elif NN == 2:
            return jnp.array([[-1], [0], [1]])
        else:
            j = jnp.arange(1, NN - 1)
            bv = j * (j + 2) / ((2 * j + 1) * (2 * j + 3))
            A = jnp.diag(jnp.sqrt(bv), k=-1) + jnp.diag(jnp.sqrt(bv), k=1)
            # z = jnp.sort(jnp.linalg.eigvals(A)) 

            import scipy.linalg
            z = jnp.sort(jnp.array(scipy.linalg.eigvals(np.array(A))))

            z_1 = jnp.concatenate([jnp.array([-1]), z, jnp.array([1])])
            return z_1[:, None].astype(dtype)
        

    def LGLpoints_2D(self):
        NN = self.args.N
        xx = jnp.reshape(self.xx, (NN + 1,))
        X, Y = jnp.meshgrid(xx, xx, indexing="ij")  # numpy와 같은 방식
        return X, Y

    def LGLpoints_3D(self):
        NN = self.args.N
        xx = self.xx.view(NN + 1,)
        X, Y, Z = jnp.meshgrid(xx, xx, xx, indexing="ij")  # 3D 격자 생성
        return X, Y, Z

    def LGLweights(self):
        NN = self.N
        denom = jnp.square(self.lepolys[NN])
        ww = 2/(NN *(NN+1))/denom
        return jnp.reshape(ww, (1,NN+1))

    def LGLweights_2D(self):
        NN = self.args.N
        w_x = self.ww.reshape(NN+1, 1)
        w_y = self.ww.reshape(1, NN+1)
        ww_2D = w_x * w_y
        return ww_2D.reshape(NN+1, NN+1)
    
    def make_LGL_quadrature_1d(self,):
        NN = self.N
        ww = self.ww.reshape(1, 1, NN+1)
        @jax.jit
        def LGL_quadrature(ff):
            ff = ff.reshape(-1, 1, NN+1)
            batchsize = ff.shape[0]
            sum = jnp.sum(ff * ww, axis=2)
            return sum.reshape(batchsize, 1, 1)
        return LGL_quadrature
            
    def make_LGL_quadrature_2d(self,):
        NN = self.N
        ww = self.ww_2D.reshape(1, NN+1, NN+1)
        @jax.jit
        def LGL_quadrature_2D(ff):
            ww_2D = ww.reshape(-1, NN+1, NN+1)
            ff = ff.reshape(-1, NN+1, NN+1)
            sum = jnp.sum( (ff*ww_2D), axis=(-2, -1))
            return sum
        return LGL_quadrature_2D
    
    def LegendrePoly(self, n:int, x):
        dtype = self.dtype
        if n == 0:
            return jnp.ones_like(x).astype(dtype=jnp.float64)
        elif n == 1:
            return x.astype(dtype=jnp.float64)
        else:
            polylst = jnp.ones_like(x)
            poly = x                  
            for k in range(2,n+1):
                polyn = ((2*k-1)*x*poly-(k-1)*polylst)/k
                polylst, poly = poly, polyn
            return polyn.astype(dtype=dtype)
    
    def LegendrePolys(self):
        polys = [self.LegendrePoly(i, self.xx) for i in range(self.args.N + 1)]
        return jnp.stack(polys)  # shape: (N+1, N+1)
    

    def make_reconstruct_1d(self,):
        NN = self.N
        dtype=self.dtype 
        phi = self.phi  # shape: (N-1, N+1)
        def reconstruct_1D(alphas):
            '''
                # Input Dim (1, 1, N-1)
                # Output Dim (1, N+1)
            '''
            coeff = alphas.reshape(-1, 1, NN - 1)  # (B, 1, N-1)
            result = jnp.matmul(coeff, phi)  # (B, 1, N+1)
            # reconstruct_1D - alphas.shape: (10, 32)
            # reconstruct_1D - coeff.shape: (10, 1, 32)
            # reconstruct_1D - result.shape: (10, 1, 34)
            return result.astype(dtype)
        return reconstruct_1D


    def make_reconstruct_2d(self,):
        NN = self.N
        phi_2D = self.phi_2D
        def reconstruct_2d(alphas):
            '''
            Input shape: (B, N-1, N-1)
            Output shape: (B, N+1, N+1)
            '''
            # alphas: (B, N-1, N-1) → (B, N-1, N-1, 1, 1)
            coeff = jnp.reshape(alphas, (-1, NN - 1, NN - 1, 1, 1))  

            # phi_2D: (N-1, N-1, N+1, N+1) → broadcasting됨
            # result: (B, N-1, N-1, N+1, N+1)
            result = phi_2D * coeff  

            # sum over (N-1, N-1) → axis=(1, 2)
            sum_result = jnp.sum(result, axis=(1, 2))  # (B, N+1, N+1)
            return sum_result
        return reconstruct_2d
    
    # def reconstruct_2D(self, alphas):
    #     # Input Dim (-1, N-1, N-1)
    #     # Output Dim (1, N+1, N+1)
    #     NN = self.N
    #     coeff=  alphas.view(-1, self.N-1, self.N-1, 1, 1)
    #     batchsize, _, _, _, _ = coeff.shape
    #     PHI = self.phi_2D.view(-1, NN-1, NN-1, NN+1, NN+1)
    #     sum = torch.sum(PHI * coeff, axis=(1, 2))
    #     return sum.view(-1, NN+1, NN+1)
    
    def ak_bk(self, N, BOUNDARY_CONDITION):
        '''
            N : num of basis
            BOUNDARY_CONDITION : boundary condition
        '''
        if BOUNDARY_CONDITION == 'Dirichlet':
            ak, bk = jnp.zeros((N-1,)), -jnp.ones((N-1,))
            return ak, bk
        
        elif BOUNDARY_CONDITION == 'Neumann':
            kk = jnp.arange(N-1)
            ak = jnp.zeros((N-1,))
            bk = - kk * (kk +1) / ((kk + 2) * (kk+3))
            return ak, bk
        else:
            NotImplementedError('Wrong BDD')

    def basis_1D(self):
        N = self.N
        ak = self.ak
        bk = self.bk
        lepoly = self.lepolys
        dtype = self.dtype
        lepoly = jnp.stack(self.lepolys)  # 🔸 리스트 → JAX 배열

        phi = jnp.zeros((N - 1, N + 1), dtype=dtype)
        def update(i, uu):
            row = lepoly[i] + ak[i] * lepoly[i + 1] + bk[i] * lepoly[i + 2]
            row = row.reshape((N+1,))
            return uu.at[i, :].set(row)
        phi = jax.lax.fori_loop(0, N - 1, update, phi)
        # return jax.device_put(phi, self.args.device)
        return phi

    def basis_2D(self):
        NN = self.args.N
        phi = self.phi        
        phi_2D = jnp.einsum("ij,kl->ikjl", phi, phi)
        return phi_2D.reshape(NN-1, NN-1, NN+1, NN+1)
    
    def basis_3D(self):
        NN = self.args.N
        phi = self.phi
        phi_3D = jnp.einsum("ij,kl->ikjl"
                            , self.phi_2D.reshape((NN-1)**2, (NN+1)**2)
                            , phi)
        return phi_3D.reshape(NN-1, NN-1, NN-1, NN+1, NN+1, NN+1)
    
    def S_Legendre(self, NN, bk):
        dtype = self.dtype

        kk = jnp.arange(0,NN-1, dtype=dtype)
        s_diag= (4*kk+6)*bk
        S_1D = s_diag * jnp.eye(NN-1)
        self.s_diag = s_diag
        return S_1D
    
    def R_Legendre(self, NN, ak, bk):
        dtype = self.dtype
        RR = np.zeros((NN-1, NN-1), dtype=dtype)
        for j in range(NN -1):
            for k in range(NN -1):
                if k == j+1:
                    RR[k, j] = 2 * bk[j]
                elif k == j:
                    RR[k, j] = 2 * ak[j] + 2 * ak[k] * bk[j]
                elif k+1 == j:
                    RR[k, j] = 2 * (1 + ak[k] * ak[j] + bk[j] + bk[k] * bk[j])
                elif (k+1 < j) and (k + j) % 2 == 0:
                    RR[k, j] = 2 * ak[k] * (1 + bk[j]) + 2 * ak[j] * (1 + bk[k])
                elif (k+1 < j) and (k + j) % 2 == 1:
                    RR[k, j] = 2 * (1 + ak[k] * ak[j] + bk[k] + bk[j] + bk[k] * bk[j])
        return jnp.array(RR, dtype=dtype)
    

    def M_Legendre(self, NN, ak, bk):
        dtype = self.dtype
        
        M = np.zeros((NN-1, NN-1), dtype=dtype)
        for k in range(NN-1):
            # j = k
            M[k, k] = 2 / (2 * k + 1) + ak[k] ** 2 * 2 / (2 * k + 3) + bk[k] ** 2 * 2 / (2 * k + 5)
            if k + 1 < NN -1:
                # j = k + 1
                M[k, k + 1] = ak[k] * 2 / (2 * k + 3) + ak[k + 1] * bk[k] * 2 / (2 * k + 5)
                M[k + 1, k] = M[k, k + 1]
            if k + 2 < NN -1:
                # j = k + 2
                M[k, k + 2] = bk[k] * 2 / (2 * k + 5)
                M[k + 2, k] = M[k, k + 2]
        return jnp.array(M, dtype=dtype)

    # # 2D
    def AotimesB(self, NN, A, B):
        outer = jnp.einsum("ij,kl->ikjl", A, B)  # 크기 (N-1, N-1, N-1, N-1)
        kronecker = outer.reshape((NN-1)**2, (NN-1)**2)
        return kronecker


    ##### Weak Form #####
    def make_forward_transform_1d(self,):
        NN = self.args.N
        ww = self.ww.reshape(1, 1, self.N+1)
        phi = self.phi.reshape((1, self.N-1, self.N+1))

        @jax.jit
        def ForwardTransform_1D(ff):
            '''
                ff: (batch, N+1)
            '''
            ff = ff.reshape(-1, 1, NN+1)
            # ss = jnp.sum( (ff*ww)* phi, axis=2)
            ss = jnp.einsum('bij,bkj->bki', ff * ww, phi)
            return ss.reshape(-1, NN-1, 1)
        return ForwardTransform_1D
    
    def make_forward_transform_2d(self,):
        NN = self.args.N
        ww_2D = self.ww_2D.reshape(1,1,1,NN+1, NN+1)
        phi_2D = self.phi_2D.reshape((1, NN-1, NN-1, NN+1, NN+1))
        @jax.jit
        def ForwardTransform_2D(ff):
            '''
                ff: (batch x N+1 x N+1)
            '''
            ff = ff.reshape(-1,1,1,NN+1, NN+1)
            sum_result = jnp.sum((ff * ww_2D) * phi_2D, axis=(-2, -1))
            return sum_result.reshape(-1, NN-1, 1)
        return ForwardTransform_2D

    def make_backward_transform_1d(self):
        '''
            inv_Gram_matrix: (N-1, N-1)
            phi: (N-1, N+1)
        '''
        phi = self.phi
        Gram_matrix = self.Gram_matrix
        @jax.jit
        def BackwardTransform_1D(bar_f):
            '''
                input: bar_f (batch, N-1, 1)
                output: f_interpol (batch, N+1, 1)
            '''
            # Solve: Gram_matrix @ x = bar_f  ⇒ x = inv_coeff
            inv_coeff = jax.vmap(lambda b: jnp.linalg.solve(Gram_matrix, b.squeeze(-1)))(bar_f)  # (batch, N-1)
            f_interpol = jnp.einsum("ij,bj->bi", phi.T, inv_coeff)  # (batch, N+1)
            return f_interpol[..., None]  # (batch, N+1, 1)
        return BackwardTransform_1D



    ##### Loss #####
    def make_relative_l2_1d(self,):
        NN = self.N
        LGL_quadrature = self.make_LGL_quadrature_1d()
        @jax.jit
        def Relative_L2_1D(u_pred, u_true):
            BATCH_SIZE = u_pred.shape[0]
            u_pred = u_pred.reshape(-1, NN+1)
            u_true = u_true.reshape(-1, NN+1)
            q1 = LGL_quadrature((u_pred-u_true)**2)
            q2 = LGL_quadrature(u_true**2)
            loss = jnp.sqrt(q1 / q2)
            return loss.reshape(BATCH_SIZE, 1)
        return Relative_L2_1D
    
    def make_relative_l2_2d(self,):
        NN = self.N
        LGL_quadrature = self.make_LGL_quadrature_2d()
        @jax.jit
        def Relative_L2_2D(u_pred, u_true):
            u_pred = u_pred.reshape(-1, NN+1, NN+1)
            u_true = u_true.reshape(-1, NN+1, NN+1)
            q1 = LGL_quadrature((u_pred-u_true)**2)
            q2 = LGL_quadrature(u_true**2)
            loss = jnp.sqrt(q1 / q2)
            return loss.reshape(-1, 1)
        return Relative_L2_2D


    def make_relative_linf_1d(self,):
        NN = self.N
        @jax.jit
        def Relative_Linf_1D(u_pred, u_true):
            u_pred = u_pred.reshape(-1, NN+1)
            u_true = u_true.reshape(-1, NN+1)
            BATCH_SIZE = u_pred.shape[0]
            loss1 = jnp.amax(jnp.abs(u_true - u_pred), axis=(-1))
            loss2 = jnp.amax(jnp.abs(u_true), axis=(-1))
            loss = loss1/loss2
            return loss.reshape(BATCH_SIZE, 1)
        return Relative_Linf_1D
    
    def make_relative_linf_2d(self,):
        NN = self.N
        @jax.jit
        def Relative_Linf_2D(u_pred, u_true):
            u_pred = u_pred.reshape(-1, NN+1, NN+1)
            u_true = u_true.reshape(-1, NN+1, NN+1)
            loss1 = jnp.amax(jnp.abs(u_true - u_pred), axis=(-2, -1))
            loss2 = jnp.amax(jnp.abs(u_true), axis=(-2, -1))
            loss = loss1/loss2
            return loss.reshape(-1, 1)
        return Relative_Linf_2D

    
    def total_variation_2d(self, f: jnp.ndarray) -> jnp.ndarray:
        NN = self.N
        """
        Args:
            f (jnp.ndarray): shape (batch, N+1, N+1)
        Returns:
            jnp.ndarray: shape (batch,), total variation per batch (internal only)
        """
        f = f.reshape(-1, NN+1, NN+1)

        # 내부만 고려: 1~N-1 범위
        f_internal = f[:, 1:-1, 1:-1]

        # 가로 방향 차이 (x 방향)
        tv_x = jnp.sum(jnp.abs(f_internal[:, :, 1:] - f_internal[:, :, :-1]), axis=(1,2))
        # 세로 방향 차이 (y 방향)
        tv_y = jnp.sum(jnp.abs(f_internal[:, 1:, :] - f_internal[:, :-1, :]), axis=(1,2))

        tv = tv_x + tv_y
        return tv

    def relative_total_variation_2d(self, f: jnp.ndarray) -> jnp.ndarray:
        NN = self.N
        f = f.reshape(-1, NN+1, NN+1)
        f_internal = f[:, 1:-1, 1:-1]

        tv = self.total_variation(f)
        total_mag = jnp.sum(jnp.abs(f_internal), axis=(1,2))
        rtv = tv / jnp.where(total_mag > 0, total_mag, 1e-12)
        return rtv