import numpy as np
import torch
from scipy.interpolate import CloughTocher2DInterpolator, NearestNDInterpolator
from scipy import linalg as la
from scipy import sparse as sp
from utils import decompress_time_data, plot_heatmap
import deepxde as dde

def matrixInvNorm(A, dense=False): # calculate ||A^{-1}||
    if dense:
        return 1 / la.norm(A.toarray(), ord=-2)
    else:
        return 1 / sp.linalg.svds(A, k=1, which='SM')[1]

class CaseBase():
    def __init__(self):
        pass

    def cond(self):
        """
        Return the condition number of this case
        """
        raise NotImplementedError("cond() method not implemented")
    
    def model(self):
        """
        Return the DeepXde Model of this case
        """
        raise NotImplementedError("model() method not implemented")
    
    def output_transform(self, x, y):
        """
        Enforce the network to fit boundary conditions
        """
        raise NotImplementedError("model() method not implemented")

class Burgers1D(CaseBase):

    def __init__(self, mesh=(499, 20), nu=0.01/np.pi) -> None:
        self.bbox = [-1, 1, 0, 1]
        self.xcoords, self.tcoords = mesh
        self.nu = nu

        self.Xs = np.linspace(self.bbox[0], self.bbox[1], self.xcoords+2)
        self.Ts = np.linspace(self.bbox[2], self.bbox[3], self.tcoords+2)
        self.Pts = np.stack(np.meshgrid(self.Xs, self.Ts)).reshape(2, -1).T
        self.dt = (self.bbox[3] - self.bbox[2]) / (self.tcoords+1)
        self.dx = (self.bbox[1] - self.bbox[0]) / (self.xcoords+1)
        self.SrcFunc = (lambda x: np.sin(np.pi * x[:, 0:1]))

        nu_list = np.logspace(-2, 0, 21, base=10) / np.pi
        nu_id = np.argmin(np.abs(nu_list - nu))
        if not np.isclose(nu_list[nu_id], nu):
            raise ValueError(f"Given nu({nu}) not listed in reference data")
        
        dim_s, dim_out = 1, 1
        self.ref_data = np.loadtxt("../data/ref/burgers1d_src.dat", comments="%")
        self.ref_data = self.ref_data[:, [0] + list(range(nu_id*11 + 1, nu_id*11 + 12))] # get the corresponding nu in data
        self.ref_data = decompress_time_data(
            self.ref_data, 
            dim_s, 
            dim_out, 
            self.bbox[dim_s*2:])
        interp = CloughTocher2DInterpolator(self.ref_data[:, :2], self.ref_data[:, 2:])
        interp_near = NearestNDInterpolator(self.ref_data[:, :2], self.ref_data[:, 2:])
        self.Us = interp(self.Pts)
        self.Us[self.Us == np.nan] = interp_near(self.Pts)[self.Us == np.nan]

    def get_AB(self): # Get Matrix Representation of the Non Linear Oprator
        size = self.tcoords * self.xcoords
        A, B = ([], ([], [])), ([], ([], []))
        a, b = np.zeros((size,)), np.zeros((size,))

        def get_id(t, x):
            if t <= 0 or t > self.tcoords or x <= 0 or x > self.xcoords:
                return -1
            return (t-1) * self.xcoords + (x-1)
        
        def get_value(t, x):
            return self.Us[t*(self.xcoords+2) + x]
        
        def put_elem(config, ind1, ind2, d):
            i = get_id(*ind1)
            j = get_id(*ind2)
            assert i != -1
            if j != -1:
                config[1][0].append(i)
                config[1][1].append(j)
                config[0].append(d)
                return 0
            else:
                return get_value(ind2[0], ind2[1]) * d # Boundary Value
        
        for t in range(1, self.tcoords+1):
            for x in range(1, self.xcoords+1):
                # A: u_t - nu * u_xx  =>   Au + a
                i = (t, x)
                l = (t-1, x)
                r = (t+1, x)
                a[get_id(*i)] += put_elem(A, i, l, -1/self.dt/2)
                a[get_id(*i)] += put_elem(A, i, r, 1/self.dt/2)

                l = (t, x-1)
                r = (t, x+1)
                a[get_id(*i)] += put_elem(A, i, l, -self.nu/self.dx**2)
                a[get_id(*i)] += put_elem(A, i, i, 2*self.nu/self.dx**2)
                a[get_id(*i)] += put_elem(A, i, r, -self.nu/self.dx**2)

                # B: u_x  =>  Bu + b
                b[get_id(*i)] += put_elem(B, i, l, -1/self.dx/2)
                b[get_id(*i)] += put_elem(B, i, r, 1/self.dx/2)

        return sp.csr_matrix(A, shape=(size, size)), sp.csr_matrix(B, shape=(size, size)), a, b
    
    def calc_F(self, A, B, a, b, u, src): # PDE: u_t + u*u_x - nu*u_xx - src
        return A@u + a + u * (B@u + b) - src
    def calc_dF(self, A, B, a, b, u):
        def diag(array):
            return sp.dia_matrix((array[None, :], [0]), A.shape)
        return sp.csr_matrix(A + diag(B@u) + diag(u)@B + diag(b))
    
    def newton_iter(self, A, B, a, b): # Get the solution of the PDE
        u = self.Us.reshape(self.tcoords+2, self.xcoords+2)[1:-1, 1:-1].reshape(-1)
        pt = self.Pts.reshape(self.tcoords+2, self.xcoords+2, 2)[1:-1, 1:-1, :].reshape(-1, 2)
        src = self.SrcFunc(pt).reshape(-1)
        F = self.calc_F(A, B, a, b, u, src)

        X = self.Pts.reshape(self.tcoords+2, self.xcoords+2, 2)[1:-1, 1:-1]
        # plot_heatmap(X[:,:,0].reshape(-1), X[:,:,1].reshape(-1), F, path="err2.png")

        for i in range(10): # y = a(x-x_0)+b = 0  =>  a(x-x_0)=-b
            # print(f"Epoch {i},   Fnorm {(F**2).sum()},  relerr{np.square(u - self.Us.reshape(self.tcoords+2, self.xcoords+2)[1:-1, 1:-1].reshape(-1)).sum() / np.square(self.Us).sum()}")
            du = sp.linalg.spsolve(self.calc_dF(A, B, a, b, u), -F)
            # print(np.abs(F).max(), self.calc_dF(A, B, a, b, u))
            u += du
            F = self.calc_F(A, B, a, b, u, src)
            
        return u
    
    def cond(self):
        A, B, a, b = self.get_AB()
        u = self.newton_iter(A, B, a, b)
        return matrixInvNorm(self.calc_dF(A, B, a, b, u), dense=True) * la.norm(self.SrcFunc(self.Pts), ord=2) / la.norm(self.Us, ord=2) 
    
    def model(self, net):
        self.geom = dde.geometry.Interval(-1, 1)
        timedomain = dde.geometry.TimeDomain(0, 1)
        self.geomtime = dde.geometry.GeometryXTime(self.geom, timedomain)

        def burger_pde(x, u):
            u_x = dde.grad.jacobian(u, x, i=0, j=0)
            u_t = dde.grad.jacobian(u, x, i=0, j=1)
            u_xx = dde.grad.hessian(u, x, i=0, j=0)
            return u_t + u * u_x - self.nu * u_xx - torch.sin(np.pi * x[:, 0:1])
        self.pde = burger_pde
    
        def ic_func(x):
            return np.sin(-np.pi * x[:, 0:1])
        self.bcs = [
            dde.DirichletBC(self.geomtime, ic_func, (lambda _, on_initial: on_initial), component=0),
            dde.DirichletBC(self.geomtime, (lambda _: 0), (lambda _, on_boundary: on_boundary), component=0),
        ]

        self.net = net
        self.data = dde.data.TimePDE(
            self.geomtime,
            self.pde,
            self.bcs,
            num_domain=8192,
            num_boundary=2048,
            num_initial=2048,
            num_test=8192
        )
        return dde.Model(self.data, net)
    
    def output_transform(self, x, y): 
        wt = torch.exp(-x[:, 1:2])
        wx = (x[:, 0:1] + 1) * (1 - x[:, 0:1])
        return wt * -torch.sin(np.pi * x[:, 0:1]) + (1-wt) * wx * y


