import torch
import math

# def convert_psk(tensor, categories):
#     device = tensor.device
#     tensor = tensor.int()
#     cat_spacings = 2 * math.pi / torch.tensor(categories).to(device)
    
#     phases = tensor * cat_spacings
#     psk_enc = torch.cat((torch.sin(phases), torch.cos(phases)), dim=-1)
#     return psk_enc

# def psk2cat(psk_tensor, categories):
#     device = psk_tensor.device
#     cat_len = categories.shape[0]
#     cat_spacings = 2 * math.pi / torch.tensor(categories).to(device)
    
#     sin_phases = psk_tensor[:, :cat_len]
#     cos_phases = psk_tensor[:, cat_len:2*cat_len]
    
#     phases = torch.atan2(sin_phases, cos_phases)
#     phases = (phases + 2 * math.pi) % (2 * math.pi)
    
#     feature_cat = torch.round(phases / cat_spacings).long()

#     total = feature_cat.numel()
#     corrected = 0

#     for i, cat in enumerate(categories):
#         corrected += (feature_cat[:, i] >= cat).sum().item()
#         feature_cat[:, i] = torch.where(feature_cat[:, i] >= cat, 0, feature_cat[:, i])

#     if total > 0:
#         casting_rate = corrected / total
#         print(f"Casting rate in psk2cat: {round(casting_rate, 3)}")
    
#     return feature_cat






# def convert_psk(tensor, categories):
#     """PSK Encoder: Maps categorical tensor to phase-shifted sin/cos."""
#     device = tensor.device
#     tensor = tensor.int()
#     cat_spacings = 2 * math.pi / torch.tensor(categories, device=device)
#     phases = tensor * cat_spacings
#     return torch.cat((torch.sin(phases), torch.cos(phases)), dim=-1)

# def psk2cat(psk_tensor, categories, missing_val=-1, eps=1e-3):
#     """Convert PSK-encoded tensor back to categorical (with missing value support)."""
#     device = psk_tensor.device
#     categories = torch.tensor(categories, device=device)
#     cat_len = len(categories)
#     cat_spacings = 2 * math.pi / categories

#     sin_phases, cos_phases = psk_tensor[:, :cat_len], psk_tensor[:, cat_len:]
    
#     # Detect missing values (low amplitude)
#     missing_mask = (sin_phases.abs() < eps) & (cos_phases.abs() < eps)
    
#     # Compute phase (auto-normalized to [0, 2π))
#     phases = torch.atan2(sin_phases, cos_phases)
#     phases = (phases + 2 * math.pi) % (2 * math.pi)
    
#     # Vectorized category decoding + overflow handling
#     feature_cat = torch.round(phases / cat_spacings).long()
#     feature_cat = feature_cat % categories  # Handles overflow in one step
    
#     # Apply missing values
#     feature_cat[missing_mask] = missing_val
    
#     return feature_cat








# import numpy as np
# import torch
# import math
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader


# def generalized_spiral(K, device='cpu'):
#     """More uniform spherical distribution than Fibonacci"""
#     indices = torch.arange(K, dtype=torch.float32, device=device)
#     h = -1 + 2 * (indices) / (K - 1)
#     theta = torch.acos(h)
#     phi = torch.zeros(K, device=device)
    
#     # Golden angle increment
#     golden_angle = math.pi * (3 - math.sqrt(5))
#     phi[1:] = (phi[1:] + golden_angle * indices[1:]) % (2*math.pi)
    
#     # Convert to Cartesian
#     x = torch.sin(theta) * torch.cos(phi)
#     y = torch.sin(theta) * torch.sin(phi)
#     z = torch.cos(theta)
    
#     return torch.stack([x, y, z], dim=1)




# def fibonacci_sphere(K, device='cpu'):
#     """
#     Generate K nearly-uniform points on a sphere using Fibonacci sphere method.
    
#     Args:
#         K: Number of points (categories)
#         device: Target device for the tensor
        
#     Returns:
#         Tensor of shape (K, 3) with Cartesian coordinates on the unit sphere
#     """
#     indices = torch.arange(K, dtype=torch.float32, device=device)
#     phi = math.pi * (3. - math.sqrt(5.))  # Golden angle

#     y = 1 - (indices / (K - 1)) * 2  # y goes from 1 to -1
#     radius = torch.sqrt(1 - y ** 2)  # Radius at height y

#     theta = phi * indices

#     x = torch.cos(theta) * radius
#     z = torch.sin(theta) * radius

