from meshpy.triangle import build, MeshInfo
import meep as mp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import pickle, os
import random

num_of_simulations = 100
num_of_steps = 100
folder_path = './dataset/data_re40-80_fr1/'

step_plot = False
simulation_plot = False
file_generation = True
mesh_resolution = 1000

class Triangle:
    def __init__(self, a, b, c):
        self.vertices = (a, b, c)
        self.edges = self.get_edges()

    def get_edges(self):
        vertices_list = list(self.vertices)
        vertices_list.sort()
        edges = [
            (vertices_list[0], vertices_list[1]),
            (vertices_list[1], vertices_list[2]),
            (vertices_list[0], vertices_list[2]),
            (vertices_list[1], vertices_list[0]),
            (vertices_list[2], vertices_list[1]),
            (vertices_list[2], vertices_list[0]),
        ]
        # adding self-loops
        # for i in range(3):
        #     edges.append((vertices_list[i], vertices_list[i]))
        return edges

def get_triangle_edges(triangles):
    edges = set()
    for triangle in triangles:
        edges.update(triangle.edges)
    return list(edges)

# Generating a mesh on n random vertices in a rectangle
def GenerateRandomMesh(bottom_left, top_right, n, resolution=mesh_resolution):
    min_x, min_y = bottom_left
    max_x, max_y = top_right
    area = (max_x - min_x) * (max_y - min_y)

    x_values = np.random.uniform(min_x, max_x, n)
    y_values = np.random.uniform(min_y, max_y, n)

    points = list(zip(x_values, y_values))
    points.append(bottom_left)
    points.append((min_x, max_y))
    points.append(top_right)
    points.append((max_x, min_y))

    mesh_info = MeshInfo()
    mesh_info.set_points(points)
    mesh_info.number_of_point_attributes = 2
    mesh_info.set_facets([[n, n+1], [n+1, n+2], [n+2, n+3], [n, n+3]])
    maxvolume = area/resolution + np.random.normal(0, area/(5*resolution))
    mesh = build(mesh_info, max_volume=maxvolume)

    return mesh

# plot a given mesh
def plot_triangular_mesh(mesh, mesh_data, color):
    x = []
    y = []
    for i in mesh.points:
        x.append(i[0])
        y.append(i[1])
    triangles = mesh.elements
    triang = mtri.Triangulation(x, y, triangles)
    plt.tricontourf(triang, mesh_data, cmap=color)
    plt.triplot(triang)
    plt.axis('off')
    plt.show()

# Custom simulation step, which at each point of time calculates and
# stores grid structure and vertex and face attributes
def triang_step(sim):
    # assigning field values and epsilon values as point attributes
    Points = []
    for i in range(len(mesh.points)):
        mesh.point_attributes[i, 0] = sim.get_field_point(c=mp.Ez, pt=mesh.points[i]).real
        mesh.point_attributes[i, 1] = sim.get_epsilon_point(pt=mesh.points[i]).real
        Points.append(mesh.points[i])

    mesh_attributes = np.array(mesh.point_attributes)
    if step_plot:
        plot_triangular_mesh(mesh, mesh_attributes[:, 0], 'RdBu')

    # assigning field
    mesh_face_attr = np.zeros((mesh_face_num, 4))
    for i in range(mesh_face_num):
        if mesh_faces[i][0] == mesh_faces[i][1]:
            mesh_face_attr[i, :] = np.zeros((1, 4))
        else:
            vertex1 = np.array(mesh.points[mesh_faces[i][0]])
            vertex2 = np.array(mesh.points[mesh_faces[i][1]])
            middle = (vertex1 + vertex2)/2
            mesh_face_attr[i, :2] = vertex1 - vertex2
            mesh_face_attr[i, 2] = sim.get_field_point(c=mp.Hx, pt=middle).real
            mesh_face_attr[i, 3] = sim.get_field_point(c=mp.Hy, pt=middle).real

    local_data = {
        'PointAttributes': mesh_attributes, 'Points': Points,
        'FaceAttributes': mesh_face_attr,
    }
    AttributeData.append(local_data)

# create a folder to store the data
if file_generation:
    os.makedirs(folder_path, exist_ok=True)

# main generation loop
for i in range(60, 100):  # num_of_simulations
    mesh = GenerateRandomMesh((-1, -1), (1, 1), 10)
    # Unpacking the list of edges of triangulation from the list of triangles
    mesh_elements = np.array(mesh.elements)
    mesh_triangles = []
    for j in range(len(mesh.elements)):
        mesh_triangles.append(Triangle(*mesh_elements[j, :]))
    mesh_faces = get_triangle_edges(mesh_triangles)
    mesh_face_num = len(mesh_faces)

    MeshData = {
        'GraphStructure': mesh_faces,
        'NumberOfVertices': len(mesh.points),
        'NumberOfFaces': mesh_face_num
    }

    # Storing mesh information
    if file_generation:
        file_path = os.path.join(folder_path, f'MeshData_{i}.pkl')
        with open(file_path, 'wb') as file:
            pickle.dump(MeshData, file)

    AttributeData=[]
    cell = mp.Vector3(2, 2, 0)

    epsilon_length = random.uniform(0.2, 2)
    epsilon_height = random.uniform(0.2, 2)
    permittivity = random.uniform(2, 15)
    frequency = 1 # random.uniform(1, 2)
    time = 50 #random.randint(100, 200)
    center_x = random.uniform(-0.5, 0.5)
    center_y = random.uniform(-0.5, 0.5)

    # setting up and running the simulation
    geometry = [mp.Block(mp.Vector3(epsilon_length, epsilon_height, mp.inf), center=mp.Vector3(), material=mp.Medium(epsilon=permittivity))]
    sources = [mp.Source(mp.ContinuousSource(frequency=frequency), component=mp.Ez, center=mp.Vector3(center_x, center_y))]
    pml_layers = [mp.PML(0.2)]
    resolution = random.randint(40, 80) # 60
    sim = mp.Simulation(cell_size=cell, boundary_layers=pml_layers, geometry=geometry, sources=sources, resolution=resolution)
    sim.run(until=time)
    sim.run(triang_step, until=(num_of_steps-1)*0.005*resolution)

    # storing mesh attributes
    if file_generation:
        file_path = os.path.join(folder_path, f'MeshAttributes_{i}.pkl')
        with open(file_path, 'wb') as file:
            pickle.dump(AttributeData, file)

    # plotting the result of the simulation
    if simulation_plot:
        eps_data = sim.get_array(center=mp.Vector3(), size=cell, component=mp.Dielectric)
        plt.figure()
        plt.imshow(eps_data.transpose(), interpolation='spline36', cmap='binary')
        plt.axis('off')
        plt.show()

        ez_data = sim.get_array(center=mp.Vector3(), size=cell, component=mp.Ez)
        plt.figure()
        plt.imshow(ez_data.transpose(), interpolation='spline36', cmap='RdBu', alpha=0.9)
        plt.axis('off')
        plt.show()