

import gmsh
from dolfinx.io import gmshio
from mpi4py import MPI
import numpy as np
from dolfinx import fem, io, mesh, plot
from petsc4py.PETSc import ScalarType
from ufl import ds, dx, grad, inner
import ufl
import matplotlib.pyplot as plt
from dolfinx.fem import FunctionSpace
import random
from tqdm import tqdm
import pickle
from dolfinx.io import XDMFFile
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from dolfinx.fem import apply_lifting, set_bc
from petsc4py import PETSc
from scipy import sparse
from polygenerator import (
    random_polygon,
    random_star_shaped_polygon,
    random_convex_polygon,
)
import random
from itertools import pairwise
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
import sys
sys.path.append("./")
import argparse
from utils.polygon import plot_polygon, generate_polygon_mesh, get_counterclockwise_dolphinx_boundary, generate_polygon

# this is just so that you can reproduce the same results
#random.seed(5)
gmsh.initialize()

def generate_solution(batch_pp):
    # generate batch_pp solutions with random boundary values on a random simple polygon in gmsh
    
    gmsh.model.setCurrent('p1')
    gmsh_model_rank = 0
    mesh_comm = MPI.COMM_WORLD
    gmsh_model_rank = 0
    
    model = gmsh.model
    
    # convert to dophinx mesh
    domain, cell_markers, facet_markers = gmshio.model_to_mesh(model, mesh_comm, gmsh_model_rank, gdim=2)
    V = fem.FunctionSpace(domain, ("Lagrange", 1))
    
    # get facets with physical group 1
    boundary_elems = facet_markers.find(1)
    _, boundary_index, boundary_points = get_counterclockwise_dolphinx_boundary(domain, boundary_elems.tolist())
    
    index = set(range(domain.geometry.x.shape[0]))
    interior_points = domain.geometry.x[sorted(list(index - set(boundary_index)))]
    
    boundary_dofs = fem.locate_dofs_topological(V, 1, boundary_elems)
    uD = fem.Function(V)
    bc = fem.dirichletbc(uD, boundary_dofs)
    u = ufl.TrialFunction(V)
    v = ufl.TestFunction(V)
    #x = ufl.SpatialCoordinate(domain)
    q = fem.Function(V)
    f = fem.Function(V)
    #g = ufl.sin(5 * x[0])
    a = q * inner(grad(u), grad(v)) * dx
    L = inner(f, v) * dx
    
    
    #if 'problem' in globals():
    #    del problem
    problem = fem.petsc.LinearProblem(a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"})
    
    sols, qfs, bcs = [], [], []
    space_shift = np.array([0.5, 0.5])
    for _ in range(batch_pp):
        
        # for darcy flow, uD has to have uniform range
        random_range = random.uniform(0.3, 1.0)
        uD.x.array[list(boundary_index)] = [random.random() * random_range for _ in boundary_index]

        f.x.array[:] = [random.random() for _ in range(f.x.array.shape[0])]
        q.x.array[:] = [random.random() for _ in range(q.x.array.shape[0])]
        uh = problem.solve()
    
        sol = np.concatenate([domain.geometry.x[:, [0, 1]] - space_shift, uh.x.array[..., np.newaxis]], axis=1)
        qf = np.concatenate([domain.geometry.x[:, [0, 1]] - space_shift, q.x.array[..., np.newaxis], f.x.array[..., np.newaxis]], axis=1)
        boundary_condition = np.concatenate([boundary_points[:, [0, 1]] - space_shift, 
                                             uD.x.array[boundary_index, np.newaxis],
                                            np.zeros((len(boundary_index), 1))], axis=1)
        
        sols.append(sol)
        qfs.append(qf)
        bcs.append(boundary_condition)
    return sols, qfs, bcs
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training data generation')
    parser.add_argument('--count',type=int)
    args = parser.parse_args()

    datalist = []
    for i in tqdm(range(250)):
        # randomly generate number of polygon points
        num_points = random.randrange(3, 16)
        num_batch = 10
        generate_polygon_mesh(num_points, 'simple')
        sols, qfs, bcs = generate_solution(num_batch)
        datalist += [(sol, [qf, bc]) for sol, qf, bc in\
                 zip(sols, qfs, bcs)]
    
    with open("data/2d/darcy2d_simple_2500_test_" + f"{args.count:03}" + ".pkl", 'wb') as file:
        pickle.dump(datalist, file)