import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.invpdedata import InversePDEData


class PInv(InversePDEData):
    """
    Prepare data for the Poisson inverse problem (PInv).
    """
    def __init__(self):
        super().__init__()
        def a_fun(xy):
            x, y = xy[:, 0:1], xy[:, 1:2]
            return 1 / (1 + x**2 + y**2 + (x - 1)**2 + (y - 1)**2)

        def u_fun(xy):
            x, y = xy[:, 0:1], xy[:, 1:2]
            return np.sin(np.pi * x) * np.sin(np.pi * y)
        
        self.ref_sol_a = a_fun
        self.ref_sol_u = u_fun
    
    def prepare_train_data(self):
        # Uniformly sample 50x50 points
        bbox = self.config["Bounding Box"]
        dim_s_t = self.config["Spatial Temporal Dimension"]
        sample_points = 50
        xlist = [np.linspace(bbox[i * 2], bbox[i * 2 + 1], sample_points)
                    for i in range(dim_s_t)]
        self.train_X = np.stack(np.meshgrid(*xlist), axis=-1).reshape(-1, dim_s_t)
        self.train_Y = self.ref_sol_u(self.train_X)
        # Add noise of N(0, 0.1)
        self.train_Y += np.random.normal(0, 0.1, self.train_Y.shape)

    def prepare_boundary_data(self):
        # Uniformly sample 256x4 points along the boundary
        bbox = self.config["Bounding Box"]
        up_boundary = np.stack([np.linspace(bbox[0], bbox[1], 256),
                                np.full((256,), bbox[3])], axis=-1)
        down_boundary = np.stack([np.linspace(bbox[0], bbox[1], 256),
                                np.full((256,), bbox[2])], axis=-1)
        left_boundary = np.stack([np.full((256,), bbox[0]),
                                np.linspace(bbox[2], bbox[3], 256)], axis=-1)
        right_boundary = np.stack([np.full((256,), bbox[1]),
                                np.linspace(bbox[2], bbox[3], 256)], axis=-1)
        self.train_X_b = np.concatenate([up_boundary, down_boundary, 
                                        left_boundary, right_boundary], axis=0)
        self.train_Y_b = self.ref_sol_a(self.train_X_b)

    def init_problem(self):
        def f_fun(xy):
            x, y = xy[:, 0:1], xy[:, 1:2]
            return 2 * np.pi**2 * np.sin(np.pi * x) * np.sin(np.pi * y) * self.ref_sol_a(xy) + \
                2 * np.pi * ((2*x+1) * np.cos(np.pi * x) * np.sin(np.pi * y) + (2*y+1) * np.sin(np.pi * x) * np.cos(np.pi * y)) * self.ref_sol_a(xy)**2

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(Point(bbox[0], bbox[2]),
                            Point(bbox[1], bbox[3]),
                            self.config["Grid Size (x-direction)"],
                            self.config["Grid Size (y-direction)"])
        self.V = V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        self.f = Function(V)
        self.f.vector()[:] = f_fun(self.coords).reshape(-1)

        self.u = Function(V)
        self.a = Function(V)
        self.v = TestFunction(V)
    
    def update_precond(self):
        u = TrialFunction(self.V)
        A = assemble(self.a*inner(grad(u), grad(self.v))*dx)

        # Convert PETSc matrix to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        A = sp.csr_matrix(csr)

        # Compute the preconditioner
        diag = A.diagonal()
        diag[np.abs(diag) < 1e-3] = 1
        self.P_inv = sp.diags(1 / diag)

    def compute_loss_grad(self):
        self.u.vector()[:] = self.u_data
        self.a.vector()[:] = self.a_data

        # Define variational problem
        r = self.a*inner(grad(self.u), grad(self.v))*dx - self.f*self.v*dx
        drdu = derivative(r, self.u)
        drda = derivative(r, self.a)

        r = assemble(r)
        drdu = assemble(drdu)
        drda = assemble(drda)

        # Convert PETSc matrix to numpy arrays
        r = r.get_local().reshape(-1, 1)

        loss_val = np.mean(r**2)

        csr = as_backend_type(drdu).mat().getValuesCSR()[::-1]
        drdu = sp.csr_matrix(csr)
        csr = as_backend_type(drda).mat().getValuesCSR()[::-1]
        drda = sp.csr_matrix(csr)

        if self.config["Use Preconditioner"] and \
            self.P_inv is not None:
            r = self.P_inv.dot(r)
            r = self.P_inv.dot(r)

        grad_u = drdu.transpose().dot(r) / r.shape[0]
        grad_a = drda.transpose().dot(r) / r.shape[0]

        return loss_val, grad_u, grad_a
