import os
import sys
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import *
from utils import preprocess
from scipy.spatial.transform import Rotation as R
from torch_geometric.transforms import SamplePoints

EPS = 1e-6

class VNLinear(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VNLinear, self).__init__()
        self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
        return x_out


class VNBilinear(nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels):
        super(VNBilinear, self).__init__()
        self.map_to_feat = nn.Bilinear(in_channels1, in_channels2, out_channels, bias=False)
    
    def forward(self, x, labels):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        labels = labels.repeat(1, x.shape[2], 1).float()
        x_out = self.map_to_feat(x.transpose(1,-1), labels).transpose(1,-1)
        return x_out


class VNSoftplus(nn.Module):
    def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.0):
        super(VNSoftplus, self).__init__()
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
        self.negative_slope = negative_slope
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdim=True)
        angle_between = torch.acos(dotprod / (torch.norm(x, dim=2, keepdim=True) * torch.norm(d, dim=2, keepdim=True) + EPS))
        # create a smooth scale between 0 and 1 based on the angle between x and d
        mask = torch.cos(angle_between / 2) ** 2
        d_norm_sq = (d*d).sum(2, keepdim=True)
        x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLeakyReLU(nn.Module):
    def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2):
        super(VNLeakyReLU, self).__init__()
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
        self.negative_slope = negative_slope
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdim=True)
        mask = (dotprod >= 0).float()
        d_norm_sq = (d*d).sum(2, keepdim=True)
        x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLinearLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, negative_slope=0.2):
        super(VNLinearLeakyReLU, self).__init__()
        self.dim = dim
        self.negative_slope = negative_slope
        
        self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
        self.batchnorm = VNBatchNorm(out_channels, dim=dim)
        
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False)
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # Linear
        p = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
        # BatchNorm
        p = self.batchnorm(p)
        # LeakyReLU
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (p*d).sum(2, keepdims=True)
        mask = (dotprod >= 0).float()
        d_norm_sq = (d*d).sum(2, keepdims=True)
        x_out = self.negative_slope * p + (1-self.negative_slope) * (mask*p + (1-mask)*(p-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLinearAndLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_batchnorm='norm', negative_slope=0.2):
        super(VNLinearLeakyReLU, self).__init__()
        self.dim = dim
        self.share_nonlinearity = share_nonlinearity
        self.use_batchnorm = use_batchnorm
        self.negative_slope = negative_slope
        
        self.linear = VNLinear(in_channels, out_channels)
        self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope)
        
        # BatchNorm
        self.use_batchnorm = use_batchnorm
        if use_batchnorm != 'none':
            self.batchnorm = VNBatchNorm(out_channels, dim=dim, mode=use_batchnorm)
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # Conv
        x = self.linear(x)
        # InstanceNorm
        if self.use_batchnorm != 'none':
            x = self.batchnorm(x)
        # LeakyReLU
        x_out = self.leaky_relu(x)
        return x_out


class VNBatchNorm(nn.Module):
    def __init__(self, num_features, dim):
        super(VNBatchNorm, self).__init__()
        self.dim = dim
        if dim == 3 or dim == 4:
            self.bn = nn.BatchNorm1d(num_features)
        elif dim == 5:
            self.bn = nn.BatchNorm2d(num_features)
    
    def forward(self, x, mask=None):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        norm = torch.norm(x, dim=2) + EPS  # [B, N_feat, N_samples, ...]
        
        if mask is not None:
            expanded_mask = mask.expand_as(norm)
            norm = torch.where(expanded_mask, norm, torch.ones_like(norm))
            
        norm_bn = self.bn(norm)
        
        if mask is not None:
            norm_bn = torch.where(expanded_mask, norm_bn, norm)
            

        norm = norm.unsqueeze(2)
        norm_bn = norm_bn.unsqueeze(2)
        x = x / norm * norm_bn
        
        return x


class VNMaxPool(nn.Module):
    def __init__(self, in_channels):
        super(VNMaxPool, self).__init__()
        self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # print(x.shape)
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdims=True)
        idx = dotprod.max(dim=-1, keepdim=False)[1]
        index_tuple = torch.meshgrid([torch.arange(j) for j in x.size()[:-1]]) + (idx,)
        x_max = x[index_tuple]
        return x_max


def mean_pool(x, dim=-1, keepdim=False):
    return x.mean(dim=dim, keepdim=keepdim)


