import torch
import torch.nn as nn
import numpy as np
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
import torch.nn.functional as F


########################### 3DGSPCI_GNet ########################### 

def get_activation(activation):
    if activation.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    elif activation.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation.lower() == 'softplus':
        return nn.Softplus()
    elif activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'mish':
        return nn.Mish(inplace=True)
    else:
        raise Exception("Activation Function Error")


def get_norm(norm, width):
    if norm == 'LN':
        return nn.LayerNorm(width)
    elif norm == 'BN':
        return nn.BatchNorm1d(width)
    elif norm == 'IN':
        return nn.InstanceNorm1d(width)
    elif norm == 'GN':
        return nn.GroupNorm(width)
    else:
        raise Exception("Normalization Layer Error")


class NeuralPCI_Layer(torch.nn.Module):
    def __init__(self, 
                 dim_in,
                 dim_out,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(dim_in, dim_out))
        if norm:
            layer_list.append(get_norm(norm, dim_out))
        if act_fn:
            layer_list.append(get_activation(act_fn))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.layer(x)
        return x


class NeuralPCI_Block(torch.nn.Module):
    def __init__(self, 
                 depth, 
                 width,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        for _ in range(depth):
            layer_list.append(nn.Linear(width, width))
            if norm:
                layer_list.append(get_norm(norm, width))
            if act_fn:
                layer_list.append(get_activation(act_fn))
        self.mlp = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.mlp(x)
        return x
    
class LocalGraphConv(nn.Module):
    def __init__(self, in_channels, out_channels, k=16):
        super(LocalGraphConv, self).__init__()
        self.k = k
        self.conv = nn.Sequential(
            nn.Linear(in_channels*2, out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x):
        n, d = x.size()
        dist_mat = torch.cdist(x, x)
        _, indices = torch.topk(dist_mat, k=self.k, dim=1, largest=False)
        neighbors = self.index_select(x, indices)
        center = x.unsqueeze(1).expand(-1,self.k,-1)
        
        x = torch.cat([center, neighbors-center], dim=-1).contiguous()  # [N, K, 2*C]  # .permute(2,0,1)
        x = self.conv(x.view(-1, 2*d)).view(n, self.k, -1)
        x = x.max(dim=1)[0]
        
        return x

    def index_select(self, x, idx):
        """
        :param  x: input points data, [N, C]
        :param  idx: knn index data, [N, K]
        :return: indexed points data, [N, K, C]
        """
        N, C = x.shape
        _, K = idx.shape
        
        edge_index = knn_graph(x, k=K)
        
        neighbors = x[edge_index[1]]
        
        return neighbors.view(N, K, C)
    
class GaussianMixtureModel(nn.Module):
    def __init__(self, in_channels, n_kernels):
        super().__init__()
        self.n_kernels = n_kernels
        self.mu = nn.Parameter(torch.randn(n_kernels, in_channels) / np.sqrt(n_kernels))
        self.sigma = nn.Parameter(torch.ones(n_kernels, in_channels))
        self.p = nn.Parameter(torch.ones(n_kernels) / n_kernels)

    def forward(self, x):
        x = x.unsqueeze(1) - self.mu.unsqueeze(0)
        z = torch.sum(x**2 / (self.sigma.unsqueeze(0)**2), dim=-1)
        p = torch.softmax(torch.log(self.p) - 0.5 * z, dim=-1)
        mu_x = torch.sum(p.unsqueeze(-1) * self.mu.unsqueeze(0), dim=1)
        return mu_x
    


class ProbabilisticGraphLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(ProbabilisticGraphLayer, self).__init__(aggr='add') 
        self.lin = nn.Linear(in_channels, out_channels)
        self.score_net = nn.Sequential(
            nn.Linear(2 * out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x, edge_index):
        x = self.lin(x)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j, edge_index, size):
        score = self.score_net(torch.cat([x_i, x_j], dim=-1))
        return score * x_j

class ProbabilisticGraphNetwork(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(ProbabilisticGraphNetwork, self).__init__()
        self.conv1 = ProbabilisticGraphLayer(in_channels, hidden_channels)
        self.conv2 = ProbabilisticGraphLayer(hidden_channels, hidden_channels)
        self.conv3 = ProbabilisticGraphLayer(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

class NeuralPCI_diff(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul
        self.n_gaussians = args.n_gaussians
        self.gmm = GaussianMixtureModel(args.layer_width, self.n_gaussians)

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in = (dim_rrf + dim_time) * pe_mul, 
                                           dim_out = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # hidden layers
        self.hidden_encode = NeuralPCI_Block(depth = depth_encode, 
                                             width = layer_width, 
                                             norm = norm,
                                             act_fn = act_fn
                                             )

        # insert interpolation time
        self.layer_time = NeuralPCI_Layer(dim_in = layer_width + dim_time * pe_mul, 
                                          dim_out = layer_width, 
                                          norm = norm,
                                          act_fn = act_fn
                                          )

        # hidden layers
        self.hidden_pred = NeuralPCI_Block(depth = depth_pred, 
                                           width = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # output layer
        self.layer_output = NeuralPCI_Layer(dim_in = layer_width, 
                                          dim_out = dim_pc, 
                                          norm = norm,
                                          act_fn = None
                                          )
        self.layer_gradient = NeuralPCI_Layer(dim_in=layer_width, 
                                      dim_out=dim_pc, 
                                      norm=norm,
                                      act_fn=None)
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    # torch.nn.init.normal_(m.weight.data, 0, 0.01)
                    m.weight.data.zero_()
                    m.bias.data.zero_()

        self.pred_splatted = args.pred_splatted
        self.sigma = args.sigma
    
        self.num_timesteps = args.num_timesteps
        self.beta_start = 1e-4 # args.beta_start
        self.beta_end = 0.02 # args.beta_end
        self.betas = self.get_beta_schedule().cuda()
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(self.alphas_cumprod.device), self.alphas_cumprod[:-1]])
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
    
    # def get_beta_schedule(self):
    #     return torch.linspace(self.beta_start, self.beta_end, self.num_timesteps)
    def get_beta_schedule(self):
        steps = self.num_timesteps
        t = torch.linspace(0, 1, steps)
        start = torch.tensor(self.beta_start).log()
        end = torch.tensor(self.beta_end).log()
        alpha = (-((t * (end - start) + start))).exp()
        alpha_hat = alpha / alpha[0]
        betas = 1 - alpha_hat[1:] / alpha_hat[:-1] 
        return torch.cat((torch.tensor([0]), betas))
    
    def posenc(self, x):
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x
    
    def q_sample(self, pc_start, t, noise_factor=0.1):
        noise = torch.randn_like(pc_start) * noise_factor
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        return sqrt_alphas_cumprod_t * pc_start + sqrt_one_minus_alphas_cumprod_t * noise, noise
    
    def p_sample(self, pc_t, t, time_pred, time_current):
        betas_t = self.betas[t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alphas[t]).view(-1, 1)
        
        model_output = self.model(pc_t, time_current, time_pred)
        model_mean, model_gradient = model_output['mean'], model_output['gradient']
        
        # DDIM sampling
        x_t_minus_1 = sqrt_recip_alphas_t * (pc_t - betas_t * model_mean / sqrt_one_minus_alphas_cumprod_t)
        x_t_minus_1 += torch.sqrt(self.posterior_variance[t]) * model_gradient
        
        return x_t_minus_1
    

    def forward(self, pc_current, time_current, time_pred, train=True, num_denoise_steps=10):
        # Convert time_current and time_pred to PyTorch tensors with the same data type and device as pc_current
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).to(pc_current.device).float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).to(pc_current.device).float().detach()
        
        if train:
            t = torch.randint(0, self.num_timesteps, (1,)).to(pc_current.device)
            pc_noisy, _ = self.q_sample(pc_current, t)
            model_output = self.model(pc_noisy, time_current, time_pred)
            model_mean, model_gradient = model_output['mean'], model_output['gradient']
            return model_mean
        else:
            for i in range(num_denoise_steps):
                t = torch.tensor([i]).to(pc_current.device)
                pc_current = self.p_sample(pc_current, t, time_pred, time_current)
            model_output = self.model(pc_current, time_current, time_pred)
            pc_pred, model_gradient = model_output['mean'], model_output['gradient']
            return pc_pred
        
    def model(self, pc_current, t, time_pred):
        if self.args.use_rrf:
            pc_current = torch.matmul(2. * torch.pi * pc_current, self.transform)
        if t.shape[0] != pc_current.shape[0]:
            time_t = t.view(-1, 1).repeat(1, pc_current.shape[0]).view(-1, 1)
        else:
            time_t = t
        x = torch.cat((pc_current, time_t), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)

        time_pred = self.posenc(time_pred)
        x = torch.cat((x, time_pred), dim=1)
        x = self.layer_time(x)

        x = self.hidden_pred(x)
        x = self.gmm(x)
        pc_gradient = self.layer_gradient(x)
        x = self.layer_output(x)
        pc_pred = x
        
        # return x
        if self.pred_splatted:
            pc_pred_min = pc_pred.min(dim=0, keepdim=True)[0]
            pc_pred_max = pc_pred.max(dim=0, keepdim=True)[0]
            pc_pred_normalized = (pc_pred - pc_pred_min) / (pc_pred_max - pc_pred_min + 1e-6)

            kernel_size = int(self.sigma * 6)
            kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
            gaussian_kernel = self._gaussian_kernel(kernel_size, self.sigma).cuda()

            pc_pred_normalized = pc_pred_normalized.transpose(0, 1).unsqueeze(0).unsqueeze(0).unsqueeze(0)
            pc_pred_splatted = nn.functional.conv3d(pc_pred_normalized, gaussian_kernel, padding=kernel_size//2)
            pc_pred_splatted = pc_pred_splatted.squeeze(0).squeeze(0).squeeze(0).transpose(0, 1)

            pc_pred = pc_pred_splatted * (pc_pred_max - pc_pred_min + 1e-6) + pc_pred_min
        
        
        return {'mean': pc_pred, 'gradient': pc_gradient}
        # return pc_pred

    def _gaussian_kernel(self, kernel_size, sigma):
        x_coord = torch.arange(kernel_size)
        x_grid = x_coord.repeat(kernel_size, kernel_size).view(kernel_size, kernel_size, kernel_size)
        y_grid = x_grid.transpose(1, 0)
        z_grid = x_grid.transpose(2, 0)
        xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float()

        mean = (kernel_size - 1)/2.
        variance = sigma**2.

        gaussian_kernel = (1./(2.*np.pi*variance)**(3/2)) * torch.exp(-torch.sum((xyz_grid - mean)**2., dim=-1) / (2*variance))
        gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

        return gaussian_kernel.view(1, 1, kernel_size, kernel_size, kernel_size)
    


class NeuralPCI(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in = (dim_rrf + dim_time) * pe_mul, 
                                           dim_out = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # hidden layers
        self.hidden_encode = NeuralPCI_Block(depth = depth_encode, 
                                             width = layer_width, 
                                             norm = norm,
                                             act_fn = act_fn
                                             )

        # insert interpolation time
        self.layer_time = NeuralPCI_Layer(dim_in = layer_width + dim_time * pe_mul, 
                                          dim_out = layer_width, 
                                          norm = norm,
                                          act_fn = act_fn
                                          )

        # hidden layers
        self.hidden_pred = NeuralPCI_Block(depth = depth_pred, 
                                           width = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # self.hidden_encode = nn.Sequential(
        #     LocalGraphConv(layer_width, layer_width),
        #     LocalGraphConv(layer_width, layer_width),
        #     LocalGraphConv(layer_width, layer_width)
        # )

        # self.hidden_pred = nn.Sequential( 
        #     LocalGraphConv(layer_width, layer_width),
        #     LocalGraphConv(layer_width, layer_width),
        #     LocalGraphConv(layer_width, layer_width)
        # )

        # output layer
        self.layer_output = NeuralPCI_Layer(dim_in = layer_width, 
                                          dim_out = dim_pc, 
                                          norm = norm,
                                          act_fn = None
                                          )
        
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    # torch.nn.init.normal_(m.weight.data, 0, 0.01)
                    m.weight.data.zero_()
                    m.bias.data.zero_()

        ############## NNNNNNNNNNNNNNNNEEEEEEEEEEEEEEEEEEEWWWWWWWWWWWWWWWWWWWW ##############
        if args.GMM:
            self.gmm = GaussianMixtureModel(args.layer_width, args.n_gaussians)
            self.GMM = True
        else:
            self.GMM = False
        if args.ProGraphNet:
            self.ProbabilisticGraphNet = ProbabilisticGraphNetwork(layer_width, layer_width//4, layer_width)
            self.ProGraphNet = True
        else:
            self.ProGraphNet = False

        self.pred_splatted = args.pred_splatted
        self.sigma = args.sigma
    
    def posenc(self, x):
        """
        sinusoidal positional encoding : N ——> 3 * N
        [x] ——> [x, sin(x), cos(x)]
        """
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x

    def forward(self, pc_current, time_current, time_pred, train=False):
        """
        pc_current: tensor, [N, 3]
        time_current: float, [1]
        time_pred: float, [1]
        output: tensor, [N, 3]
        """
        
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).cuda().float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).cuda().float().detach()
        
        if self.args.use_rrf:
            pc_current = torch.matmul(2. * torch.pi * pc_current, self.transform)

        x = torch.cat((pc_current, time_current), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)
        x = self.hidden_encode(x)

        if self.ProGraphNet:
            data = Data(x=x, pos=pc_current)
            edge_index = knn_graph(data.pos, k=16, loop=False)
            data.edge_index = edge_index
            x = self.ProbabilisticGraphNet(data)
        
        time_pred = self.posenc(time_pred)
        x = torch.cat((x, time_pred), dim=1)
        x = self.layer_time(x)

        if self.GMM:
            x = self.gmm(x)

        x = self.hidden_pred(x)

        x = self.layer_output(x)
        
        # return x
        pc_pred = x
        if self.pred_splatted:
            pc_pred_min = pc_pred.min(dim=0, keepdim=True)[0]
            pc_pred_max = pc_pred.max(dim=0, keepdim=True)[0]
            pc_pred_normalized = (pc_pred - pc_pred_min) / (pc_pred_max - pc_pred_min + 1e-6)

            kernel_size = int(self.sigma * 6)
            kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
            gaussian_kernel = self._gaussian_kernel(kernel_size, self.sigma).cuda()

            pc_pred_normalized = pc_pred_normalized.transpose(0, 1).unsqueeze(0).unsqueeze(0).unsqueeze(0)
            pc_pred_splatted = nn.functional.conv3d(pc_pred_normalized, gaussian_kernel, padding=kernel_size//2)
            pc_pred_splatted = pc_pred_splatted.squeeze(0).squeeze(0).squeeze(0).transpose(0, 1)

            pc_pred = pc_pred_splatted * (pc_pred_max - pc_pred_min + 1e-6) + pc_pred_min

        return pc_pred

    def _gaussian_kernel(self, kernel_size, sigma):
        x_coord = torch.arange(kernel_size)
        x_grid = x_coord.repeat(kernel_size, kernel_size).view(kernel_size, kernel_size, kernel_size)
        y_grid = x_grid.transpose(1, 0)
        z_grid = x_grid.transpose(2, 0)
        xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float()

        mean = (kernel_size - 1)/2.
        variance = sigma**2.

        gaussian_kernel = (1./(2.*np.pi*variance)**(3/2)) * torch.exp(-torch.sum((xyz_grid - mean)**2., dim=-1) / (2*variance))
        gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

        return gaussian_kernel.view(1, 1, kernel_size, kernel_size, kernel_size)
    
    