class Poisson1D(CaseBase):
    def __init__(self, mesh=100, P=1):
        self.N = mesh
        self.P = P
        # Src Term: ...
        # 
        self.test_x = np.linspace(0, 2 * np.pi / P, 1000)[:, None]
        self.src_term = (lambda x: P**2 * torch.sin(P * x))

    
    def cond(self):
        h = 2*np.pi / self.P / (self.N+1)
        dat = np.ones(self.N)
        A = sp.dia_matrix(([dat, -2*dat, dat], [-1, 0, 1]), shape=(self.N, self.N)) / h**2
        return matrixInvNorm(sp.csr_matrix(A)) # Add ||f|| / ||u||
    
    def model(self, net, src=True):
        self.geom = dde.geometry.Interval(0, 2 * np.pi / self.P)

        def poisson_pde(x, u):
            u_xx = dde.grad.hessian(u, x, i=0, j=0)

            return u_xx - (self.src_term(x) if src else 0)

        self.pde = poisson_pde

        self.bcs = [
            dde.DirichletBC(self.geom, (lambda _:0), (lambda _, on_bnd: on_bnd))
        ]

        self.net = net
        self.data = dde.data.PDE(
            self.geom, 
            self.pde, 
            self.bcs,
            num_domain=2048,
            num_boundary=128,
        )

        return dde.Model(self.data, self.net)

    def output_transform(self, x, y):
        return x * (2 * np.pi / self.P - x) / (np.pi / self.P)**2 * y


