import time
import itertools
import numpy as np
import jax
import jax.numpy as jnp
from scipy.sparse import diags
import logging
from tqdm import tqdm
jax.config.update("jax_enable_x64", True)




class LegendreSystemJax:
    def __init__(self, args):
        self.args = args
        self.N = args.N
        # self.M = args.M

        BOUNDARY_CONDITION = args.boundary_condition

        self.dtype = jnp.float64
        self.xx = self.LGLpoints() # Legendre-Gauss-Lobatto (LGL) quadrature

        self.lepolys = self.LegendrePolys()
        self.ww = self.LGLweights()
        self.ak, self.bk = self.ak_bk( self.N, BOUNDARY_CONDITION)
        self.phi = self.basis_1D()

        self.S_1D = self.S_Legendre(self.N, self.bk)
        self.R_1D = self.R_Legendre(self.N, self.ak, self.bk)
        self.M_1D = self.M_Legendre(self.N, self.ak, self.bk)
        self.AA = self.A_matrix_1D(args)

        self.Gram_matrix = self.M_1D

        # Pure functions
        self.ForwardTransform_1D = self.make_forward_transform_1d()
        # print(f'Done: ForwardTransform_1D')
        self.BackwardTransform_1D = self.make_backward_transform_1d()
        # print(f'Done: BackwardTransform_1D')
        self.reconstruct_1D = self.make_reconstruct_1d()
        # print(f'Done: reconstruct_1D')
        self.LGL_quadrature = self.make_LGL_quadrature()
        # print(f'Done: LGL_quadrature')
        self.Relative_L2_1D = self.make_relative_l2_1d()
        # print(f'Done: Relative_L2_1D')
        self.Relative_Linf_1D = self.make_relative_linf_1d()
        # print(f'Done: Relative_Linf_1D')

        if args.dimension == '1D':
            pass
        elif args.dimension == '2D':
            self.X_2D, self.Y_2D = self.LGLpoints_2D()
            self.ww_2D = self.LGLweights_2D()
            self.phi_2D = self.basis_2D()

            self.MM_2D = self.AotimesB(self.M_1D , self.M_1D)
            self.SM_2D = self.AotimesB(self.S_1D , self.M_1D)
            self.MS_2D = self.AotimesB(self.M_1D , self.S_1D)
            self.RM_2D = self.AotimesB(self.R_1D , self.M_1D)
            self.MR_2D = self.AotimesB(self.M_1D , self.R_1D)

        elif args.dimension == '3D':
            self.X_3D, self.Y_3D, self.Z_3D = self.LGLpoints_3D()
            self.phi_3D = self.basis_3D()
        else:
            raise NotImplementedError('Wrong Dimension')

    def A_matrix_1D(self, args):
        if args.equation == 'RD_1D':
            # - eps * u_xx + u = f
            AA = - args.pde_parameter * self.S_1D + self.M_1D
        elif args.equation == 'CD_1D':
            # -eps * u_xx - u_x = f
            AA =  - args.pde_parameter * self.S_1D - self.R_1D
        elif args.equation == 'Helmholtz_1D':
            '''
                u_xx + ku * u_x = f
            '''
            AA =  self.S_1D + args.pde_parameter * self.M_1D
            # ku = 3.5
            # AA =  self.S_1D + ku * self.M_1D
        else:
            raise NotImplementedError('Wrong Error')
        logging.info(f'PDE : {args.equation} with Parameter : {args.pde_parameter}')
        return AA

    def LGLpoints(self):
        NN = self.N
        dtype = self.dtype
        if NN <= 0:
            #print("n should be bigger than 1")
            return jnp.array([[]])
        elif NN == 1:
            return jnp.array([[-1], [1]])
        elif NN == 2:
            return jnp.array([[-1], [0], [1]])
        else:
            j = jnp.arange(1, NN - 1)
            bv = j * (j + 2) / ((2 * j + 1) * (2 * j + 3))
            A = jnp.diag(jnp.sqrt(bv), k=-1) + jnp.diag(jnp.sqrt(bv), k=1)
            # z = jnp.sort(jnp.linalg.eigvals(A)) 

            import scipy.linalg
            z = jnp.sort(jnp.array(scipy.linalg.eigvals(np.array(A))))

            z_1 = jnp.concatenate([jnp.array([-1]), z, jnp.array([1])])
            return z_1[:, None].astype(dtype)


    def LGLweights(self):
        NN = self.N
        denom = jnp.square(self.lepolys[NN])
        ww = 2/(NN *(NN+1))/denom
        return jnp.reshape(ww, (1,NN+1))

    def LGLweights_2D(self):
        NN = self.args.N
        w_x = self.ww.view(NN+1, 1)
        w_y = self.ww.view(1, NN+1)
        ww_2D = w_x * w_y
        return ww_2D.view(NN+1, NN+1)
    
    def make_LGL_quadrature(self,):
        NN = self.N
        ww = self.ww.reshape(1, 1, NN+1)
        @jax.jit
        def LGL_quadrature(ff):
            ff = ff.reshape(-1, 1, NN+1)
            batchsize = ff.shape[0]
            sum = jnp.sum(ff * ww, axis=2)
            return sum.reshape(batchsize, 1, 1)
        return LGL_quadrature
       
    def LegendrePoly(self, n:int, x):
        dtype = self.dtype
        if n == 0:
            return jnp.ones_like(x).astype(dtype=jnp.float64)
        elif n == 1:
            return x.astype(dtype=jnp.float64)
        else:
            polylst = jnp.ones_like(x)
            poly = x                  
            for k in range(2,n+1):
                polyn = ((2*k-1)*x*poly-(k-1)*polylst)/k
                polylst, poly = poly, polyn
            return polyn.astype(dtype=dtype)


    def LegendrePolys(self):
        polys = [self.LegendrePoly(i, self.xx) for i in range(self.args.N + 1)]
        return jnp.stack(polys)  # shape: (N+1, N+1)
    

    def make_reconstruct_1d(self,):
        NN = self.N
        dtype=self.dtype 
        phi = self.phi  # shape: (N-1, N+1)
        def reconstruct_1D(alphas):
            '''
                # Input Dim (1, 1, N-1)
                # Output Dim (1, N+1)
            '''
            coeff = alphas.reshape(-1, 1, NN - 1)  # (B, 1, N-1)
            result = jnp.matmul(coeff, phi)  # (B, 1, N+1)
            # reconstruct_1D - alphas.shape: (10, 32)
            # reconstruct_1D - coeff.shape: (10, 1, 32)
            # reconstruct_1D - result.shape: (10, 1, 34)
            return result.astype(dtype)
        return reconstruct_1D

    def ak_bk(self, N, BOUNDARY_CONDITION):
        if BOUNDARY_CONDITION == 'Dirichlet':
            ak, bk = jnp.zeros((N-1,)), -jnp.ones((N-1,))
            return ak, bk
        
        elif BOUNDARY_CONDITION == 'Neumann':
            kk = jnp.arange(N-1)
            ak = jnp.zeros((N-1,))
            bk = - kk * (kk +1) / ((kk + 2) * (kk+3))
            return ak, bk
        else:
            raise NotImplementedError('Wrong BDD')

    def basis_1D(self):
        N = self.N
        ak = self.ak
        bk = self.bk
        lepoly = self.lepolys
        dtype = self.dtype
        lepoly = jnp.stack(self.lepolys)  # 🔸 리스트 → JAX 배열

        phi = jnp.zeros((N - 1, N + 1), dtype=dtype)
        def update(i, uu):
            row = lepoly[i] + ak[i] * lepoly[i + 1] + bk[i] * lepoly[i + 2]
            row = row.reshape((N+1,))
            return uu.at[i, :].set(row)
        phi = jax.lax.fori_loop(0, N - 1, update, phi)
        # return jax.device_put(phi, self.args.device)
        return phi

    def S_Legendre(self, NN, bk):
        dtype = self.dtype

        kk = jnp.arange(0,NN-1, dtype=dtype)
        s_diag= (4*kk+6)*bk
        S_1D = s_diag * jnp.eye(NN-1)
        self.s_diag = s_diag
        return S_1D
    
    def R_Legendre(self, NN, ak, bk):
        dtype = self.dtype
        RR = np.zeros((NN-1, NN-1), dtype=dtype)
        for j in range(NN -1):
            for k in range(NN -1):
                if k == j+1:
                    RR[k, j] = 2 * bk[j]
                elif k == j:
                    RR[k, j] = 2 * ak[j] + 2 * ak[k] * bk[j]
                elif k+1 == j:
                    RR[k, j] = 2 * (1 + ak[k] * ak[j] + bk[j] + bk[k] * bk[j])
                elif (k+1 < j) and (k + j) % 2 == 0:
                    RR[k, j] = 2 * ak[k] * (1 + bk[j]) + 2 * ak[j] * (1 + bk[k])
                elif (k+1 < j) and (k + j) % 2 == 1:
                    RR[k, j] = 2 * (1 + ak[k] * ak[j] + bk[k] + bk[j] + bk[k] * bk[j])
        return jnp.array(RR, dtype=dtype)
    

    def M_Legendre(self, NN, ak, bk):
        dtype = self.dtype
        
        M = np.zeros((NN-1, NN-1), dtype=dtype)
        for k in range(NN-1):
            # j = k
            M[k, k] = 2 / (2 * k + 1) + ak[k] ** 2 * 2 / (2 * k + 3) + bk[k] ** 2 * 2 / (2 * k + 5)
            if k + 1 < NN -1:
                # j = k + 1
                M[k, k + 1] = ak[k] * 2 / (2 * k + 3) + ak[k + 1] * bk[k] * 2 / (2 * k + 5)
                M[k + 1, k] = M[k, k + 1]
            if k + 2 < NN -1:
                # j = k + 2
                M[k, k + 2] = bk[k] * 2 / (2 * k + 5)
                M[k + 2, k] = M[k, k + 2]
        return jnp.array(M, dtype=dtype)



    ##### Weak Form #####
    def make_forward_transform_1d(self,):
        NN = self.args.N
        ww = self.ww.reshape(1, 1, self.N+1)
        phi = self.phi.reshape((1, self.N-1, self.N+1))

        @jax.jit
        def ForwardTransform_1D(ff):
            '''
                ff: (batch, N+1)
            '''
            ff = ff.reshape(-1, 1, NN+1)
            # ss = jnp.sum( (ff*ww)* phi, axis=2)
            ss = jnp.einsum('bij,bkj->bki', ff * ww, phi)
            return ss.reshape(-1, NN-1, 1)
        return ForwardTransform_1D
    
    def make_backward_transform_1d(self):
        '''
            inv_Gram_matrix: (N-1, N-1)
            phi: (N-1, N+1)
        '''
        phi = self.phi
        Gram_matrix = self.Gram_matrix
        @jax.jit
        def BackwardTransform_1D(bar_f):
            '''
                input: bar_f (batch, N-1, 1)
                output: f_interpol (batch, N+1, 1)
            '''
            # Solve: Gram_matrix @ x = bar_f  ⇒ x = inv_coeff
            inv_coeff = jax.vmap(lambda b: jnp.linalg.solve(Gram_matrix, b.squeeze(-1)))(bar_f)  # (batch, N-1)
            f_interpol = jnp.einsum("ij,bj->bi", phi.T, inv_coeff)  # (batch, N+1)
            return f_interpol[..., None]  # (batch, N+1, 1)
        return BackwardTransform_1D



    ##### Loss #####
    def make_relative_l2_1d(self,):
        NN = self.N
        LGL_quadrature = self.make_LGL_quadrature()
        @jax.jit
        def Relative_L2_1D(u_pred, u_true):
            BATCH_SIZE = u_pred.shape[0]
            u_pred = u_pred.reshape(-1, NN+1)
            u_true = u_true.reshape(-1, NN+1)
            q1 = LGL_quadrature((u_pred-u_true)**2)
            q2 = LGL_quadrature(u_true**2)
            loss = jnp.sqrt(q1 / q2)
            return loss.reshape(BATCH_SIZE, 1)
        return Relative_L2_1D


    def make_relative_linf_1d(self,):
        NN = self.N
        
        @jax.jit
        def Relative_Linf_1D(u_pred, u_true):
            u_pred = u_pred.reshape(-1, NN+1)
            u_true = u_true.reshape(-1, NN+1)
            BATCH_SIZE = u_pred.shape[0]
            loss1 = jnp.amax(jnp.abs(u_true - u_pred), axis=(-1))
            loss2 = jnp.amax(jnp.abs(u_true), axis=(-1))
            loss = loss1/loss2
            return loss.reshape(BATCH_SIZE, 1)
        return Relative_Linf_1D
    


    def total_variation(self, f: jnp.ndarray) -> jnp.ndarray:
        NN = self.N
        """
        Args:
            f (jnp.ndarray): shape (batch, N+1)
        Returns:
            jnp.ndarray: shape (batch,), total variation per batch
        """
        f = f.reshape(-1, NN+1)
        # tv = jnp.sum(jnp.abs(f[:, 1:] - f[:, :-1]), axis=1)
        tv = jnp.sum(jnp.abs(f[:, 2:-1] - f[:, 1:-2]), axis=1)
        return tv

    def relative_total_variation(self, f: jnp.ndarray) -> jnp.ndarray:
        NN = self.N
        f = f.reshape(-1, NN+1)
        f_wo_bdd = f[:, 1:-1]

        tv = self.total_variation(f)
        total_mag = jnp.sum(jnp.abs(f_wo_bdd), axis=1)
        rtv = tv / jnp.where(total_mag > 0, total_mag, 1e-12)
        return rtv

