import gmsh
from dolfinx.io import gmshio
from mpi4py import MPI
import numpy as np
from dolfinx import fem, io, mesh, plot
from petsc4py.PETSc import ScalarType
from ufl import ds, dx, grad, inner
import ufl
import matplotlib.pyplot as plt
from dolfinx.fem import FunctionSpace
import random
from tqdm import tqdm
import pickle
from dolfinx.io import XDMFFile
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from dolfinx.fem import apply_lifting, set_bc
from dolfinx.fem.petsc import NonlinearProblem
from dolfinx.nls.petsc import NewtonSolver
from dolfinx import mesh, fem, io, nls, log
from petsc4py import PETSc
from scipy import sparse
from polygenerator import (
    random_polygon,
    random_star_shaped_polygon,
    random_convex_polygon,
)
import random
from itertools import pairwise
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
import sys
sys.path.append("./")
from utils.polygon import plot_polygon, generate_polygon_mesh, get_counterclockwise_dolphinx_boundary
from dolfinx.fem.petsc import assemble_vector, assemble_matrix, create_vector, apply_lifting, set_bc

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
import argparse
# this is just so that you can reproduce the same results
#random.seed(5)
gmsh.initialize()

def generate_solution(domain_type):
    if domain_type == "A":
        gmsh.open("./data/mesh/A-schwarz.msh")
    elif domain_type == "B":
        gmsh.open("./data/mesh/B-holes.msh")
    elif domain_type == "C":
        gmsh.open("./data/mesh/C-bosch.msh")

    gmsh_model_rank = 0
    mesh_comm = MPI.COMM_WORLD
    gmsh_model_rank = 0

    model = gmsh.model

    # convert to dophinx mesh
    domain, cell_markers, facet_markers = gmshio.model_to_mesh(model, mesh_comm, gmsh_model_rank, gdim=2)

    V = fem.FunctionSpace(domain, ("Lagrange", 1))

    if domain_type == "A":
        # get facets with physical group 1
        boundary_elems = facet_markers.find(21).tolist() + facet_markers.find(22).tolist()
    elif domain_type == "B":
        # get facets with physical group 1
        boundary_elems = facet_markers.find(11).tolist() + facet_markers.find(12).tolist() + facet_markers.find(13).tolist()
    elif domain_type == "C":
        boundary_elems = facet_markers.find(12).tolist() + facet_markers.find(13).tolist()

    t = 0  # Start time
    T = 0.5  # Final time
    num_steps = 50
    dt = T / num_steps  # time step size

    # get boundary locations
    boundary_index = set()
    for i in boundary_elems:
        boundary_index = boundary_index.union(domain.topology.connectivity(1, 0).links(i).tolist())
    boundary_index = sorted(list(boundary_index))
    boundary_points = domain.geometry.x[boundary_index]

    index = set(range(domain.geometry.x.shape[0]))
    interior_points = domain.geometry.x[sorted(list(index - set(boundary_index)))]

    # represent "previous" time step
    u_n = fem.Function(V)

    boundary_dofs = fem.locate_dofs_topological(V, 1, boundary_elems)

    uD = fem.Function(V)

    bc = fem.dirichletbc(uD, boundary_dofs)

    # represent "current" time step to solve
    uh = fem.Function(V)
    uh.name = "uh"

    u = ufl.TrialFunction(V)
    v = ufl.TestFunction(V)
    #x = ufl.SpatialCoordinate(domain)
    #f = fem.Function(V)
    f = fem.Constant(domain, ScalarType(0))
    alpha = fem.Function(V)
    a = u * v * ufl.dx + alpha * dt * ufl.dot(ufl.grad(u), ufl.grad(v)) * ufl.dx
    L = (u_n + dt * f) * v * ufl.dx

    bilinear_form = fem.form(a)
    linear_form = fem.form(L)

    b = create_vector(linear_form)

    datalist = []

    for i in tqdm(range(100)):
        series, bc_series = [], []

        alpha.x.array[:] = [1.0] * alpha.x.array.shape[0]
        A = assemble_matrix(bilinear_form, bcs=[bc])
        A.assemble()

        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        u_n.x.array[:] = [random.random() for _ in range(u_n.x.array.shape[0])]
        uh.x.array[:] = u_n.x.array
        bc_series.append(np.copy(uh.x.array[list(boundary_index)]))
        uD.x.array[list(boundary_index)] = [random.uniform(0.5, 1.0) for _ in boundary_index]

        for i in range(num_steps):
            # boundary condition for current time step

            
            series.append(np.copy(uh.x.array[:]))
            bc_series.append(np.copy(uD.x.array[list(boundary_index)]))
            #t += dt

            # Update the right hand side reusing the initial vector
            with b.localForm() as loc_b:
                loc_b.set(0)
            assemble_vector(b, linear_form)

            # Apply Dirichlet boundary condition to the vector
            apply_lifting(b, [bilinear_form], [[bc]])
            b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b, [bc])

            # Solve linear problem
            solver.solve(b, uh.vector)
            uh.x.scatter_forward()

            # Update solution at previous time step (u_n)
            u_n.x.array[:] = uh.x.array
        series = np.concatenate([domain.geometry.x[:, [0, 1]], 
                            np.concatenate([s[..., np.newaxis] for s in series], axis=1)], axis=1)
        bc_series = np.concatenate([boundary_points[:, [0, 1]], 
                                np.concatenate([s[..., np.newaxis] for s in bc_series[0:-1]], axis=1),
                                np.zeros((boundary_points.shape[0], 1))], axis=1)
        
        datalist.append((series, alpha.x.array[0], [bc_series]))
    return datalist

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training data generation')
    parser.add_argument('--domain-type',type=str)
    args = parser.parse_args()

    datalist = generate_solution(args.domain_type)
    with open(f"data/2d/laplace2d_{args.domain_type}_100_test" + ".pkl", 'wb') as file:
        pickle.dump(datalist, file)

