import torch
import sklearn
import numpy as np

def plot_3d(scatter_matrix):
    assert scatter_matrix.shape[-1] == 3
    import matplotlib.pyplot as plt
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    scatter_matrix = scatter_matrix.reshape(-1, 3)
    xs = scatter_matrix[:,0]
    ys = scatter_matrix[:,1]
    zs = scatter_matrix[:,2]

    ax.scatter(xs, ys, zs)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')

    plt.savefig('3d_temp.png')
    plt.show()
    
def clamp_preserve_gradients(x: torch.Tensor, min: float, max: float) -> torch.Tensor:
    """Clamp the tensor while preserving gradients in the clamped region."""
    return x + (x.clamp(min, max) - x).detach()

class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()


class RandomMultiMeshGenerator(object):
    def __init__(self, grid, level, sample_sizes, zero_idx=None, device='cuda'):
        super(RandomMultiMeshGenerator, self).__init__()
        
        self.device = device

        self.m = sample_sizes
        self.level = level
        self.grid = torch.from_numpy(grid).float().to(self.device)
        self.n = self.grid.shape[0]
        self.d = self.grid.shape[-1]

        self.idx = []
        self.idx_used_nonzero = None
        self.zero_idx = torch.tensor(zero_idx).long().to(self.device) if zero_idx is not None else None
        self.zero_num = len(zero_idx)
        self.grid_sample = []
        self.grid_sample_all = None
        self.edge_index = []
        self.edge_index_down = []
        self.edge_index_up = []
        self.edge_attr = []
        self.edge_attr_down = []
        self.edge_attr_up = []
        self.n_edges_inner = []
        self.n_edges_inter = []


    def sample(self):
        idx_all = [i for i in range(self.n)]
        self.idx_non_zero = torch.tensor([i for i in idx_all if i not in self.zero_idx]).to(self.device)
        self.non_zero_num = self.idx_non_zero.shape[0]

        self.idx = []
        self.grid_sample = []
        
        perm = torch.randperm(self.non_zero_num).to(self.device)
        index = 0
        for l in range(self.level):
            idx_temp = self.idx_non_zero[perm[index: index+self.m[l]]]
            if l == 0:
                idx_temp = torch.cat([idx_temp, self.zero_idx], dim=0)
            self.idx.append(idx_temp)
            self.grid_sample.append(self.grid[self.idx[l]])
            index = index+self.m[l]
        self.nonzero_idx_used = self.idx_non_zero[perm[:index]]
        self.nonzero_grid_sample_all = self.grid[self.nonzero_idx_used]
        self.grid_sample_all = self.grid[torch.cat(self.idx)]

        idx = torch.cat(self.idx).detach().cpu()
        nonzero_idx_used = self.nonzero_idx_used.detach().cpu()
        zero_idx = self.zero_idx.detach().cpu()

        return idx, nonzero_idx_used, zero_idx

    def get_grid(self):
        self.grid_out = []
        grid_out = []
        for l in range(len(self.grid_sample)):
            grid = self.grid_sample[l]
            if l == 0:
                self.zero_grid_idx = (self.m[0] + torch.arange(0, self.zero_idx.shape[0])).to(self.device)
            self.grid_out.append(grid)
            grid_out.append(grid.detach().cpu())                
        return grid_out, self.grid_sample_all

    def sphere_distance(self, X, Y=None):
        if Y is None:
            Y = X
        assert(Y.shape[1] == X.shape[1])
        cos_theta = torch.matmul(X, Y.transpose(0,1))
        cos_theta = torch.clip(cos_theta, min=-1.0, max=1.0)
        theta = torch.arccos(cos_theta)
        return theta        

    def ball_connectivity(self, radius_inner, radius_inter):
        assert len(radius_inner) == self.level
        assert len(radius_inter) == self.level - 1

        self.edge_index = []
        self.edge_index_down = []
        self.edge_index_up = []
        self.n_edges_inner = []
        self.n_edges_inter = []
        edge_index_out = []
        edge_index_down_out = []
        edge_index_up_out = []

        index = 0
        for l in range(self.level):
            pwd = self.sphere_distance(self.grid_out[l])
            pwd += (torch.eye(pwd.shape[0])*1e2).to(self.device) # remove the self-loop
            edge_index = torch.vstack(torch.where(pwd <= radius_inner[l])) + index
            # edge_index[0] --> from
            # edge_index[1] --> to
            edge_index = edge_index[[1,0]]
            # not_in_zero = [i for i in edge_index[0] if i not in self.zero_grid_idx]
            not_in_zero = ~(edge_index[0][...,None].detach().cpu() == self.zero_grid_idx.detach().cpu()).any(-1).to(self.device)
            edge_index = edge_index[:,torch.where(not_in_zero)[0]]
            self.edge_index.append(edge_index)
            edge_index_out.append(edge_index.long())
            self.n_edges_inner.append(edge_index.shape[1])
            index = index + self.grid_out[l].shape[0]

        index = 0
        for l in range(self.level-1):
            pwd = self.sphere_distance(self.grid_out[l], self.grid_out[l+1]) # n * m
            edge_index = torch.vstack(torch.where(pwd <= radius_inter[l])) + index
            edge_index[1, :] = edge_index[1, :] + self.grid_out[l].shape[0]
            edge_index = edge_index[[1,0]]
            
            self.edge_index_down.append(edge_index)
            edge_index_down_out.append(edge_index.long())

            if l != 0: # the message from lowest level should not pass to higher ones
                self.edge_index_up.append(edge_index[[1,0],:])
                edge_index_up_out.append(edge_index[[1,0],:].long())
                
            self.n_edges_inter.append(edge_index.shape[1])
            index = index + self.grid_out[l].shape[0]

        edge_index_out = torch.cat(edge_index_out, dim=1)
        edge_index_down_out = torch.cat(edge_index_down_out, dim=1)
        edge_index_up_out = torch.cat(edge_index_up_out, dim=1)

        return edge_index_out.detach().cpu(), edge_index_down_out.detach().cpu(), edge_index_up_out.detach().cpu()

    def get_edge_index_range(self):
        # in order to use graph network's data structure,
        # the edge index shall be stored as tensor instead of list
        # we concatenate the edge index list and label the range of each level

        edge_index_range = torch.zeros((self.level,2), dtype=torch.long)
        edge_index_down_range = torch.zeros((self.level-1,2), dtype=torch.long)
        edge_index_up_range = torch.zeros((self.level-2,2), dtype=torch.long)

        n_edge_index = 0
        for l in range(self.level):
            edge_index_range[l, 0] = n_edge_index
            n_edge_index = n_edge_index + self.edge_index[l].shape[1]
            edge_index_range[l, 1] = n_edge_index

        n_edge_index = 0
        for l in range(self.level-1):
            edge_index_down_range[l, 0] = n_edge_index
            n_edge_index = n_edge_index + self.edge_index_down[l].shape[1]
            edge_index_down_range[l, 1] = n_edge_index
        
        n_edge_index = 0
        for l in range(self.level-2):
            edge_index_up_range[l, 0] = n_edge_index
            n_edge_index = n_edge_index + self.edge_index_up[l].shape[1]
            edge_index_up_range[l, 1] = n_edge_index

        return edge_index_range, edge_index_down_range, edge_index_up_range

    def attributes(self, theta=None):
        self.edge_attr = []
        self.edge_attr_down = []
        self.edge_attr_up = []

        for l in range(self.level):
            edge_attr = self.grid_sample_all[self.edge_index[l].T].reshape((self.n_edges_inner[l], 2*self.d))
            self.edge_attr.append(edge_attr)

        for l in range(self.level - 1):
            edge_attr_down = self.grid_sample_all[self.edge_index_down[l].T].reshape((self.n_edges_inter[l], 2*self.d))
            self.edge_attr_down.append(edge_attr_down)

        for l in range(self.level - 2):
            edge_attr_up = self.grid_sample_all[self.edge_index_up[l].T].reshape((self.n_edges_inter[l+1], 2*self.d))
            self.edge_attr_up.append(edge_attr_up)

        edge_attr_out = torch.cat(self.edge_attr, dim=0).detach().cpu()
        edge_attr_down_out = torch.cat(self.edge_attr_down, dim=0).detach().cpu()
        edge_attr_up_out = torch.cat(self.edge_attr_up, dim=0).detach().cpu()
        return edge_attr_out, edge_attr_down_out, edge_attr_up_out

