

import gmsh
from dolfinx.io import gmshio
from mpi4py import MPI
import numpy as np
from dolfinx import fem, io, mesh, plot
from petsc4py.PETSc import ScalarType
from ufl import ds, dx, grad, inner
import ufl
import matplotlib.pyplot as plt
from dolfinx.fem import FunctionSpace
import random
from tqdm import tqdm
import pickle
from dolfinx.io import XDMFFile
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from dolfinx.fem import apply_lifting, set_bc
from petsc4py import PETSc
from scipy import sparse
from polygenerator import (
    random_polygon,
    random_star_shaped_polygon,
    random_convex_polygon,
)
import random
from itertools import pairwise
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
import sys
sys.path.append("./")
import argparse
from utils.polygon import plot_polygon, generate_polygon_mesh, get_counterclockwise_dolphinx_boundary, generate_polygon

# this is just so that you can reproduce the same results
#random.seed(5)
gmsh.initialize()

def generate_solution(batch_pp, p=0.5):
    # generate batch_pp solutions with random boundary values on a random simple polygon in gmsh
    # p stand for the ratio of mixed boundary problem
    
    gmsh.model.setCurrent('p1')
    gmsh_model_rank = 0
    mesh_comm = MPI.COMM_WORLD
    gmsh_model_rank = 0
    
    model = gmsh.model
    
    # convert to dophinx mesh
    domain, cell_markers, facet_markers = gmshio.model_to_mesh(model, mesh_comm, gmsh_model_rank, gdim=2)
    V = fem.FunctionSpace(domain, ("Lagrange", 1))
    
    # get facets with physical group 1
    boundary_elems = facet_markers.find(1).tolist()
    # get boundary locations
    boundary_elems_with_vertices, boundary_index, boundary_points = get_counterclockwise_dolphinx_boundary(domain, boundary_elems)
    
    sols, bcs = [], []
    
    for _ in range(batch_pp):
        p = 0.2
        mixed = True if random.random() >= p else False
        if mixed:
            # with at least one Neumann boundary element
            # start and end (exclusive) element of dirichlet boundary
            start = random.randint(0, len(boundary_elems_with_vertices) - 1)
            end = (random.randint(len(boundary_elems_with_vertices) // 2, len(boundary_elems_with_vertices)) + start - 1) % len(boundary_elems_with_vertices)

            if end > start:
                db_elems = [e for e, v1, v2 in boundary_elems_with_vertices[start:end]]
                db_index = [v1 for e, v1, v2 in boundary_elems_with_vertices[start:end]]
                db_index.append(boundary_elems_with_vertices[end-1][2])

                nb_elems = [e for e ,v1, v2 in boundary_elems_with_vertices[end:] + boundary_elems_with_vertices[:start]]
                nb_index = [v1 for e, v1, v2 in boundary_elems_with_vertices[end:] + boundary_elems_with_vertices[:start]]
                nb_index.append(boundary_elems_with_vertices[start-1][2])

            elif end < start:
                db_elems = [e for e, v1, v2 in boundary_elems_with_vertices[start:] + boundary_elems_with_vertices[:end]]
                db_index = [v1 for e, v1, v2 in boundary_elems_with_vertices[start:] + boundary_elems_with_vertices[:end]]
                db_index.append(boundary_elems_with_vertices[end-1][2])

                nb_elems = [e for e, v1, v2 in boundary_elems_with_vertices[end:start]]
                nb_index = [v1 for e, v1, v2 in boundary_elems_with_vertices[end:start]]
                nb_index.append(boundary_elems_with_vertices[start-1][2])

        else:
            db_elems = [e for e, v1, v2 in boundary_elems_with_vertices]
            db_index = [v1 for e, v1, v2 in boundary_elems_with_vertices]
            nb_elems = []
            nb_index = []

        db_points = domain.geometry.x[db_index]
        # if db_points.shape[0] < 2:
        #     gmsh.write("p1.msh")
        nb_points = domain.geometry.x[nb_index]

        index = set(range(domain.geometry.x.shape[0]))
        interior_points = domain.geometry.x[sorted(list(index - set(boundary_index)))][:, 0:2]

        boundary_dofs = fem.locate_dofs_topological(V, 1, db_elems)
        uD = fem.Function(V)
        bc = fem.dirichletbc(uD, boundary_dofs)
        g = fem.Function(V)
        u = ufl.TrialFunction(V)
        v = ufl.TestFunction(V)
        #x = ufl.SpatialCoordinate(domain)
        f = fem.Constant(domain, ScalarType(0))
        #g = ufl.sin(5 * x[0])
        a = inner(grad(u), grad(v)) * dx
        L = inner(f, v) * dx - g * v * ds

        #if 'problem' in globals():
        #    del problem
        problem = fem.petsc.LinearProblem(a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"})

        if len(nb_index) > 0:
            random_range = random.uniform(0.1, 1.0)
            if random.random() >= 0.5:
                uD.x.array[db_index] = [random.random() * random_range for _ in db_index]
                g.x.array[nb_index] = [random.random()  for _ in nb_index]
            else:
                uD.x.array[db_index] = [random.random() for _ in db_index]
                g.x.array[nb_index] = [random.random() * random_range  for _ in nb_index]
        else:
            uD.x.array[db_index] = [random.random() for _ in db_index]
        uh = problem.solve()

        space_shift = np.array([0.5, 0.5])
        sol = np.concatenate([domain.geometry.x[:, [0, 1]] - space_shift, uh.x.array[..., np.newaxis]], axis=1)
        # dbc = np.concatenate([db_points, uD.x.array[db_index, np.newaxis]], axis=1)
        # nbc = np.concatenate([nb_points, g.x.array[nb_index, np.newaxis]], axis=1)
        dbc = np.concatenate([db_points[:, [0, 1]] - space_shift, uD.x.array[db_index, np.newaxis], np.zeros((len(db_index), 1))], axis=1)
        nbc = np.concatenate([nb_points[:, [0, 1]] - space_shift, g.x.array[nb_index, np.newaxis], np.ones((len(nb_index), 1))], axis=1)
        bc = np.concatenate([dbc, nbc], axis=0)
        sols.append(sol)
        bcs.append(bc)
        #dbcs.append(dbc)
        #nbcs.append(nbc)
    return sols, bcs

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training data generation')
    parser.add_argument('--count',type=int)
    args = parser.parse_args()

    # parser.add_argument('--component',type=str,
    #                     default='all',)



    # parser.add_argument('--seed', type=int, default=2023, metavar='Seed',
    #                     help='random seed (default: 1127802)')

    # parser.add_argument('--gpu', type=int, default=0, help='gpu id')
    # parser.add_argument('--use-tb', type=int, default=0, help='whether use tensorboard')
    # parser.add_argument('--comment',type=str,default="",help="comment for the experiment")

    # parser.add_argument('--train-num', type=str, default='all')
    # parser.add_argument('--test-num', type=str, default='all')

    datalist = []
    for i in tqdm(range(10)):
        # randomly generate number of polygon points
        num_points = random.randrange(3, 12)
        num_batch = 20
        generate_polygon_mesh(num_points, 'simple')
        sols, bcs = generate_solution(num_batch)
        datalist += [(sol, [bc]) for sol, bc in zip(sols, bcs)]
    # train
    # with open("data/3d/laplace3d_n_simple_20000_train_" + f"{args.count:02}" + ".pkl", 'wb') as file:
    #     pickle.dump(datalist, file)
    
    with open("data/2d/laplace2d_n_simple_4000_test_" + f"{args.count:03}" + ".pkl", 'wb') as file:
        pickle.dump(datalist, file)