import time
import json
import numpy as np
from dolfin import *
import scipy.sparse as sp
from utils.fenics_utils import generate_complex_mesh3d


class Poisson3DSolver:
    """
    Solve the 3D Poisson equation with FEniCS.
    """
    def __init__(self) -> None:
        self.load_config()
        self.msh_len = self.config["Grid Size (mesh length)"]
        self.coef_data = np.loadtxt("data/ref/heat_2d_coef_256.dat")

    def load_config(self):
        # Load from json file
        with open("src/poisson/poisson3d_cg/conf.json", "r") as f:
            self.config = json.load(f)

    def set_mesh_length(self, length):
        self.msh_len = length

    def setup(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        voids = {
            'sphere': [[0.4, 0.3, 0.6, 0.2], 
                    [0.6, 0.7, 0.6, 0.2], 
                    [0.2, 0.8, 0.7, 0.1], 
                    [0.6, 0.2, 0.3, 0.1]]
        }
        mesh_length = self.msh_len
        mesh = generate_complex_mesh3d(bbox, voids, mesh_length)
        self.V = V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        A = (20, 100)
        m = (1, 10, 5)
        k_values = (8, 10)
        mu_values = (1, 1)
        interface_z = 0.5

        # Define the function f
        f_expr = Expression(
            "A0 * exp(sin(m0 * pi * x[0]) + sin(m1 * pi * x[1]) + sin(m2 * pi * x[2])) * (pow(x[0], 2) + pow(x[1], 2) + pow(x[2], 2) - 1) / (pow(x[0], 2) + pow(x[1], 2) + pow(x[2], 2) + 1)"
            "+ A1 * sin(m0 * pi * x[0]) * sin(m1 * pi * x[1]) * sin(m2 * pi * x[2])",
            degree=2,
            A0=A[0],
            A1=A[1],
            m0=m[0],
            m1=m[1],
            m2=m[2],
        )
        f = interpolate(f_expr, V)

        # Define mu and k as Functions and interpolate the values
        mu = Function(V)
        k = Function(V)
        mu.vector()[:] = [mu_values[0] if z < interface_z else mu_values[1] for x, y, z in self.coords]
        k.vector()[:] = [k_values[0] if z < interface_z else k_values[1] for x, y, z in self.coords]

        # Define the bilinear form and linear form
        self.a = mu * inner(grad(u), grad(v)) * dx + k**2 * u * v * dx
        self.L = f * v * dx

        A = assemble(self.a)
        b = assemble(self.L)

        # Convert PETSc matrix to scipy sparse array
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
    
    def solve(self):
        ts = []
        N_TRIALS = 5
        # Compute solution
        for _ in range(N_TRIALS):
            t0 = time.time()
            u = Function(self.V)
            # Use CG solver in scipy
            # ILU preconditioner
            solve(self.a == self.L, u, solver_parameters={"linear_solver": "cg", "preconditioner": "ilu"})
            t1 = time.time()
            ts.append(t1 - t0)
        # Compute the average time
        self.time = np.mean(ts)
        # Compute the std
        self.std = np.std(ts)
        # Print
        print("Average time: %.2e s"%self.time)
        print("Std: %.2e s"%self.std)


solver = Poisson3DSolver()
solver.set_mesh_length(0.05)
solver.setup()
solver.solve()
