import os
import sys
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import logm

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, idx=None, x_coord=None):
    batch_size = x.size(0)
    num_points = x.size(3)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if x_coord is None: # dynamic knn graph
            idx = knn(x, k=k)
        else:          # fixed knn graph with input point coordinates
            idx = knn(x_coord, k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()
    num_dims = num_dims // 3

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims, 3) 
    x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous()
  
    return feature


def get_graph_feature_cross(x, k=20, idx=None):
    # print(x.shape)     #torch.Size([1024, 1, 6, 200])
    batch_size = x.size(0)
    num_points = x.size(3)

    ### make 3d imu
    # x_acc = x[:,:,:3,:]
    # x_ang = x[:,:,3:,:]
    # x_resize = torch.cat((x_acc, x_ang),dim=3)
    # x_resize[:,:,:,0::2] = x_acc
    # x_resize[:,:,:,1::2] = x_ang
    # x = x_resize.reshape(batch_size, 3, -1)
    # num_points = num_points * 2
    ### make 3d imu
    
    x = x.view(batch_size, -1, num_points)
    
    if idx is None:
        idx = knn(x, k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()
    
    # num_dims = num_dims // 3  #for pointcloud and 3d shape imu
    num_dims = num_dims // 6  #for 6D imu

    # x = x.transpose(2, 1).contiguous()
    x_acc = x[:,:3,:].transpose(2,1).contiguous() 
    x_ang = x[:,3:,:].transpose(2,1).contiguous() 
    
    # feature = x.view(batch_size*num_points, -1)[idx, :]
    feature_acc = x_acc.view(batch_size*num_points, -1)[idx, :]
    feature_ang = x_ang.view(batch_size*num_points, -1)[idx, :]
    
    # feature = feature.view(batch_size, num_points, k, num_dims, 3) 
    feature_acc = feature_acc.view(batch_size, num_points, k, num_dims, 3) 
    feature_ang = feature_ang.view(batch_size, num_points, k, num_dims, 3) 
    
    # x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    x_acc = x_acc.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    x_ang = x_ang.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    
    # print(feature.shape, x.shape)
    # cross = torch.cross(feature, x, dim=-1)
    cross_acc = torch.cross(feature_acc, x_acc, dim=-1)
    cross_ang = torch.cross(feature_ang, x_ang, dim=-1)
    
    # feature = torch.cat((feature-x, x, cross), dim=3).permute(0, 3, 4, 1, 2).contiguous()
    feature_acc = torch.cat((feature_acc-x_acc, x_acc, cross_acc), dim=3).permute(0, 3, 4, 1, 2).contiguous()
    feature_ang = torch.cat((feature_ang-x_ang, x_ang, cross_ang), dim=3).permute(0, 3, 4, 1, 2).contiguous()
    feature = torch.cat((feature_acc,feature_ang), dim=1).contiguous()   
    # feature = torch.cat((feature,feature), dim=2).contiguous()
    # print('feature shape of getcross : ', feature.shape)     # [1024, 6, 3, 200, 20]
    return feature


def get_vector_feature(x):
    # print(x.shape)     #torch.Size([1024, 1, 6, 200])
    batch_size = x.size(0)
    num_points = x.size(3)

    ### make 3d imu
    # x_acc = x[:,:,:3,:]
    # x_ang = x[:,:,3:,:]
    # x_resize = torch.cat((x_acc, x_ang),dim=3)
    # x_resize[:,:,:,0::2] = x_acc
    # x_resize[:,:,:,1::2] = x_ang
    # x = x_resize.reshape(batch_size, 3, -1)
    # num_points = num_points * 2
    ### make 3d imu
    
    x = x.view(batch_size, -1, num_points)
    
    _, num_dims, _ = x.size()
    
    # div = 6 #for 6D imu
    # div = 9 #for 6D imu + vel_body
    # assert num_dims % div == 0, f"{num_dims} is not divisible by {div}"
    # num_dims = int(num_dims / div)
    # num_dims = 1
    
    x_acc = x[:,:3,:].transpose(2,1).contiguous() 
    x_ang = x[:,3:6,:].transpose(2,1).contiguous() 
    # if div == 9:
    # print(num_dims.item())
    if isinstance(num_dims, torch.Tensor):
        dims = num_dims.item()
    else:
        dims = num_dims
    if dims == 9 :
        x_vel_body = x[:,6:9,:].transpose(2,1).contiguous() 
    elif dims == 18 :
        x_vel_body = x[:,6:9,:].transpose(2,1).contiguous() 
        x_ori_3column = x[:,9:18,:].transpose(2,1).contiguous() 
        
    
    # feature = x.view(batch_size*num_points, -1)[idx, :]
    x_acc = x_acc.view(batch_size, num_points, 1, 3)
    x_ang = x_ang.view(batch_size, num_points, 1, 3)
    # if div == 9:
    if dims == 9 :
        x_vel_body = x_vel_body.view(batch_size, num_points, 1, 3)
    elif dims == 18 :
        x_vel_body = x_vel_body.view(batch_size, num_points, 1, 3)
        x_ori_3column = x_ori_3column.view(batch_size, num_points, 3, 3)
    # if div == 6 :
    if dims == 6 :
        feature = torch.cat((x_acc, x_ang), dim=2).contiguous()   
    # elif div == 9:
    elif dims == 9 :
        # print("here")
        feature = torch.cat((x_acc, x_ang, x_vel_body), dim=2).contiguous()   
    elif dims == 18 :
        feature = torch.cat((x_acc, x_ang, x_vel_body, x_ori_3column), dim=2).contiguous()   
    else:
        # assert div == 6 or div == 9
        assert dims == 6 or dims == 9
    return feature

def _vector_to_so3(vector_3d: torch.Tensor) -> torch.Tensor:
    """
    Helper function to embed a batch of 3D vectors into their
    9D so(3) Lie algebra matrix representation (flattened).
    """
    # Get the components of the 3D vector
    vx, vy, vz = vector_3d[..., 0], vector_3d[..., 1], vector_3d[..., 2]
    
    # Create a zero tensor with the correct output shape (..., 9)
    lie_algebra_9d = torch.zeros(vector_3d.shape[:-1] + (9,), device=vector_3d.device)
    
    # Embed the vector into the flattened so(3) matrix
    # lie_algebra_9d[..., 1] = -vx
    # lie_algebra_9d[..., 2] =  vy
    # lie_algebra_9d[..., 3] =  vx
    # lie_algebra_9d[..., 5] = -vz
    # lie_algebra_9d[..., 6] = -vy
    # lie_algebra_9d[..., 7] =  vz
    lie_algebra_9d[..., 1] = -vz
    lie_algebra_9d[..., 2] =  vy
    lie_algebra_9d[..., 3] =  vz
    lie_algebra_9d[..., 5] = -vx
    lie_algebra_9d[..., 6] = -vy
    lie_algebra_9d[..., 7] =  vx
    
    return lie_algebra_9d


def get_lie_algebra_feature(x: torch.Tensor) -> torch.Tensor:
    """
    Transforms a tensor with various features into a new feature space
    containing so(3) Lie algebra representations. Handles both 12D and 15D inputs.

    Args:
        x: Input tensor of shape (B, 1, C_feat, T), where C_feat is 12 or 15.
           - If C_feat=12: 3D acceleration + 9D covariance.
           - If C_feat=15: 3D velocity + 3D acceleration + 9D covariance.

    Returns:
        A tensor with so(3) features.
        - If C_feat=12: Shape is (B, T, 2, 9).
        - If C_feat=15: Shape is (B, T, 3, 9).
    """
    # Get dimensions from the input tensor
    batch_size = x.size(0)
    num_features = x.size(2) # C_feat (e.g., 12 or 15)
    num_points = x.size(3)   # T (e.g., 200)

    # Reshape and permute for easier slicing: (B, 1, C_feat, T) -> (B, T, C_feat)
    features = x.view(batch_size, num_features, num_points).permute(0, 2, 1)

    # --- Conditional loop based on input dimension ---
    if num_features == 15:
        # Case 1: Input is [vel (3), acc (3), cov (9)]
        vel_vector = features[:, :, :3]
        acc_vector = features[:, :, 3:6]
        covariance_vector = features[:, :, 6:]

        lie_algebra_vel = _vector_to_so3(vel_vector)
        lie_algebra_acc = _vector_to_so3(acc_vector)

        # Stack the three 9D features
        output_feature = torch.cat(
            (
                lie_algebra_vel.unsqueeze(2), 
                lie_algebra_acc.unsqueeze(2), 
                covariance_vector.unsqueeze(2)
            ), 
            dim=2
        ) # Shape: (B, T, 3, 9)

    elif num_features == 12:
        # Case 2: Input is [acc (3), cov (9)]
        acc_vector = features[:, :, :3]
        covariance_vector = features[:, :, 3:]

        lie_algebra_acc = _vector_to_so3(acc_vector)

        covariance_matrix = covariance_vector.reshape(batch_size * num_points, 3, 3)
        # covariance_matrix = 0.5 * (covariance_matrix + covariance_matrix.transpose(-2, -1))

        eigvals, eigvecs = torch.linalg.eigh(covariance_matrix)  # [B,3,3]
        # eps = 1e-6
        # print(eigvals)
        # eigvals = F.softplus(eigvals) + eps
        log_eigvals = torch.log(eigvals)                        # [B,3]
        # print(torch.isnan(log_eigvals).any())
        log_covariance = eigvecs @ torch.diag_embed(log_eigvals) @ eigvecs.transpose(-2, -1)
        log_covariance_vector = log_covariance.clone().detach().to(covariance_matrix.device)  # clone + detach avoids warning
        log_covariance_vector = log_covariance.cuda().reshape(batch_size, num_points, 9)

        # Stack the two 9D features
        output_feature = torch.cat(
            (
                lie_algebra_acc.unsqueeze(2), 
                # covariance_vector.unsqueeze(2)
                log_covariance_vector.unsqueeze(2)
            ), 
            dim=2
        ) # Shape: (B, T, 2, 9)
        # print(torch.isnan(log_covariance_vector).any())
        # print(output_feature.shape)
        # assert False
    else:
        raise ValueError(f"Unsupported feature dimension: {num_features}. Expected 12 or 15.")

    return output_feature
