import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
import logging        

def process_data_local(data_dir, objname, base_sample=10000, supervision=False, outname=None, sample_radius=1000, sample_param=1.0, classname=None):

    pointcloud = trimesh.load(objname)
    pointcloud = pointcloud.vertices

    shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])])
    shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2]
    pointcloud = pointcloud - shape_center
    pointcloud = pointcloud / shape_scale
    points = pointcloud

    all_samples = []
    all_nears = []
    for i in range(420):
        point_idx = np.random.choice(points.shape[0], base_sample, replace=False)
        pointcloud = points[point_idx, :3]
        samples1 = pointcloud[:base_sample // 2] + 0.5 * sample_param * np.random.normal(0.0, 1.0, size=(base_sample // 2, 3))
        samples2 = pointcloud[:base_sample // 2] + 0.1 * sample_param * np.random.normal(0.0, 1.0, size=(base_sample // 2, 3))
        samples = np.concatenate([samples1, samples2], 0)
        tree = cKDTree(pointcloud)
        dis, index = tree.query(samples)
        pointcloud = pointcloud[index]
        select_point_idx = np.random.choice(pointcloud.shape[0], 1, replace=False)
        select_point = pointcloud[select_point_idx, :3]
        tree2 = cKDTree(pointcloud)
        dis2, index2 = tree2.query(select_point, k=sample_radius)
        final_samples = samples[index2]
        final_points = pointcloud[index2]
        all_samples.append(final_samples)
        all_nears.append(final_points)
    all_samples = np.asarray(all_samples)
    all_nears = np.asarray(all_nears)
    
    if outname is not None:
        np.savez(os.path.join(data_dir, outname), sample=all_samples, sample_near=all_nears)
    else:
        np.savez(os.path.join(data_dir, 'sample_locals.npz'), sample=all_samples, sample_near=all_nears) 


class Decoder(nn.Module):
    def __init__(
        self,
        latent_size,
        dims,
        dropout=None,
        dropout_prob=0.0,
        norm_layers=(),
        latent_in=(),
        skip_in=(4,),
        weight_norm=False,
        xyz_in_all=None,
        use_tanh=False,
        latent_dropout=False,
    ):
        super(Decoder, self).__init__()

        def make_sequence():
            return []

        self.__hidden__ = torch.nn.Linear(3, 1, bias=False)       
        
        multires = 0
        bias = 0.5
        inside_outside = False
        self.scale = 1.0
        self.skip_in = skip_in
        self.embed_fn_fine = None
        self.latent_size = latent_size
        highdims = [latent_size // 2 + 3] + dims + [1]
        dims = [latent_size + 3] +dims + [1]

        self.num_layers = len(dims)
        self.norm_layers = norm_layers
        self.latent_in = latent_in
        self.latent_dropout = latent_dropout
        if self.latent_dropout:
            self.lat_dp = nn.Dropout(0.2)

        self.xyz_in_all = xyz_in_all
        self.weight_norm = weight_norm

        for i in range(3):
            setattr(self, "latent_lin" + str(i), nn.Linear(latent_size // 2, latent_size // 2))
        
        for l in range(0, self.num_layers - 1):
            if l + 1 in self.skip_in:
                out_dim = dims[l + 1] - dims[0]
            else:
                out_dim = dims[l + 1]

            lin = nn.Linear(dims[l], out_dim)
        
        for layer in range(0, self.num_layers - 1):
            if layer + 1 in latent_in:
                out_dim = dims[layer + 1] - dims[0]
            else:
                out_dim = dims[layer + 1]
                if self.xyz_in_all and layer != self.num_layers - 2:
                    out_dim -= 3

            if weight_norm and layer in self.norm_layers:
                setattr(
                    self,
                    "lin" + str(layer),
                    nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
                )
            else:
                setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))

            if (
                (not weight_norm)
                and self.norm_layers is not None
                and layer in self.norm_layers
            ):
                setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))

        self.use_tanh = use_tanh
        if use_tanh:
            self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

        self.dropout_prob = dropout_prob
        self.dropout = dropout
        self.th = nn.Tanh()

    def forward1(self, input):
        xyz = input[:, :3]

        if input.shape[1] > 3 and self.latent_dropout:
            latent_vecs = input[:, 3:]
            latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
            x = torch.cat([xyz, latent_vecs], 1)
        else:
            x = input

        for layer in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(layer))
            if layer in self.latent_in:
                x = torch.cat([input, x], 1)
            elif layer != 0 and self.xyz_in_all:
                x = torch.cat([xyz, x], 1)
            x = lin(x)
            if layer == self.num_layers - 2 and self.use_tanh:
                x = self.tanh(x)
            if layer < self.num_layers - 2:
                if (
                    self.norm_layers is not None
                    and layer in self.norm_layers
                    and not self.weight_norm
                ):
                    bn = getattr(self, "bn" + str(layer))
                    x = bn(x)
                x = self.relu(x)
                if self.dropout is not None and layer in self.dropout:
                    x = F.dropout(x, p=self.dropout_prob, training=self.training)

        if hasattr(self, "th"):
            x = self.th(x)

        return x

    
    def sdf(self, x):
        return self.forward1(x)
    
    def gradient(self, input):
        input.requires_grad_(True)
        y = self.sdf(input)
        gradients = torch.autograd.grad(
                outputs=y,
                inputs=input,
                grad_outputs=torch.ones_like(y, requires_grad=False, device=y.device),
                create_graph=True,
                retain_graph=True,
                only_inputs=True,
        )[0]
        return gradients.unsqueeze(1)
   

    def forward(self, input, require_grad=False):

        if require_grad:
            gradients_sample = self.gradient(input).squeeze()
        else:
            gradients_sample = None

        if input is None:
            output = None
        else:
            output = self.forward1(input)

        return output, gradients_sample
        
        
for i in range(batch_split):
    input = torch.cat([batch_vecs, xyz[i]], dim=1)
    pred_sdf, grads = decoder(input, unsuper)
    if not unsuper:
        loss = loss_l1(pred_sdf, sdf_gt[i].cuda()) / num_sdf_samples
    else:
        grads = grads[:, -3:]
        grad_norm = F.normalize(grads, dim=1)
        sample_moved = xyz[i].cuda() - grad_norm * pred_sdf
        txt = surface[i].detach().cpu().numpy()
        loss = earth_mover_distance(surface[i].cuda(), sample_moved, transpose=False).mean()
