import os
from turtle import circle
import einops
import torch
import numpy as np
import trimesh

from GINN.geometry.constraints import BoundingBox2DConstraint, CircleObstacle2D, CompositeInterface2D, Envelope2D, LineInterface2D, CompositeConstraint, SampleConstraint, SampleConstraintWithNormals, SampleEnvelope
from models.model_utils import tensor_product_xz
from models.point_wrapper import PointWrapper
from visualization.utils_mesh import get_watertight_mesh_for_latent
from utils import get_is_out_mask, inflate_bounds
from visualization.utils_mesh import get_2d_contour_for_grid, get_meshgrid_for_marching_squares, get_meshgrid_in_domain, get_mesh

def t_(x):
    return torch.tensor(x, dtype=torch.float32)

class ProblemSampler():
    
    def __init__(self, config) -> None:
        self.config = config
        device = self.config['device']
        
        self._envelope_constr = None
        self._interface_constraints = []
        self._normal_constraints = []        
        self._obstacle_constraints = []
        
        if self.config['problem'] == 'simple_2d':
            self.config['bounds'] = t_([[-1, 1],[-0.5, 0.5]])  # [[x_min, x_max], [y_min, y_max]]
            self.envelope = np.array([[-.9, 0.9], [-0.4, 0.4]])
            self.obst_1_center = [0, 0]
            self.obst_1_radius = 0.1
            self.interface_left = np.array([[-0.9, -0.4], [-0.9, 0.4]])
            self.interface_right = np.array([[0.9, -0.4], [0.9, 0.4]])
            
            envelope = Envelope2D(env_bbox=t_(self.envelope), bounds=self.config['bounds'], device=device, sample_from=self.config['envelope_sample_from'])
            domain = BoundingBox2DConstraint(bbox=self.config['bounds'])            
            
            # TODO: normals should be computed from the interface definition
            l_target_normal = t_([-1.0, 0.0])
            r_target_normal = t_([1.0, 0.0])
            l_bc = LineInterface2D(start=t_(self.interface_left[0]), 
                                   end=t_(self.interface_left[1]), 
                                   target_normal=l_target_normal)
            r_bc = LineInterface2D(start=t_(self.interface_right[0]),
                                      end=t_(self.interface_right[1]),
                                      target_normal=r_target_normal)
            all_interfaces = CompositeInterface2D([l_bc, r_bc])
            
            circle_obstacle_1 = CircleObstacle2D(center=t_(self.obst_1_center), radius=t_(self.obst_1_radius))
            
            # sample once and keep the points; these are used for plotting
            self.constr_pts_dict = {
                'envelope': envelope.get_sampled_points(N=self.config['n_points_envelope']).cpu().numpy().T,
                'interface': all_interfaces.get_sampled_points(N=self.config['n_points_interfaces'])[0].cpu().numpy().T,
                'obstacles': circle_obstacle_1.get_sampled_points(N=self.config['n_points_obstacles']).cpu().numpy().T,
                'domain': domain.get_sampled_points(N=self.config['n_points_domain']).cpu().numpy().T,
            }
            
            # save the constraints
            self._envelope_constr = [envelope]
            self._interface_constraints = [l_bc, r_bc]
            self._obstacle_constraints = [circle_obstacle_1]
            self._domain = domain
            
            ##
            self.X0_ms, _, xs_ms = get_meshgrid_for_marching_squares(self.config['bounds'].cpu().numpy())
            self.xs_ms = torch.tensor(xs_ms, dtype=torch.float32, device=self.config['device'])
            
            self.bounds = config['bounds'].cpu()
            ## For plotting
            self.X0, self.X1, self.xs = get_meshgrid_in_domain(self.bounds) # inflate bounds for better visualization
            self.xs = torch.tensor(self.xs, dtype=torch.float32, device=device)
            
        elif self.config['problem'] == 'double_obstacle':
            self.config['bounds'] = t_([[-1.5, 1.5],[-0.5, 0.5]])  # [[x_min, x_max], [y_min, y_max]]
            self.envelope = np.array([[-1.4, 1.4], [-0.4, 0.4]])
            self.interface_left = np.array([[-1.4, -0.4], [-1.4, 0.4]])
            self.interface_right = np.array([[1.4, -0.4], [1.4, 0.4]])
            self.obst_1_center = [-0.5, 0]
            self.obst_1_radius = 0.1
            self.obst_2_center = [0.5, 0]
            self.obst_2_radius = 0.1
            
            envelope = Envelope2D(env_bbox=t_(self.envelope), bounds=self.config['bounds'], device=device, sample_from=self.config['envelope_sample_from'])
            domain = BoundingBox2DConstraint(bbox=self.config['bounds'])            
            
            l_target_normal = t_([-1.0, 0.0])
            r_target_normal = t_([1.0, 0.0])
            l_bc = LineInterface2D(start=t_(self.interface_left[0]), 
                                   end=t_(self.interface_left[1]), 
                                   target_normal=l_target_normal)
            r_bc = LineInterface2D(start=t_(self.interface_right[0]),
                                      end=t_(self.interface_right[1]),
                                      target_normal=r_target_normal)
            all_interfaces = CompositeInterface2D([l_bc, r_bc])
            
            circle_obstacle_1 = CircleObstacle2D(center=t_(self.obst_1_center), radius=t_(self.obst_1_radius))
            circle_obstacle_2 = CircleObstacle2D(center=t_(self.obst_2_center), radius=t_(self.obst_2_radius))
            
            # sample once and keep the points; these are used for plotting
            self.constr_pts_dict = {
                'envelope': envelope.get_sampled_points(N=self.config['n_points_envelope']).cpu().numpy().T,
                'interface': all_interfaces.get_sampled_points(N=self.config['n_points_interfaces'])[0].cpu().numpy().T,
                'obstacles': np.concatenate([circle_obstacle_1.get_sampled_points(N=self.config['n_points_obstacles']//2).cpu().numpy().T,
                                      circle_obstacle_2.get_sampled_points(N=self.config['n_points_obstacles']//2).cpu().numpy().T], axis=1),
                'domain': domain.get_sampled_points(N=self.config['n_points_domain']).cpu().numpy().T,
            }
            
            # save the constraints
            self._envelope_constr = [envelope]
            self._interface_constraints = [l_bc, r_bc]
            self._obstacle_constraints = [circle_obstacle_1, circle_obstacle_2]
            self._domain = domain
            
            ##
            self.X0_ms, _, xs_ms = get_meshgrid_for_marching_squares(self.config['bounds'].cpu().numpy())
            self.xs_ms = torch.tensor(xs_ms, dtype=torch.float32, device=self.config['device'])
            
        elif self.config['problem'] == 'pipes':
            self.config['bounds'] = t_([[-0.1, 1.6],[-0.1, 1.1]])  # [[x_min, x_max], [y_min, y_max]]
            # see paper page 15 - https://arxiv.org/pdf/2004.11797.pdf
            envelope = Envelope2D(env_bbox=t_([[0, 1.5],[0, 1]]), bounds=self.config['bounds'], device=device, sample_from=self.config['envelope_sample_from'])
            domain = BoundingBox2DConstraint(bbox=self.config['bounds'])
            
            l_target_normal = t_([-1.0, 0.0])
            r_target_normal = t_([1.0, 0.0])
            l_bc_1 = LineInterface2D(start=t_([0, 0.25 - 1/12]), end=t_([0, 0.25 + 1/12]), target_normal=l_target_normal)
            l_bc_2 = LineInterface2D(start=t_([0, 0.75 - 1/12]), end=t_([0, 0.75 + 1/12]), target_normal=l_target_normal)
            r_bc_1 = LineInterface2D(start=t_([1.5, 0.25 - 1/12]), end=t_([1.5, 0.25 + 1/12]), target_normal=r_target_normal)
            r_bc_2 = LineInterface2D(start=t_([1.5, 0.75 - 1/12]), end=t_([1.5, 0.75 + 1/12]), target_normal=r_target_normal)
            
            edge_in = 0.05
            upper_target_normal = t_([0.0, 1.0])
            lower_target_normal = t_([0.0, -1.0])
            l_bc_1_upper = LineInterface2D(start=t_([0, 0.25 + 1/12]), end=t_([edge_in, 0.25 + 1/12]), target_normal=upper_target_normal)
            l_bc_1_lower = LineInterface2D(start=t_([0, 0.25 - 1/12]), end=t_([edge_in, 0.25 - 1/12]), target_normal=lower_target_normal)
            l_bc_2_upper = LineInterface2D(start=t_([0, 0.75 + 1/12]), end=t_([edge_in, 0.75 + 1/12]), target_normal=upper_target_normal)
            l_bc_2_lower = LineInterface2D(start=t_([0, 0.75 - 1/12]), end=t_([edge_in, 0.75 - 1/12]), target_normal=lower_target_normal)
            
            r_bc_1_upper = LineInterface2D(start=t_([1.5, 0.25 + 1/12]), end=t_([1.5 - edge_in, 0.25 + 1/12]), target_normal=upper_target_normal)
            r_bc_1_lower = LineInterface2D(start=t_([1.5, 0.25 - 1/12]), end=t_([1.5 - edge_in, 0.25 - 1/12]), target_normal=lower_target_normal)
            r_bc_2_upper = LineInterface2D(start=t_([1.5, 0.75 + 1/12]), end=t_([1.5 - edge_in, 0.75 + 1/12]), target_normal=upper_target_normal)
            r_bc_2_lower = LineInterface2D(start=t_([1.5, 0.75 - 1/12]), end=t_([1.5 - edge_in, 0.75 - 1/12]), target_normal=lower_target_normal)
            
            all_interfaces = CompositeInterface2D([l_bc_1, l_bc_2, r_bc_1, r_bc_2,
                                                    l_bc_1_upper, l_bc_1_lower, l_bc_2_upper, l_bc_2_lower,
                                                    r_bc_1_upper, r_bc_1_lower, r_bc_2_upper, r_bc_2_lower,
                                                    ])
            
            # TODO: the obstacles are Decagons, not circles; probably not worth the effort though
            # the holes are described in the paper page 19, - https://arxiv.org/pdf/2004.11797.pdf
            circle_obstacle_1 = CircleObstacle2D(center=t_([0.5, 1.0/3]), radius=t_(0.05))
            circle_obstacle_2 = CircleObstacle2D(center=t_([0.5, 2.0/3]), radius=t_(0.05))
            circle_obstacle_3 = CircleObstacle2D(center=t_([1.0, 1.0/4]), radius=t_(0.05))
            circle_obstacle_4 = CircleObstacle2D(center=t_([1.0, 2.0/4]), radius=t_(0.05))
            circle_obstacle_5 = CircleObstacle2D(center=t_([1.0, 3.0/4]), radius=t_(0.05))
            all_obstacles = CompositeConstraint([circle_obstacle_1, circle_obstacle_2, 
                                                   circle_obstacle_3, circle_obstacle_4, circle_obstacle_5])
            
            # sample once and keep the points; these are used for plotting
            self.constr_pts_dict = {
                'envelope': envelope.get_sampled_points(N=self.config['n_points_envelope']).cpu().numpy().T,
                'interface': all_interfaces.get_sampled_points(N=self.config['n_points_interfaces'])[0].cpu().numpy().T,
                'obstacles': all_obstacles.get_sampled_points(N=self.config['n_points_obstacles']).cpu().numpy().T,
                'domain': domain.get_sampled_points(N=self.config['n_points_domain']).cpu().numpy().T,
            }
            
            # save the constraints
            self._envelope_constr = [envelope]
            self._interface_constraints = [l_bc_1, l_bc_2, r_bc_1, r_bc_2]
            self._obstacle_constraints = [all_obstacles]
            self._domain = domain
    
        elif self.config['problem'] == 'simjeb':
            # see paper page 5 - https://arxiv.org/pdf/2105.03534.pdf
            # measurements given in 100s of millimeters
            bounds = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'bounds.npy'))).to(device).float()
            
            # scale_factor and translation_vector
            scale_factor = np.load(os.path.join(self.config['simjeb_root_dir'], 'scale_factor.npy'))
            center_for_translation = np.load(os.path.join(self.config['simjeb_root_dir'], 'center_for_translation.npy'))
            
            # load meshes
            self.mesh_if = trimesh.load(os.path.join(self.config['simjeb_root_dir'], 'interfaces.stl'))
            self.mesh_env = trimesh.load(os.path.join(self.config['simjeb_root_dir'], '411_for_envelope.obj'))
            
            # translate meshes
            self.mesh_if.apply_translation(-center_for_translation)
            self.mesh_env.apply_translation(-center_for_translation)
            
            # scale meshes
            self.mesh_if.apply_scale(1. / scale_factor)
            self.mesh_env.apply_scale(1. / scale_factor)
            
            ## load points
            pts_far_outside_env = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'pts_far_outside.npy'))).to(device).float()
            pts_on_envelope = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'pts_on_env.npy'))).to(device).float()
            pts_inside_envelope = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'pts_inside.npy'))).to(device).float()
            pts_outside_envelope = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'pts_outside.npy'))).to(device).float()
            interface_pts = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'interface_points.npy'))).to(device).float()
            interface_normals = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'interface_normals.npy'))).to(device).float()
            pts_around_interface = torch.from_numpy(np.load(os.path.join(self.config['simjeb_root_dir'], 'pts_around_interface_outside_env_10mm.npy'))).to(device).float()
                        
            # print(f'bounds: {bounds}')
            # print(f'pts_on_envelope: min x,y,z: {torch.min(pts_on_envelope, dim=0)[0]}, max x,y,z: {torch.max(pts_on_envelope, dim=0)[0]}')
            # print(f'pts_outside_envelope: min x,y,z: {torch.min(pts_outside_envelope, dim=0)[0]}, max x,y,z: {torch.max(pts_outside_envelope, dim=0)[0]}')
            # print(f'interface_pts: min x,y,z: {torch.min(interface_pts, dim=0)[0]}, max x,y,z: {torch.max(interface_pts, dim=0)[0]}')
            assert get_is_out_mask(pts_on_envelope, bounds).any() == False
            assert get_is_out_mask(interface_pts, bounds).any() == False
            
            self.config['bounds'] = bounds  # [[x_min, x_max], [y_min, y_max], [z_min, z_max]]        
            envelope = SampleEnvelope(pts_on_envelope=pts_on_envelope, pts_outside_envelope=pts_outside_envelope, sample_from=self.config['envelope_sample_from'])
            envelope_around_interface = SampleConstraint(sample_pts=pts_around_interface)
            pts_far_from_env_constraint = SampleConstraint(sample_pts=pts_far_outside_env)
            inside_envelope = SampleConstraint(sample_pts=pts_inside_envelope)
            domain = CompositeConstraint([inside_envelope])  ## TODO: test also with including outside envelope
            interface = SampleConstraintWithNormals(sample_pts=interface_pts, normals=interface_normals)

            self.constr_pts_dict = {
                # the envelope points are sampled uniformly from the 3 subsets
                'far_outside_envelope': pts_far_from_env_constraint.get_sampled_points(N=self.config['n_points_envelope'] // 3).cpu().numpy(),
                'envelope': envelope.get_sampled_points(N=self.config['n_points_envelope'] // 3).cpu().numpy(),
                'envelope_around_interface': envelope_around_interface.get_sampled_points(N=self.config['n_points_envelope'] // 3).cpu().numpy(),
                # other constraints
                'interface': interface.get_sampled_points(N=self.config['n_points_interfaces'])[0].cpu().numpy(),
                'domain': domain.get_sampled_points(N=self.config['n_points_domain']).cpu().numpy(),
            }
            
            self._envelope_constr = [envelope, envelope_around_interface, pts_far_from_env_constraint]
            self._interface_constraints = [interface]
            self._obstacle_constraints = None
            self._domain = domain
            
            self.bounds = config['bounds'].cpu()
            ## For plotting
            self.X0, self.X1, self.xs = get_meshgrid_in_domain(inflate_bounds(self.bounds, amount=0.2)) # inflate bounds for better visualization
            self.xs = torch.tensor(self.xs, dtype=torch.float32, device=device)
            
        else:
            raise NotImplementedError(f'Problem {self.config["problem"]} not implemented')
    
    def sample_from_envelope(self):
        pts_per_constraint = self.config['n_points_envelope'] // len(self._envelope_constr)
        return torch.cat([c.get_sampled_points(pts_per_constraint) for c in self._envelope_constr], dim=0)
    
    def sample_from_interface(self):
        pts_per_constraint = self.config['n_points_interfaces'] // len(self._interface_constraints)
        pts = []
        normals = []
        for c in self._interface_constraints:
            pts_i, normals_i = c.get_sampled_points(pts_per_constraint)
            pts.append(pts_i)
            normals.append(normals_i)
        return torch.cat(pts, dim=0), torch.cat(normals, dim=0)
    
    def sample_from_obstacles(self):
        pts_per_constraint = self.config['n_points_obstacles'] // len(self._obstacle_constraints)
        return torch.vstack([c.get_sampled_points(pts_per_constraint) for c in self._obstacle_constraints])
    
    def sample_from_domain(self):
        return self._domain.get_sampled_points(self.config['n_points_domain'])
        
    
    def get_mesh_or_contour(self, f, params, z_latents):
        try:
            
            if self.config['nx']==2:
                ## get contour on the marching squares grid for 2d obstacle
                contour_list = []
                with torch.no_grad():   
                    y_ms = f(params, *tensor_product_xz(self.xs_ms, z_latents)).detach().cpu().numpy()
                Y_ms = einops.rearrange(y_ms, '(bz h w) 1 -> bz h w', h=self.X0_ms.shape[0], w=self.X0_ms.shape[1])
                for i, _ in enumerate(z_latents):
                    contour = get_2d_contour_for_grid(self.X0_ms, Y_ms[i], self.bounds)
                    contour_list.append(contour)
                return contour_list
            
            elif self.config['nx']==3:
                ## get watertight verts, faces for simjeb
                verts_faces_list = []
                for z_ in z_latents: ## do marching cubes for every z
                    verts_faces = get_watertight_mesh_for_latent(f, params, z_, bounds=self.config["bounds"],
                                                                mc_resolution=self.config["mc_resolution"], 
                                                                device=z_latents.device, chunks=self.config['mc_chunks'])
                    verts_faces_list.append(verts_faces)
                return verts_faces_list
        except Exception as e:
            print(f'WARNING: Could not compute mesh_or_contour for plotting: {e}')
            return None
        
        
    def recalc_output(self, f, params, z_latents):
        """Compute the function on the grid.
        epoch: will be used to identify figures for wandb or saving
        :param z_latents:
        :get_contour: only for 2d; if True, will return the contour instead of the full grid
        """        
        if self.config['nx']==2:
            ## just return the function values on the standard grid (used for visualization)
            with torch.no_grad():
                y = f(params, *tensor_product_xz(self.xs, z_latents)).detach().cpu().numpy()
            Y = einops.rearrange(y, '(bz h w) 1 -> bz h w', h=self.X0.shape[0], w=self.X0.shape[1])
            return y, Y
                    
        elif self.config['nx']==3:
            verts_faces_list = []
            for z_ in z_latents: ## do marching cubes for every z
                
                def f_fixed_z(x):
                    with torch.no_grad():
                        """A wrapper for calling the model with a single fixed latent code"""
                        return f(params, *tensor_product_xz(x, z_.unsqueeze(0))).squeeze(0)
                
                verts_, faces_ = get_mesh(f_fixed_z,
                                            N=self.config["mc_resolution"],
                                            device=z_latents.device,
                                            bbox_min=self.config["bounds"][:,0],
                                            bbox_max=self.config["bounds"][:,1],
                                            chunks=1,
                                            return_normals=0)
                # print(f"Found a mesh with {len(verts_)} vertices and {len(faces_)} faces")
                verts_faces_list.append((verts_, faces_))
            return verts_faces_list
    
    def is_inside_envelope(self, p_np: PointWrapper):
        """Remove points that are outside the envelope"""
        if not self.config['problem'] == 'simjeb':
            raise NotImplementedError('This function is only implemented for the simjeb problem')
        
        is_inside_mask = self.mesh_env.contains(p_np.data)
        return is_inside_mask