class NonZeroRandomMultiMeshGenerator(object):
    def __init__(self, grid, level, sample_sizes):
        super(NonZeroRandomMultiMeshGenerator, self).__init__()
        self.grid = grid

        self.n = self.grid.shape[0]
        self.d = self.grid.shape[-1]

        self.m = sample_sizes
        self.level = level

        assert len(sample_sizes) == level

        self.idx = []
        self.idx_all = None
        self.grid_sample = []
        self.grid_sample_all = None
        self.edge_index = []
        self.edge_index_down = []
        self.edge_index_up = []
        self.edge_attr = []
        self.edge_attr_down = []
        self.edge_attr_up = []
        self.n_edges_inner = []
        self.n_edges_inter = []

    def sphere_distance(self, X, Y=None):
        if Y is None:
            Y = X
        assert(Y.shape[1] == X.shape[1])
        cos_theta = X.dot(Y.T)
        cos_theta = np.clip(cos_theta, a_min=-1.0, a_max=1.0)
        theta = np.arccos(cos_theta)
        return theta    

    def sample(self):
        self.idx = []
        self.grid_sample = []

        perm = torch.randperm(self.n)
        index = 0
        for l in range(self.level):
            self.idx.append(perm[index: index+self.m[l]])
            self.grid_sample.append(self.grid[self.idx[l]])
            index = index+self.m[l]
        self.idx_all = perm[:index]
        self.grid_sample_all = self.grid[self.idx_all]
        return self.idx, self.idx_all

    def get_grid(self):
        grid_out = []
        for grid in self.grid_sample:
            grid_out.append(torch.tensor(grid, dtype=torch.float))
        return grid_out, torch.tensor(self.grid_sample_all, dtype=torch.float)

    def ball_connectivity(self, radius_inner, radius_inter):
        assert len(radius_inner) == self.level
        assert len(radius_inter) == self.level - 1

        self.edge_index = []
        self.edge_index_down = []
        self.edge_index_up = []
        self.n_edges_inner = []
        self.n_edges_inter = []
        edge_index_out = []
        edge_index_down_out = []
        edge_index_up_out = []

        index = 0
        for l in range(self.level):
            pwd = self.sphere_distance(self.grid_sample[l])
            edge_index = np.vstack(np.where(pwd <= radius_inner[l])) + index
            edge_index = edge_index[[1,0]]

            self.edge_index.append(edge_index)
            edge_index_out.append(torch.tensor(edge_index, dtype=torch.long))
            self.n_edges_inner.append(edge_index.shape[1])
            index = index + self.grid_sample[l].shape[0]

        index = 0
        for l in range(self.level-1):
            pwd = self.sphere_distance(self.grid_sample[l], self.grid_sample[l+1])
            edge_index = np.vstack(np.where(pwd <= radius_inter[l])) + index
            edge_index[1, :] = edge_index[1, :] + self.grid_sample[l].shape[0]
            edge_index = edge_index[[1,0]]
            self.edge_index_down.append(edge_index)
            edge_index_down_out.append(torch.tensor(edge_index, dtype=torch.long))
            self.edge_index_up.append(edge_index[[1,0],:])
            edge_index_up_out.append(torch.tensor(edge_index[[1,0],:], dtype=torch.long))
            self.n_edges_inter.append(edge_index.shape[1])
            index = index + self.grid_sample[l].shape[0]

        edge_index_out = torch.cat(edge_index_out, dim=1)
        edge_index_down_out = torch.cat(edge_index_down_out, dim=1)
        edge_index_up_out = torch.cat(edge_index_up_out, dim=1)

        return edge_index_out, edge_index_down_out, edge_index_up_out

    def get_edge_index_range(self):
        # in order to use graph network's data structure,
        # the edge index shall be stored as tensor instead of list
        # we concatenate the edge index list and label the range of each level

        edge_index_range = torch.zeros((self.level,2), dtype=torch.long)
        edge_index_down_range = torch.zeros((self.level-1,2), dtype=torch.long)
        edge_index_up_range = torch.zeros((self.level-1,2), dtype=torch.long)

        n_edge_index = 0
        for l in range(self.level):
            edge_index_range[l, 0] = n_edge_index
            n_edge_index = n_edge_index + self.edge_index[l].shape[1]
            edge_index_range[l, 1] = n_edge_index

        n_edge_index = 0
        for l in range(self.level-1):
            edge_index_down_range[l, 0] = n_edge_index
            edge_index_up_range[l, 0] = n_edge_index
            n_edge_index = n_edge_index + self.edge_index_down[l].shape[1]
            edge_index_down_range[l, 1] = n_edge_index
            edge_index_up_range[l, 1] = n_edge_index

        return edge_index_range, edge_index_down_range, edge_index_up_range

    def attributes(self, theta=None):
        self.edge_attr = []
        self.edge_attr_down = []
        self.edge_attr_up = []

        if theta is None:
            for l in range(self.level):
                edge_attr = self.grid_sample_all[self.edge_index[l].T].reshape((self.n_edges_inner[l], 2*self.d))
                self.edge_attr.append(torch.tensor(edge_attr))

            for l in range(self.level - 1):
                edge_attr_down = self.grid_sample_all[self.edge_index_down[l].T].reshape((self.n_edges_inter[l], 2*self.d))
                edge_attr_up = self.grid_sample_all[self.edge_index_up[l].T].reshape((self.n_edges_inter[l], 2*self.d))
                self.edge_attr_down.append(torch.tensor(edge_attr_down))
                self.edge_attr_up.append(torch.tensor(edge_attr_up))
        else:
            theta = theta[self.idx_all]

            for l in range(self.level):
                edge_attr = np.zeros((self.n_edges_inner[l], 2 * self.d + 2))
                edge_attr[:, 0:2 * self.d] = self.grid_sample_all[self.edge_index[l].T].reshape(
                    (self.n_edges_inner[l], 2 * self.d))
                edge_attr[:, 2 * self.d] = theta[self.edge_index[l][0]]
                edge_attr[:, 2 * self.d + 1] = theta[self.edge_index[l][1]]
                self.edge_attr.append(torch.tensor(edge_attr, dtype=torch.float))

            for l in range(self.level - 1):
                edge_attr_down = np.zeros((self.n_edges_inter[l], 2 * self.d + 2))
                edge_attr_up = np.zeros((self.n_edges_inter[l], 2 * self.d + 2))

                edge_attr_down[:, 0:2 * self.d] = self.grid_sample_all[self.edge_index_down[l].T].reshape(
                    (self.n_edges_inter[l], 2 * self.d))
                edge_attr_down[:, 2 * self.d] = theta[self.edge_index_down[l][0]]
                edge_attr_down[:, 2 * self.d + 1] = theta[self.edge_index_down[l][1]]
                self.edge_attr_down.append(torch.tensor(edge_attr_down, dtype=torch.float))

                edge_attr_up[:, 0:2 * self.d] = self.grid_sample_all[self.edge_index_up[l].T].reshape(
                    (self.n_edges_inter[l], 2 * self.d))
                edge_attr_up[:, 2 * self.d] = theta[self.edge_index_up[l][0]]
                edge_attr_up[:, 2 * self.d + 1] = theta[self.edge_index_up[l][1]]
                self.edge_attr_up.append(torch.tensor(edge_attr_up, dtype=torch.float))

        edge_attr_out = torch.cat(self.edge_attr, dim=0)
        edge_attr_down_out = torch.cat(self.edge_attr_down, dim=0)
        edge_attr_up_out = torch.cat(self.edge_attr_up, dim=0)
        return edge_attr_out, edge_attr_down_out, edge_attr_up_out