import numpy as np
import scipy.sparse as sp
from dolfin import *

import os
os.environ["DDE_BACKEND"] = "pytorch"
import deepxde as dde

from src.invpdedata import InversePDEData


class HInv(InversePDEData):
    """
    Prepare data for the heat inverse problem (HInv).
    """
    def __init__(self):
        super().__init__()
        def a_fun(xyt):
            x, y = xyt[:, 0:1], xyt[:, 1:2]
            return 2 + np.sin(np.pi * x) * np.sin(np.pi * y)

        def u_fun(xyt):
            x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
            return np.exp(-t) * np.sin(np.pi * x) * np.sin(np.pi * y)
        
        self.ref_sol_a = a_fun
        self.ref_sol_u = u_fun
    
    def prepare_train_data(self):
        self.train_X = np.loadtxt("data/ref/heatinv_points.dat")
        self.train_Y = self.ref_sol_u(self.train_X)
        # Add noise of N(0, 0.1)
        self.train_Y += np.random.normal(0, 0.1, self.train_Y.shape)

    def prepare_boundary_data(self):
        # Uniformly sample 2500 points along the boundary
        bbox = self.config["Bounding Box"]
        geom = dde.geometry.Cuboid(
            xmin=[bbox[0], bbox[2], bbox[4]],
            xmax=[bbox[1], bbox[3], bbox[5]])
        self.train_X_b = geom.random_boundary_points(2500)
        self.train_Y_b = self.ref_sol_a(self.train_X_b)

    def init_problem(self):
        def f_fun(xyt):
            x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
            s, c, p = np.sin, np.cos, np.pi
            return np.exp(-t) * (
                (4 * p**2 - 1) * s(p * x) * s(p * y)
                + p**2 * (
                    2 * s(p * x) ** 2 * s(p * y) ** 2
                    - c(p * x) ** 2 * s(p * y) ** 2
                    - s(p * x) ** 2 * c(p * y) ** 2
                )
            )
        
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = BoxMesh(Point(bbox[0], bbox[2], bbox[4]), 
            Point(bbox[1], bbox[3], bbox[5]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"],
            self.config["Grid Size (t-direction)"])
        self.V = V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        self.f = Function(V)
        self.f.vector()[:] = f_fun(self.coords).reshape(-1)

        self.u = Function(V)
        self.a = Function(V)
        self.v = TestFunction(V)
    
    def update_precond(self):
        u = TrialFunction(self.V)
        v = self.v
        grad_u = as_vector((u.dx(0), u.dx(1)))
        grad_v = as_vector((v.dx(0), v.dx(1)))
        u_t = u.dx(2)

        A = assemble(u_t*v*dx + self.a*inner(grad_u, grad_v)*dx)

        # Convert PETSc matrix to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        A = sp.csr_matrix(csr)

        # Compute the preconditioner
        diag = A.diagonal()
        diag[np.abs(diag) < 1e-3] = 1
        self.P_inv = sp.diags(1 / diag)

    def compute_loss_grad(self):
        self.u.vector()[:] = self.u_data
        self.a.vector()[:] = self.a_data

        # Define variational problem
        u = self.u
        v = self.v
        grad_u = as_vector((u.dx(0), u.dx(1)))
        grad_v = as_vector((v.dx(0), v.dx(1)))
        u_t = u.dx(2)

        r = u_t*v*dx + self.a*inner(grad_u, grad_v)*dx - self.f*v*dx
        drdu = derivative(r, self.u)
        drda = derivative(r, self.a)

        r = assemble(r)
        drdu = assemble(drdu)
        drda = assemble(drda)

        # Convert PETSc matrix to numpy arrays
        r = r.get_local().reshape(-1, 1)

        loss_val = np.mean(r**2)

        csr = as_backend_type(drdu).mat().getValuesCSR()[::-1]
        drdu = sp.csr_matrix(csr)
        csr = as_backend_type(drda).mat().getValuesCSR()[::-1]
        drda = sp.csr_matrix(csr)

        if self.config["Use Preconditioner"] and \
            self.P_inv is not None:
            r = self.P_inv.dot(r)
            r = self.P_inv.dot(r)

        grad_u = drdu.transpose().dot(r) / r.shape[0]
        grad_a = drda.transpose().dot(r) / r.shape[0]

        return loss_val, grad_u, grad_a