def knn(x, k, mask=None):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)  # [B, N, N]

    if mask is not None:
        mask = mask.squeeze(1)  # [B, N]
        inf_mask = (~mask).float() * -1e9
        inf_mask = inf_mask.unsqueeze(1)  # [B, 1, N]
        pairwise_distance = pairwise_distance + inf_mask + inf_mask.transpose(1, 2)
        
        self_mask = torch.eye(mask.size(1)).unsqueeze(0).to(x.device)  # [1, N, N]
        self_mask = self_mask * (~mask).float().unsqueeze(1)  # [B, N, N]
        pairwise_distance = pairwise_distance * (1 - self_mask) + self_mask * 0

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def get_graph_feature_cross(x, k=20, idx=None, mask=None):
    batch_size = x.size(0)
    num_points = x.size(3)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k, mask=mask)  

    idx_base = torch.arange(0, batch_size).type_as(idx).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)

    _, num_dims, _ = x.size()
    num_dims = num_dims // 3

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims, 3)
    x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    
    if mask is not None:
        mask = mask.squeeze(1)  # [B, N]
        masked_points = ~mask  # [B, N]
        masked_points = masked_points.unsqueeze(2).unsqueeze(3).unsqueeze(4)  # [B, N, 1, 1, 1]
        feature = torch.where(masked_points, x, feature)

    cross = torch.cross(feature, x, dim=-1)
    feature = torch.cat((feature - x, x, cross), dim=3).permute(0, 3, 4, 1, 2).contiguous()

    return feature