class Wave1D(CaseBase):
    def __init__(self, mesh=(50, 50), C=1, scale=8):
        self.xcoords, self.tcoords = mesh
        self.bbox = [0, 1, 0, 1]
        self.dx = (self.bbox[1] - self.bbox[0]) / (self.xcoords + 1)
        self.dt = (self.bbox[3] - self.bbox[2]) / (self.tcoords + 1)
        self.C = C
        self.scale = scale

        self.ref_sol = \
            (lambda x:(np.sin(np.pi / scale * x[:, 0:1]) * np.cos(np.pi / scale * x[:, 1:2]) + \
                       0.5 * np.sin(4 * np.pi / scale * x[:, 0:1]) * np.cos(4 * C * np.pi / scale * x[:, 1:2])))
        self.ref_sol_t0 = \
            (lambda x:(torch.sin(np.pi / scale * x[:, 0:1]) + \
                       0.5 * torch.sin(4 * np.pi / scale * x[:, 0:1])))
        self.src_term = \
            (lambda x:(1 - C**2) * (np.pi / scale)**2 * \
                       torch.sin(np.pi / scale * x[:, 0:1]) * torch.cos(np.pi / scale * x[:, 1:2]))
        self.src_term_numpy = \
            (lambda x:(1 - C**2) * (np.pi / scale)**2 * \
                       np.sin(np.pi / scale * x[:, 0:1]) * np.cos(np.pi / scale * x[:, 1:2]))

        self.test_x = np.stack(np.meshgrid(np.linspace(self.bbox[0], self.bbox[1]), np.linspace(self.bbox[2], self.bbox[3])), axis=2).reshape(-1, 2)
        self.test_y = self.ref_sol(self.test_x)
    
    def get_Ab(self): # Get Matrix Representation of the Oprator (u_tt - C**2 * u_xx + src(x)) # Has Error (src not included)
        size = (self.tcoords + 1) * self.xcoords
        A = ([], ([], []))
        b = np.zeros(size)

        def get_id(t, x):
            if t <= 0 or t > self.tcoords+1 or x <= 0 or x > self.xcoords:
                return -1
            return (t-1) * self.xcoords + (x-1)
        def get_boundary_value(t, x):
            return self.ref_sol(np.array([[x/(self.xcoords+1), t/(self.tcoords+1)]])).item()
        
        def put_elem(config, ind1, ind2, d):
            i = get_id(*ind1)
            j = get_id(*ind2)
            assert i != -1
            if j != -1:
                config[1][0].append(i)
                config[1][1].append(j)
                config[0].append(d)
            else:
                b[i] -= get_boundary_value(*ind2) * d
        
        for t in range(1, self.tcoords+2):
            for x in range(1, self.xcoords+1):
                # if t != 1: # -C^2 u_xx
                put_elem(A, (t, x), (t, x-1), -self.C**2/self.dx**2)
                put_elem(A, (t, x), (t, x), 2*self.C**2/self.dx**2)
                put_elem(A, (t, x), (t, x+1), -self.C**2/self.dx**2)

                if 1 < t and t <= self.tcoords: # u_tt
                    put_elem(A, (t, x), (t-1, x), 1/self.dt**2)
                    put_elem(A, (t, x), (t, x), -2/self.dt**2)
                    put_elem(A, (t, x), (t+1, x), 1/self.dt**2)

                elif t == self.tcoords + 1: # u_tt (another formula)
                    put_elem(A, (t, x), (t-2, x), 1/self.dt**2)
                    put_elem(A, (t, x), (t-1, x), -2/self.dt**2)
                    put_elem(A, (t, x), (t, x), 1/self.dt**2)
                
                elif t == 1:
                    put_elem(A, (t, x), (t-1, x), 5/(2 * self.dt**2))
                    put_elem(A, (t, x), (t, x), -8/(2 * self.dt**2))
                    put_elem(A, (t, x), (t+1, x), 3/(2 * self.dt**2))

        return sp.csr_matrix(A, shape=(size, size)), b

    def check_A(self): # Check whether it is correct
        A, b = self.get_Ab()
        self.Xs = np.linspace(self.bbox[0], self.bbox[1], self.xcoords+2)
        self.Ts = np.linspace(self.bbox[2], self.bbox[3], self.tcoords+2)
        self.Pts = np.stack(np.meshgrid(self.Xs, self.Ts)).reshape(2, -1).T
        self.Us = self.ref_sol(self.Pts)
        u = self.Us.reshape(self.tcoords+2, self.xcoords+2)[1:, 1:-1].reshape(-1)
        x = self.Pts.reshape(self.tcoords+2, self.xcoords+2, 2)[1:, 1:-1, :].reshape(-1, 2)

        usolve = sp.linalg.spsolve(A, b)
        relerr = np.square(u - usolve).mean() / np.square(u).mean()
        print(relerr, np.square(u).mean())
        plot_heatmap(x[:, 0], x[:, 1], (u-usolve).reshape(-1), f"results/wave_error_fig/error_{self.C:.2f}.png")
        return relerr, np.square(u).mean()

    def cond(self):
        return matrixInvNorm(self.get_Ab()[0], dense=True) * \
            np.square(self.src_term_numpy(self.test_x)).mean() / np.square(self.test_y).mean()

    def model(self, net, src=True):
        self.geom = dde.geometry.Rectangle(xmin=[self.bbox[0], self.bbox[2]], xmax=[self.bbox[1], self.bbox[3]])

        def wave_pde(x, u):
            u_xx = dde.grad.hessian(u, x, i=0, j=0)
            u_tt = dde.grad.hessian(u, x, i=1, j=1)

            return u_tt - self.C**2 * u_xx + (self.src_term(x) if src else 0)
        
        self.pde = wave_pde

        def boundary_x(x, on_boundary):
            return on_boundary and (np.isclose(x[0], self.bbox[0]) or np.isclose(x[0], self.bbox[1]))
        def boundary_t0(x, on_boundary):
            return on_boundary and np.isclose(x[1], self.bbox[2])
        
        self.bcs = [
            dde.NeumannBC(self.geom, (lambda _:0), boundary_t0),
            dde.DirichletBC(self.geom, self.ref_sol, boundary_t0),
            dde.DirichletBC(self.geom, self.ref_sol, boundary_x)
        ]

        self.net = net
        self.data = dde.data.PDE(
            self.geom,
            self.pde,
            self.bcs,
            num_domain=8192,
            num_boundary=2048,
            num_test=8192
        )
        return dde.Model(self.data, net)
    
    def output_transform(self, xt, y):
        w_bond = xt[:, 0:1] * (self.scale - xt[:, 0:1]) / (self.scale**2/4)
        w_init = (xt[:, 1:2] * (self.scale * 1.5 - xt[:, 1:2])) ** 2 / (self.scale**4/16)
        return self.ref_sol_t0(xt) + w_init * w_bond * y
    
