

import gmsh
from dolfinx.io import gmshio
from mpi4py import MPI
import numpy as np
from dolfinx import fem, io, mesh, plot
from petsc4py.PETSc import ScalarType
from ufl import ds, dx, grad, inner
import ufl
import matplotlib.pyplot as plt
from dolfinx.fem import FunctionSpace
import random
from tqdm import tqdm
import pickle
from dolfinx.io import XDMFFile
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from dolfinx.fem import apply_lifting, set_bc
from petsc4py import PETSc
from scipy import sparse
from polygenerator import (
    random_polygon,
    random_star_shaped_polygon,
    random_convex_polygon,
)
import random
from itertools import pairwise
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
import sys
sys.path.append("./")
import argparse
from utils.polygon import plot_polygon, generate_polygon_mesh, get_counterclockwise_dolphinx_boundary, generate_polygon
from dolfinx.fem.petsc import assemble_vector, assemble_matrix, create_vector, apply_lifting, set_bc
# this is just so that you can reproduce the same results
#random.seed(5)
gmsh.initialize()

def generate_solution(batch_pp):
    # generate batch_pp solutions with random boundary values on a random simple polygon in gmsh
    
    t = 0  # Start time
    T = 0.1 # Final time
    num_steps = 10
    dt = T / num_steps  # time step size

    gmsh.model.setCurrent('p1')
    gmsh_model_rank = 0
    mesh_comm = MPI.COMM_WORLD
    gmsh_model_rank = 0
    
    model = gmsh.model
    
    # convert to dophinx mesh
    domain, cell_markers, facet_markers = gmshio.model_to_mesh(model, mesh_comm, gmsh_model_rank, gdim=2)
    V = fem.FunctionSpace(domain, ("Lagrange", 1))
    
    # get facets with physical group 1
    boundary_elems = facet_markers.find(1).tolist()
    _, boundary_index, boundary_points = get_counterclockwise_dolphinx_boundary(domain, boundary_elems)
    
    index = set(range(domain.geometry.x.shape[0]))
    interior_points = domain.geometry.x[sorted(list(index - set(boundary_index)))][:, 0:2]
    
    # represent "previous" time step
    u_n = fem.Function(V)
    

    boundary_dofs = fem.locate_dofs_topological(V, 1, boundary_elems)

    uD = fem.Function(V)

    bc = fem.dirichletbc(uD, boundary_dofs)

    # represent "current" time step to solve
    uh = fem.Function(V)
    uh.name = "uh"

    u = ufl.TrialFunction(V)
    v = ufl.TestFunction(V)

    q = fem.Function(V)
    #x = ufl.SpatialCoordinate(domain)
    f = fem.Constant(domain, ScalarType(0))
    alpha = fem.Function(V)
    
    a = u * v * ufl.dx + alpha * dt * ufl.dot(ufl.grad(u), ufl.grad(v)) * ufl.dx
    L = (u_n + dt * f) * v * ufl.dx

    
    
    sols, alphas, bcs = [], [], []
    for _ in range(batch_pp):
        series, bc_series = [], []
        alpha.x.array[:] = [random.uniform(0.1, 1.0)] * alpha.x.array.shape[0]

        bilinear_form = fem.form(a)
        linear_form = fem.form(L)
        b = create_vector(linear_form)
        A = assemble_matrix(bilinear_form, bcs=[bc])
        A.assemble()
        
        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)
        # random initial condition
        u_n.x.array[:] = [random.random() for _ in range(u_n.x.array.shape[0])]
        
        uh.x.array[:] = u_n.x.array
        # the first boundary condition match initial condition
        bc_series.append(np.copy(uh.x.array[list(boundary_index)]))
        
        
        for i in range(num_steps):
            # random boundary condition at each time step
            uD.x.array[list(boundary_index)] = [random.random() for _ in boundary_index]
            
            series.append(np.copy(uh.x.array[:]))
            bc_series.append(np.copy(uD.x.array[list(boundary_index)]))
            
            #t += dt

            # Update the right hand side reusing the initial vector
            with b.localForm() as loc_b:
                loc_b.set(0)
            assemble_vector(b, linear_form)

            # Apply Dirichlet boundary condition to the vector
            apply_lifting(b, [bilinear_form], [[bc]])
            b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b, [bc])

            # Solve linear problem
            solver.solve(b, uh.vector)
            uh.x.scatter_forward()

            # Update solution at previous time step (u_n)
            u_n.x.array[:] = uh.x.array

        space_shift = np.array([0.5, 0.5])
        series = np.concatenate([domain.geometry.x[:, [0, 1]] - space_shift, 
                                 np.concatenate([s[..., np.newaxis] for s in series], axis=1)], axis=1)
        bc_series = np.concatenate([boundary_points[:, [0, 1]] - space_shift, 
                                 np.concatenate([s[..., np.newaxis] for s in bc_series[0:-1]], axis=1),
                                 np.zeros((boundary_points.shape[0], 1))], axis=1)
        
        sols.append(series)
        alphas.append(alpha.x.array[0])
        bcs.append(bc_series)
    return sols, alphas, bcs

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training data generation')
    parser.add_argument('--count',type=int)
    args = parser.parse_args()

    # parser.add_argument('--component',type=str,
    #                     default='all',)



    # parser.add_argument('--seed', type=int, default=2023, metavar='Seed',
    #                     help='random seed (default: 1127802)')

    # parser.add_argument('--gpu', type=int, default=0, help='gpu id')
    # parser.add_argument('--use-tb', type=int, default=0, help='whether use tensorboard')
    # parser.add_argument('--comment',type=str,default="",help="comment for the experiment")

    # parser.add_argument('--train-num', type=str, default='all')
    # parser.add_argument('--test-num', type=str, default='all')

    datalist = []
    for i in tqdm(range(250)):
        # randomly generate number of polygon points
        num_points = random.randrange(3, 12)
        num_batch = 50
        generate_polygon_mesh(num_points, 'simple')
        sols, alphas, bcs = generate_solution(num_batch)
        datalist += [(sol, alpha, [bc]) for sol, alpha, bc in\
                 zip(sols, alphas, bcs)]
    # train
    # with open("data/3d/laplace3d_n_simple_20000_train_" + f"{args.count:02}" + ".pkl", 'wb') as file:
    #     pickle.dump(datalist, file)
    
    with open("data/2d/heat2d_simple_12500_test_" + f"{args.count:03}" + ".pkl", 'wb') as file:
        pickle.dump(datalist, file)