import os
import random
import torch
import h5py
import numpy as np
from enum import Enum
from datetime import datetime

from tools.utils import io
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
import torch.nn.functional as F

class NetworkType(Enum):
    VQVAE = 'VQVAE'
    GEN = 'GEN'
    POINTBERT = 'POINTBERT'
    HOIGPT = 'HOIGPT'

def decode_angle(anchorsList, pred_anchor_index, residual_pred):
    B = pred_anchor_index.shape[0]
    device = anchorsList.device
    N_total = anchorsList.shape[0]
    N_per_axis = N_total // 3
    
    anchor_z_list = anchorsList[0:N_per_axis]      # (N, 3)
    anchor_y_list = anchorsList[N_per_axis:2*N_per_axis]
    anchor_x_list = anchorsList[2*N_per_axis:]
    anchor_z = anchor_z_list[pred_anchor_index[:, 0]]  # (B, 3)
    anchor_y = anchor_y_list[pred_anchor_index[:, 1]]  # (B, 3)
    anchor_x = anchor_x_list[pred_anchor_index[:, 2]]  # (B, 3)
    selected_anchors = torch.stack([anchor_z, anchor_y, anchor_x], dim=1)  # (B, 3, 3)
    pred_vectors = F.normalize(selected_anchors + residual_pred, dim=-1)  # (B, 3, 3)
    pred_axis_angle = matrix_to_axis_angle(pred_vectors)
    return pred_axis_angle

def get_gt_anchor_idx_and_residual(axis_angle, anchorsList):
    B = axis_angle.shape[0]
    R = axis_angle_to_matrix(axis_angle)  # (B, 3, 3)

    N_total = anchorsList.shape[0]
    N_per_axis = N_total // 3
    
    anchor_list_z = anchorsList[0:N_per_axis]      # (N, 3)
    anchor_list_y = anchorsList[N_per_axis:2*N_per_axis]
    anchor_list_x = anchorsList[2*N_per_axis:]
    
    # 提取 Z, Y, X 轴方向向量
    dir_z = R[:, :, 2]  # (B, 3)
    dir_y = R[:, :, 1]
    dir_x = R[:, :, 0]

    dirs = torch.stack([dir_z, dir_y, dir_x], dim=1)  # (B, 3, 3)

    gt_anchor_idx = []
    gt_residual = []

    for i, anchor_list in enumerate([anchor_list_z, anchor_list_y, anchor_list_x]):
        anchors = F.normalize(anchor_list, dim=-1)  # (N, 3)
        dirs_i = dirs[:, i].unsqueeze(1)            # (B, 1, 3)
        sim = F.cosine_similarity(dirs_i, anchors.unsqueeze(0), dim=-1)  # (B, N)
        idx = torch.argmax(sim, dim=-1)             # (B,)
        gt_anchor_idx.append(idx)

        selected_anchor = anchors[idx]              # (B, 3)
        residual = dirs[:, i] - selected_anchor     # (B, 3)
        gt_residual.append(residual)

    gt_anchor_idx = torch.stack(gt_anchor_idx, dim=1)    # (B, 3) → Z, Y, X
    gt_residual = torch.stack(gt_residual, dim=1)        # (B, 3, 3) → Z, Y, X

    return gt_anchor_idx, gt_residual

def select_res_by_index(anchor_tensor, index_tensor):
    B, D, N, _ = anchor_tensor.shape  # D=3 表示 x/y/z 方向

    index_tensor = index_tensor.unsqueeze(-1).unsqueeze(-1)  # shape: (B, 3, 1, 1)

    index_tensor = index_tensor.expand(-1, -1, 1, 3)  # (B, 3, 1, 3)
    selected_anchors = torch.gather(anchor_tensor, dim=2, index=index_tensor)  # (B, 3, 1, 3)

    selected_anchors = selected_anchors.squeeze(2)  # (B, 3, 3)
    return selected_anchors

def generate_xyz_directional_anchors(num_per_axis=8):
    anchors = []
    for i in range(num_per_axis):
        angle = 2 * np.pi * i / num_per_axis
        x = np.cos(angle)
        y = np.sin(angle)
        z = 0
        anchors.append([x, y, z])
    for i in range(num_per_axis):
        angle = 2 * np.pi * i / num_per_axis
        x = np.cos(angle)
        y = 0
        z = np.sin(angle)
        anchors.append([x, y, z])
    for i in range(num_per_axis):
        angle = 2 * np.pi * i / num_per_axis
        x = 0
        y = np.cos(angle)
        z = np.sin(angle)
        anchors.append([x, y, z])

    anchors = np.array(anchors)
    anchors = anchors / np.linalg.norm(anchors, axis=1, keepdims=True)  
    return anchors # z y x

