import gmsh
from dolfinx.io import gmshio
from mpi4py import MPI
import numpy as np
from dolfinx import fem, io, mesh, plot
from petsc4py.PETSc import ScalarType
from ufl import ds, dx, grad, inner
import ufl
import matplotlib.pyplot as plt
from dolfinx.fem import FunctionSpace
import random
from tqdm import tqdm
import pickle
from dolfinx.io import XDMFFile
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from dolfinx.fem import apply_lifting, set_bc
from petsc4py import PETSc
from scipy import sparse
from polygenerator import (
    random_polygon,
    random_star_shaped_polygon,
    random_convex_polygon,
)
import random
from itertools import pairwise
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
import sys
sys.path.append("./")
from utils.polygon import plot_polygon, generate_polygon_mesh, get_counterclockwise_dolphinx_boundary
import plotly.express as px
import plotly.offline as pyo
import plotly.io as pio
pio.renderers.default = 'iframe'
import plotly.graph_objs as go

import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import LinearNDInterpolator
import scipy.interpolate as interp
import scipy
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
import argparse

# this is just so that you can reproduce the same results
#random.seed(5)
gmsh.initialize()

def generate_solution(domain_type):
    if domain_type == "A":
        gmsh.open("./data/mesh/A-schwarz.msh")
    elif domain_type == "B":
        gmsh.open("./data/mesh/B-holes.msh")
    elif domain_type == "C":
        gmsh.open("./data/mesh/C-bosch.msh")

    gmsh_model_rank = 0
    mesh_comm = MPI.COMM_WORLD
    gmsh_model_rank = 0

    model = gmsh.model

    # convert to dophinx mesh
    domain, cell_markers, facet_markers = gmshio.model_to_mesh(model, mesh_comm, gmsh_model_rank, gdim=2)

    V = fem.FunctionSpace(domain, ("Lagrange", 1))

    if domain_type == "A":
        # get facets with physical group 1
        db_elems = facet_markers.find(21).tolist()
        nb_elems = facet_markers.find(22).tolist()
    elif domain_type == "B":
        # get facets with physical group 1
        db_elems = facet_markers.find(12).tolist() + facet_markers.find(13).tolist()
        nb_elems = facet_markers.find(11).tolist()
    elif domain_type == "C":
        db_elems = facet_markers.find(12).tolist()
        nb_elems = facet_markers.find(13).tolist()

    
    # get boundary locations
    db_index = set()
    nb_index = set()
    for i in db_elems:
        db_index = db_index.union(domain.topology.connectivity(1, 0).links(i).tolist())
    for i in nb_elems:
        nb_index = nb_index.union(domain.topology.connectivity(1, 0).links(i).tolist())
        
    db_index = sorted(list(db_index))
    nb_index = sorted(list(nb_index))
    db_points = domain.geometry.x[db_index]
    nb_points = domain.geometry.x[nb_index]

    db_index, nb_index = set(), set()
    for i in db_elems:
        db_index = db_index.union(domain.topology.connectivity(1, 0).links(i).tolist())
    for i in nb_elems:
        nb_index = nb_index.union(domain.topology.connectivity(1, 0).links(i).tolist())

    db_index = sorted(list(db_index))
    nb_index = sorted(list(nb_index))

    db_points = domain.geometry.x[db_index]
    nb_points = domain.geometry.x[nb_index]

    boundary_index = db_index + nb_index
    index = set(range(domain.geometry.x.shape[0]))
    interior_points = domain.geometry.x[sorted(list(index - set(db_index + nb_index)))]


    boundary_dofs = fem.locate_dofs_topological(V, 1, db_elems)
    uD = fem.Function(V)
    g = fem.Function(V)
    bc = fem.dirichletbc(uD, boundary_dofs)

    u = ufl.TrialFunction(V)
    v = ufl.TestFunction(V)
    #x = ufl.SpatialCoordinate(domain)
    f = fem.Constant(domain, ScalarType(0))
    #g = ufl.sin(5 * x[0])
    a = inner(grad(u), grad(v)) * dx
    L = inner(f, v) * dx - g * v * ds
    uh = fem.Function(V)

    problem = fem.petsc.LinearProblem(a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"})
    



    datalist = []
    for i in tqdm(range(100)):
        uD.x.array[db_index] = [random.random() for _ in db_index]
        g.x.array[nb_index] = [random.random() for _ in nb_index]
        uh = problem.solve()
        
        sol = np.concatenate([domain.geometry.x[:, [0, 1]], uh.x.array[..., np.newaxis]], axis=1)
        dbc = np.concatenate([db_points[:, [0, 1]], 
                            uD.x.array[db_index, np.newaxis], 
                            np.zeros((len(db_index), 1))], axis=1)
        nbc = np.concatenate([nb_points[:, [0, 1]], 
                            g.x.array[nb_index, np.newaxis], 
                            np.ones((len(nb_index), 1))], axis=1)
        bc = np.concatenate([dbc, nbc], axis=0)
        
        datalist.append((sol, [bc]))
    return datalist

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training data generation')
    parser.add_argument('--domain-type',type=str)
    args = parser.parse_args()

    datalist = generate_solution(args.domain_type)
    with open(f"data/2d/laplace2d_{args.domain_type}_100_test" + ".pkl", 'wb') as file:
        pickle.dump(datalist, file)

