import os
from pathlib import Path
import trimesh
import numpy as np
import einops
import matplotlib.pyplot as plt

from src.utils.plotting import get_neighbor_plot_indices_for_plane

def get_mask_path(root, slice_axis, axis_value, resolution):
    return os.path.join(root, f'inside_mask_{slice_axis}={axis_value}_resolution_{resolution}.npy')

class Plotter_3d:
    def __init__(self, root_path):
        self.root = Path(root_path)
        boundary_mesh_path = self.root / 'boundary.stl'
        self.boundary_mesh = trimesh.load(boundary_mesh_path)
        
    
    def plot(self, mesh_centers, field, slice_axis, axis_value, resolution=100):
        
        plane_points, indices, extent = get_neighbor_plot_indices_for_plane(mesh_centers, slice_axis=slice_axis, axis_value=axis_value, grid_resolution=resolution)
        
        mask_path = get_mask_path(self.root, slice_axis, axis_value, resolution)
        if mask_path.exists():
            inside_mask = np.load(mask_path)
        else:
            print(f'recompute inside points')
            inside_mask = self.boundary_mesh.contains(einops.rearrange(plane_points, 'h w d -> (h w) d'))
            inside_mask = einops.rearrange(inside_mask, '(h w) -> h w', h=plane_points.shape[0])
            np.save(mask_path, inside_mask)
        
        norm_u = np.linalg.norm(field, axis=1)
        norm_u_filtered = norm_u[indices]
        norm_u_filtered[~inside_mask] = 0

        plt.figure(figsize=(8, 6))
        sc = plt.imshow(norm_u_filtered, cmap='viridis', origin='lower', extent=extent)
        plt.colorbar(sc, label='norm(u)')
        plt.show()