class Helmholtz2d(CaseBase):
    def __init__(self, mesh=(50, 50), A=(4, 4), scale=1, k=1): 
        self.A = A
        self.k = k
        self.scale = scale
        self.bbox = [0, scale, 0, scale]
        self.xcoords, self.ycoords = mesh

        self.dx = (self.bbox[1] - self.bbox[0]) / (self.xcoords + 1)
        self.dy = (self.bbox[3] - self.bbox[2]) / (self.ycoords + 1)

        self.src_term = (lambda x: torch.sin(A[0] * np.pi * x[:, 0:1] / scale) * \
                                   torch.sin(A[1] * np.pi * x[:, 1:2] / scale) * \
                                   (k**2 - np.pi**2 * (A[0]**2 + A[1]**2) / scale**2))
        
        self.ref_sol = (lambda x: np.sin(A[0] * np.pi * x[:, 0:1] / scale) * \
                                  np.sin(A[1] * np.pi * x[:, 1:2] / scale))
        
        self.test_x = np.stack(np.meshgrid(np.linspace(self.bbox[0], self.bbox[1]), np.linspace(self.bbox[2], self.bbox[3])), axis=2).reshape(-1, 2)
        self.test_y = self.ref_sol(self.test_x)
    
    def get_Ab(self): # Get Matrix Representation of the Oprator (u_xx + u_yy + k**2*u - src(x))
        size = self.xcoords * self.ycoords
        A = ([], ([], []))
        b = np.zeros(size)

        def get_id(x, y):
            if x <= 0 or x > self.xcoords or y <= 0 or y > self.ycoords:
                return -1
            return (x-1) * self.ycoords + (y-1)
        def get_boundary_value(x, y):
            return self.ref_sol(np.array([[x/(self.xcoords+1), y/(self.ycoords+1)]])).item()
        
        def put_elem(config, ind1, ind2, d):
            i = get_id(*ind1)
            j = get_id(*ind2)
            assert i != -1
            if j != -1:
                config[1][0].append(i)
                config[1][1].append(j)
                config[0].append(d)
            else:
                b[i] -= get_boundary_value(*ind2) * d
        
        for x in range(1, self.xcoords+1):
            for y in range(1, self.ycoords+1):
                b[get_id(x, y)] += self.src_term(torch.tensor([[x/(self.xcoords+1), y/(self.ycoords+1)]])).item()
                put_elem(A, (x, y), (x-1, y), 1 / self.dx**2)
                put_elem(A, (x, y), (x+1, y), 1 / self.dx**2)
                put_elem(A, (x, y), (x, y-1), 1 / self.dy**2)
                put_elem(A, (x, y), (x, y+1), 1 / self.dy**2)
                put_elem(A, (x, y), (x, y), self.k**2-(2 / self.dx**2 + 2 / self.dy**2))


        return sp.csr_matrix(A, shape=(size, size)), b

    def cond(self):
        return matrixInvNorm(self.get_Ab()[0], dense=True) * np.abs(self.k**2 - np.pi**2 * (self.A[0]**2 + self.A[1]**2) / self.scale**2)

    def model(self, net, src=True):
        self.geom = dde.geometry.Rectangle(xmin=[0, 0], xmax=[self.scale, self.scale])

        def helmholtz_pde(x, u):
            u_xx = dde.grad.hessian(u, x, i=0, j=0)
            u_yy = dde.grad.hessian(u, x, i=1, j=1)

            return u_xx + u_yy + self.k**2 * u - (self.src_term(x) if src else 0)
    
        self.pde = helmholtz_pde

        self.bcs = [
            dde.DirichletBC(self.geom, (lambda _:0), (lambda _,on_bnd:on_bnd))
        ]

        self.net = net
        self.data = dde.data.PDE(
            self.geom,
            self.pde,
            self.bcs,
            num_domain=8192,
            num_boundary=2048,
            num_test=8192
        )
        return dde.Model(self.data, net)
    
    def output_transform(self, xy, y):
        w = xy[:, 0:1] * (self.scale - xy[:, 0:1]) * xy[:, 1:2] * (self.scale - xy[:, 1:2]) / (self.scale**4) * 16
        return w * y + (1-w) * torch.sin(self.A[0] * np.pi * xy[:, 0:1]) * torch.sin(self.A[1] * np.pi * xy[:, 1:2])
        
