import torch
import numpy as np
import matplotlib.pyplot as plt
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from einops import rearrange, repeat

def split_corrected_v2(xyz, scales, rotation, M=2):
    means = torch.zeros_like(repeat(xyz, 'n d -> (m n) d', m=M))
    stds = repeat(scales, 'n d -> (m n) d', m=M)
    rots = build_rotation(rotation) # N 3 3
    rots = repeat(rots, 'n d1 d2 -> (m n) d1 d2', m=M)
    
    samples = torch.randn_like(means) * stds
    samples = torch.einsum('... n i j,... n j->... n i', rots, samples) + means
    return samples

# Attempting to split the point again
xyz = torch.tensor([[0.0, 0.0, 0.0]])
scales = torch.tensor([[0.01, 0.01, 0.01]])
rotation = torch.tensor([[0.0, 0.0, 0.0, 1.0]])
M = 2
print(xyz.shape, scales.shape, rotation.shape)
split_points_corrected_v2 = split_corrected_v2(xyz, scales, rotation, M)
print(split_points_corrected_v2.shape)
# Visualizing the corrected points
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

# Original point
ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], color='red', label='Original Point')

# Corrected split points
ax.scatter(split_points_corrected_v2[:, 0], split_points_corrected_v2[:, 1], split_points_corrected_v2[:, 2], color='blue', label='Split Points')

# Labels and legend
ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
ax.set_zlabel('Z Axis')
ax.legend()

plt.title('Corrected Visualization of Point Splitting')
plt.savefig('corrected_splitting.png')