#     return torch.stack([x, y, z], dim=1)  # shape (K, 3)












# import math
# import torch

# import torch
# import math

# def cake_slice_sphere(K, device='cpu'):
#     """
#     Uniform spherical distribution with fixed longitudinal spacing.
#     Like slicing the Earth into equal-angle wedges.
    
#     Args:
#         K: Number of points
#         device: Target device
        
#     Returns:
#         (K, 3) tensor of Cartesian coordinates
#     """
#     indices = torch.arange(K, dtype=torch.float32, device=device)
    
#     # Latitude adjustment (compensates for sphere curvature)
#     h = -1 + 2 * indices / (K - 1)  # Uniform in height
#     theta = torch.acos(h)  # Non-uniform in angle (more points near equator)
    
#     # Fixed angular spacing in longitude (like cake slices)
#     phi = (2 * math.pi * indices / K) % (2 * math.pi)
    
#     # Convert to Cartesian
#     x = torch.sin(theta) * torch.cos(phi)
#     y = torch.sin(theta) * torch.sin(phi)
#     z = torch.cos(theta)
    
#     return torch.stack([x, y, z], dim=1)




# def convert_psk(tensor, categories, precomputed_refs=None):
#     device = tensor.device
#     batch_size, num_features = tensor.shape
#     encoded = torch.zeros((batch_size, 3*num_features), device=device)
    
#     # Generate or reuse reference points (now latitude-based)
#     if precomputed_refs is None:
#         precomputed_refs = [cake_slice_sphere(K, device) for K in categories]
    
#     for i in range(num_features):
#         start, end = 3*i, 3*(i+1)
#         encoded[:, start:end] = precomputed_refs[i][tensor[:, i].long()]
    
#     return encoded, precomputed_refs

# def psk2cat(psk_tensor, categories, precomputed_refs=None, missing_val=-1, eps=1e-6):
#     device = psk_tensor.device
#     batch_size = psk_tensor.shape[0]
#     decoded = torch.full((batch_size, len(categories)), missing_val, 
#                         dtype=torch.long, device=device)
    
#     psk_3d = psk_tensor.view(batch_size, -1, 3)
#     norms = torch.norm(psk_3d, dim=2)
#     valid_mask = norms >= eps
    
#     if precomputed_refs is None:
#         precomputed_refs = [cake_slice_sphere(K, device) for K in categories]
    
#     for i, (K, refs) in enumerate(zip(categories, precomputed_refs)):
#         mask = valid_mask[:, i]
#         if mask.any():
#             dists = torch.cdist(psk_3d[mask, i], refs)
#             decoded[mask, i] = torch.argmin(dists, dim=1)
    
#     return decoded





















# def generalized_spiral(K, device='cpu'):
#     """Optimized uniform spherical distribution"""
#     indices = torch.arange(K, dtype=torch.float32, device=device)
#     h = -1 + 2 * (indices) / (K - 1)
#     theta = torch.acos(h)

#     phi = torch.zeros(K, device=device)
    
#     # Golden angle increment
#     golden_angle = math.pi * (3 - math.sqrt(5))
#     phi[1:] = (phi[1:] + golden_angle * indices[1:]) % (2*math.pi)



#     # Generate points
#     points = torch.stack([
#         torch.sin(theta) * torch.cos(phi),
#         torch.sin(theta) * torch.sin(phi),
#         torch.cos(theta)
#     ], dim=1)
    
    
#     return points


# def convert_psk(tensor, categories, precomputed_refs=None):
#     """
#     Optimized encoder with reference point caching.
    
#     Args:
#         tensor: (B, num_features) int tensor of category indices
#         categories: List[int] of category counts
#         precomputed_refs: Optional cached reference points
        
#     Returns:
#         encoded: (B, 3*num_features) tensor
#         ref_points: Cached reference points for decoding
#     """
#     device = tensor.device
#     batch_size, num_features = tensor.shape
#     encoded = torch.zeros((batch_size, 3*num_features), device=device)

    
#     # Generate or reuse reference points
#     if precomputed_refs is None:
#         precomputed_refs = [generalized_spiral(K, device) for K in categories]
    
#     # Vectorized encoding
#     for i in range(num_features):
#         start, end = 3*i, 3*(i+1)
#         encoded[:, start:end] = precomputed_refs[i][tensor[:, i].long()]
    
#     return encoded, precomputed_refs





# def psk2cat(psk_tensor, categories, precomputed_refs=None, missing_val=-1, eps=1e-6):
#     """
#     Optimized decoder with batched distance computation.
    