def get_anchor_and_residual(gt_axis_angle, anchor_set):
    B = gt_axis_angle.shape[0]
    N = anchor_set.shape[0]

    R_gt = axis_angle_to_matrix(gt_axis_angle).unsqueeze(1)  # [B, 1, 3, 3]
    R_anchor = axis_angle_to_matrix(anchor_set.unsqueeze(0))  # [1, N, 3, 3]

    R_delta = torch.matmul(R_anchor.transpose(-1, -2), R_gt)  # [B, N, 3, 3]
    residuals = matrix_to_axis_angle(R_delta)  # [B, N, 3]

    residual_magnitude = residuals.norm(dim=-1)  # [B, N]
    gt_anchor_idx = residual_magnitude.argmin(dim=1)  # [B]
    batch_idx = torch.arange(B, device=gt_axis_angle.device)
    gt_residual = residuals[batch_idx, gt_anchor_idx]  # [B, 3]

    return gt_anchor_idx.float(), gt_residual.float()

def decode_direction(anchor_axis_angle: torch.Tensor, residual: torch.Tensor):
    R_anchor = axis_angle_to_matrix(anchor_axis_angle)  # [B, 3, 3]
    R_residual = axis_angle_to_matrix(residual)         # [B, 3, 3]
    R_combined = torch.bmm(R_residual, R_anchor)  # [B, 3, 3]
    final_axis_angle = matrix_to_axis_angle(R_combined)  # [B, 3]
    return final_axis_angle

def fibonacci_sphere(samples=100):
    points = []
    phi = np.pi * (3. - np.sqrt(5))  # golden angle in radians

    for i in range(samples):
        y = 1 - (i / float(samples - 1)) * 2  # y goes from 1 to -1
        radius = np.sqrt(1 - y * y)  # radius at y

        theta = phi * i  # golden angle increment

        x = np.cos(theta) * radius
        z = np.sin(theta) * radius

        points.append([x, y, z])

    return np.array(points)  # shape [samples, 3]

def set_random_seed(seed):
    np.random.seed(seed)
    torch.set_rng_state(torch.manual_seed(seed).get_state())
    random.seed(seed)

def get_soft_label_from_index(target_idx, anchors, sigma=0.3, topk=5):
    B = target_idx.shape[0]
    K = anchors.shape[0]
    gt_dirs = anchors[target_idx]  # [B, 3]
    cos_sim = torch.matmul(gt_dirs, anchors.T)  # [B, K]
    topk_sim, topk_idx = torch.topk(cos_sim, topk, dim=1)  # [B, topk]
    weights = torch.exp(topk_sim / sigma)  # [B, topk]
    weights = weights / weights.sum(dim=1, keepdim=True)
    soft_labels = torch.zeros(B, K, device=anchors.device)
    soft_labels.scatter_(1, topk_idx, weights)

    return soft_labels  # [B, K]
    
def duration_in_hours(duration):
    t_m, t_s = divmod(duration, 60)
    t_h, t_m = divmod(t_m, 60)
    duration_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
    return duration_time


def get_prediction_vertices(pred_segmentation, pred_coordinates):
    segmentations = np.argmax(pred_segmentation, axis=1)
    coordinates = pred_coordinates[
        np.arange(pred_coordinates.shape[0]).reshape(-1, 1),
        np.arange(3) + 3 * np.tile(segmentations.reshape(-1, 1), [1, 3])]
    return segmentations, coordinates


def get_num_parts(h5_file_path):
    if not io.file_exist(h5_file_path):
        raise IOError(f'Cannot open file {h5_file_path}')
    input_h5 = h5py.File(h5_file_path, 'r')
    num_parts = input_h5[list(input_h5.keys())[0]].attrs['numParts']
    bad_groups = []
    visit_groups = lambda name, node: bad_groups.append(name) if isinstance(node, h5py.Group) and node.attrs[
        'numParts'] != num_parts else None
    input_h5.visititems(visit_groups)
    input_h5.close()
    if len(bad_groups) > 0:
        raise ValueError(f'Instances {bad_groups} in {h5_file_path} have different number of parts than {num_parts}')
    return num_parts


def get_latest_file_with_datetime(path, folder_prefix, ext, datetime_pattern='%Y-%m-%d_%H-%M-%S'):
    folders = os.listdir(path)
    folder_pattern = folder_prefix + datetime_pattern
    matched_folders = np.asarray([fd for fd in folders if fd.startswith(folder_prefix)
                                  if len(io.get_file_list(os.path.join(path, fd), ext))])
    if len(matched_folders) == 0:
        return '', ''
    timestamps = np.asarray([int(datetime.strptime(fd, folder_pattern).timestamp() * 1000) for fd in matched_folders])
    sort_idx = np.argsort(timestamps)
    matched_folders = matched_folders[sort_idx]
    latest_folder = matched_folders[-1]
    files = io.alphanum_ordered_file_list(os.path.join(path, latest_folder), ext=ext)
    latest_file = files[-1]
    return latest_folder, latest_file


class AvgRecorder(object):
    """
    Average and current value recorder
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def sample_points_from_mesh(mesh, num_points=10000):
    points, _ = mesh.sample(num_points, return_index=True)
    return points