class VNSmall(nn.Module):
    def __init__(self, n_knn=20, pooling="max"):
        super().__init__()
        self.n_knn = n_knn
        self.pooling = pooling
        self.conv_pos = VNLinearLeakyReLU(3, 64 // 3, dim=5, negative_slope=0.0)
        self.conv1 = VNLinearLeakyReLU(64 // 3, 64 // 3, dim=4, negative_slope=0.0)
        self.bn1 = VNBatchNorm(64 // 3, dim=4)
        self.conv2 = VNLinearLeakyReLU(64 // 3, 12 // 3, dim=4, negative_slope=0.0)
        self.dropout = nn.Dropout(p=0.5)

        if self.pooling == "max":
            self.pool = VNMaxPool(64 // 3)
        elif self.pooling == "mean":
            self.pool = mean_pool

        # use one linear layer to predict the output
        self.conv = VNLinear(3, 12 // 3)

    def forward(self, point_cloud, labels=None, mask=None):
        point_cloud = point_cloud.unsqueeze(1)
        # print(point_cloud.shape)
        feat = get_graph_feature_cross(point_cloud, k=self.n_knn, mask=mask)  
        # can change multiple layers here:
        point_cloud = self.conv_pos(feat)
        point_cloud = self.pool(point_cloud)
        out = self.bn1(self.conv1(point_cloud), mask)
        out = self.conv2(out)
        # out = self.dropout(out)

        # out = self.pool(self.conv(feat))
        # out = self.dropout(out)

        return out.mean(dim=-1)


class PointcloudCanonFunction(nn.Module):
    def __init__(self, n_knn=5, pooling="max"):
        super().__init__()
        self.model = VNSmall(n_knn, pooling)

    def forward(self, points, mask=None):
        vectors = self.model(points, mask)

        rotation_vectors = vectors[:, :3]
        
        rotation_matrix = self.gram_schmidt(rotation_vectors)
        # print(points.shape, rotation_matrix.shape)
        canonical_point_cloud = torch.bmm(points.transpose(1, 2), rotation_matrix.transpose(1, 2))
        canonical_point_cloud = canonical_point_cloud.transpose(1, 2)
        return rotation_matrix, canonical_point_cloud

    def gram_schmidt(self, vectors):
        v1 = vectors[:, 0]
        v1 = v1 / torch.norm(v1, dim=1, keepdim=True)
        v2 = (vectors[:, 1] - torch.sum(vectors[:, 1] * v1, dim=1, keepdim=True) * v1)
        v2 = v2 / torch.norm(v2, dim=1, keepdim=True)
        v3 = (vectors[:, 2] - torch.sum(vectors[:, 2] * v1, dim=1, keepdim=True) * v1 - 
              torch.sum(vectors[:, 2] * v2, dim=1, keepdim=True) * v2)
        v3 = v3 / torch.norm(v3, dim=1, keepdim=True)
        return torch.stack([v1, v2, v3], dim=1)


qm9dir = 'data/qm9'
# MD17dir = 'data/md17'
QM9_dataset = QM9(root=qm9dir)

MD17dir = 'data/md17/revised aspirin'
name = 'revised aspirin'
base_dataset = MD17(root=MD17dir, name=name)
# print(name, len(base_dataset))

def test_rotation_invariance():

    canon_function = PointcloudCanonFunction(n_knn=5, pooling="max")
    canon_function.eval()
    
    MD17dir = 'data/md17/revised aspirin'
    dataset = MD17(root=MD17dir, name='revised aspirin')
    sample1 = dataset[0]
    sample2 = dataset[1]

    points1 = sample1['pos'].clone().float()  # [N_atoms, 3]
    points1 = points1.unsqueeze(0)  # [1, N_atoms, 3]
    points1 = points1.transpose(1, 2)  # [1, 3, N_atoms]

    points2 = sample2['pos'].clone().float()  # [N_atoms, 3]
    points2 = points2.unsqueeze(0)  # [1, N_atoms, 3]
    points2 = points2.transpose(1, 2)  # [1, 3, N_atoms]
    
    # print(points1)
    points = torch.cat([points1, points2], dim=0)

    with torch.no_grad():
        rotation_matrix, canonical_points1 = canon_function(points)

    

    num_samples = 1000
    rot_diffs = [] 

    for _ in range(num_samples):

        random_rotation = torch.tensor(R.random().as_matrix()).float()


        rotated_points = torch.matmul(random_rotation, points)


        with torch.no_grad():
            rotation_matrix_rotated, canonical_points_rotated = canon_function(rotated_points, mask=None)

        rot_diff = torch.norm(torch.matmul(rotation_matrix_rotated, random_rotation) - rotation_matrix)
        rot_diffs.append(rot_diff.item())

    print("\nRotation Invariance Test:")
    print("Original point cloud shape:", points.shape)
    print("Average difference:", np.mean(rot_diffs))
    print("Maximum difference:", np.max(rot_diffs))
    print("Minimum difference:", np.min(rot_diffs))

    QM9_dataset = QM9(root=qm9dir)


    pos, mask = preprocess(QM9_dataset[0], dataset_type="qm9_point_clouds")
    pos = pos.transpose(1, 2)
    # print(pos.shape, mask.shape)
    pos = pos / (torch.norm(pos, dim=1, keepdim=True) + 1e-6)

    canon_function = PointcloudCanonFunction(n_knn=5, pooling="max")
    canon_function.eval()

    with torch.no_grad():
        rotation_matrix, canonical_points = canon_function(pos, mask=mask)

    num_samples = 1000
    rot_diffs = [] 

    for _ in range(num_samples):
        random_rotation = torch.tensor(R.random().as_matrix()).float()
        rotated_points = torch.matmul(random_rotation, pos)

        with torch.no_grad():
            rotation_matrix_rotated, canonical_points_rotated = canon_function(rotated_points, mask=mask)

        rot_diff = torch.norm(torch.matmul(rotation_matrix_rotated, random_rotation) - rotation_matrix)
        rot_diffs.append(rot_diff.item())

    print("\nRotation Invariance Test:")
    print("Original point cloud shape:", pos.shape)
    print("Average difference:", np.mean(rot_diffs))
    print("Maximum difference:", np.max(rot_diffs))
    print("Minimum difference:", np.min(rot_diffs))


def test_rotation_invariance_modelnet40():
    print("\nTesting Rotation Invariance on ModelNet40:")
    
    from torch_geometric.datasets import ModelNet
    dataset = ModelNet(root='data/ModelNet40', name='40', transform=SamplePoints(100, include_normals=False))
    
    canon_function = PointcloudCanonFunction(n_knn=5, pooling="max")
    canon_function.eval()
    

    sample = dataset[0]
    points = sample.pos.clone().float()  # [N_points, 3]
    points = points.unsqueeze(0)  # [1, N_points, 3]
    points = points.transpose(1, 2)  # [1, 3, N_points]
    # print(points)
    # points = points - points.mean(dim=2, keepdim=True)
    points = points / (torch.norm(points, dim=1, keepdim=True) + 1e-6)

    with torch.no_grad():
        rotation_matrix, canonical_points = canon_function(points, mask=None)
    

    num_samples = 1000
    rot_diffs = []
    
    for _ in range(num_samples):
        random_rotation = torch.tensor(R.random().as_matrix()).float()
        
        rotated_points = torch.matmul(random_rotation, points)
        
        with torch.no_grad():
            rotation_matrix_rotated, canonical_points_rotated = canon_function(rotated_points, mask=None)
        
        rot_diff = torch.norm(torch.matmul(rotation_matrix_rotated, random_rotation) - rotation_matrix)
        rot_diffs.append(rot_diff.item())
    
    print("Original point cloud shape:", points.shape)
    print("Average difference:", np.mean(rot_diffs))
    print("Maximum difference:", np.max(rot_diffs))
    print("Minimum difference:", np.min(rot_diffs))


def test_random_point_cloud():
    sample = torch.randn(1, 3, 100) * 1000 
    canon_function = PointcloudCanonFunction(n_knn=5, pooling="max")
    canon_function.eval()
    with torch.no_grad():
        rotation_matrix, canonical_points = canon_function(sample, mask=None)
    num_samples = 1000
    rot_diffs = [] 
    for _ in range(num_samples):
        random_rotation = torch.tensor(R.random().as_matrix()).float()
        rotated_points = torch.matmul(random_rotation, sample)
        with torch.no_grad():
            rotation_matrix_rotated, canonical_points_rotated = canon_function(rotated_points, mask=None)
        rot_diff = torch.norm(torch.matmul(rotation_matrix_rotated, random_rotation) - rotation_matrix)
        rot_diffs.append(rot_diff.item())
    print("test_random_point_cloud")
    print("Average difference:", np.mean(rot_diffs))
    print("Maximum difference:", np.max(rot_diffs))
    print("Minimum difference:", np.min(rot_diffs))


if __name__ == "__main__":
    test_rotation_invariance()
    test_rotation_invariance_modelnet40()
    test_random_point_cloud()