#     Args:
#         psk_tensor: (B, 3*num_features) encoded tensor
#         categories: List[int] of category counts
#         precomputed_refs: Cached reference points from encoding
#         missing_val: Value for missing entries
#         eps: Missing value threshold
        
#     Returns:
#         decoded: (B, num_features) category indices
#     """
#     device = psk_tensor.device
#     batch_size = psk_tensor.shape[0]
#     decoded = torch.full((batch_size, len(categories)), missing_val, 
#                         dtype=torch.long, device=device)
    
#     # Reshape for vectorized processing
#     psk_3d = psk_tensor.view(batch_size, -1, 3)  # (B, num_features, 3)
#     norms = torch.norm(psk_3d, dim=2)
#     valid_mask = norms >= eps
    
#     # Generate references if not provided
#     if precomputed_refs is None:
#         precomputed_refs = [generalized_spiral(K, device) for K in categories]
    
#     # Process all features in single loop
#     for i, (K, refs) in enumerate(zip(categories, precomputed_refs)):
#         mask = valid_mask[:, i]
#         if not mask.any():
#             continue
            
#         # Batched distance computation
#         valid_vectors = psk_3d[mask, i]  # (N_valid, 3)
#         valid_vectors = valid_vectors 
        
#         # Calculate distances
#         dists = torch.cdist(valid_vectors, refs)

#         decoded[mask, i] = torch.argmin(dists, dim=1)
    
#     return decoded




def generalized_spiral_distribution(K, device='cpu'):
    """Optimized uniform spherical distribution with segment-based assignment"""
    indices = torch.arange(K, dtype=torch.float32, device=device)
    h = -1 + 2 * (indices) / (K - 1)
    theta = torch.acos(h)

    # Golden angle increment
    golden_angle = math.pi * (3 - math.sqrt(5))
    phi = (golden_angle * indices) % (2*math.pi)

    # Generate points
    points = torch.stack([
        torch.sin(theta) * torch.cos(phi),
        torch.sin(theta) * torch.sin(phi),
        torch.cos(theta)
    ], dim=1)
    
    return points




def get_segment_representatives(K, device='cpu'):
    """Get representative points for each of K segments"""
    # Generate the full spiral
    full_refs = generalized_spiral_distribution(K * 10, device)  # 10 points per segment
    
    # Split into K segments
    segment_size = len(full_refs) // K
    segments = [full_refs[i*segment_size : (i+1)*segment_size] for i in range(K)]
    
    # Get midpoint of each segment as representative
    representatives = torch.stack([seg[segment_size//2] for seg in segments])
    return representatives

def convert_psk(tensor, categories, precomputed_reps=None):
    """Encoder using segment-based distribution"""
    device = tensor.device
    batch_size, num_features = tensor.shape
    encoded = torch.zeros((batch_size, 3*num_features), device=device)
    
    # Generate or reuse segment representatives
    if precomputed_reps is None:
        precomputed_reps = [get_segment_representatives(K, device) for K in categories]
    
    for i in range(num_features):
        K = categories[i]
        reps = precomputed_reps[i]
        
        # Get category indices for this feature
        cat_indices = tensor[:, i].long()
        
        # For each category, generate points uniformly in its segment
        # Here we just use the representative point for simplicity
        # (In practice you might add small random offsets)
        encoded[:, 3*i:3*(i+1)] = reps[cat_indices]
    
    return encoded, precomputed_reps

def psk2cat(psk_tensor, categories, precomputed_reps=None, missing_val=-1, eps=1e-6):
    """Decoder using segment representatives"""
    device = psk_tensor.device
    batch_size = psk_tensor.shape[0]
    decoded = torch.full((batch_size, len(categories)), missing_val,
                        dtype=torch.long, device=device)
    
    psk_3d = psk_tensor.view(batch_size, -1, 3)
    norms = torch.norm(psk_3d, dim=2)
    valid_mask = norms >= eps
    
    # Generate representatives if not provided
    if precomputed_reps is None:
        precomputed_reps = [get_segment_representatives(K, device) for K in categories]
    
    for i, (K, reps) in enumerate(zip(categories, precomputed_reps)):
        mask = valid_mask[:, i]
        if not mask.any():
            continue
            
        valid_vectors = psk_3d[mask, i]
        
        # Compare to segment representatives
        dists = torch.cdist(valid_vectors, reps)
        decoded[mask, i] = torch.argmin(dists, dim=1)
    
    return decoded