import torch
import math

def convert_psk(tensor, categories):
    device = tensor.device
    tensor = tensor.int()
    cat_spacings = 2 * math.pi / torch.tensor(categories).to(device)
    
    phases = tensor * cat_spacings
    psk_enc = torch.cat((torch.sin(phases), torch.cos(phases)), dim=-1)
    return psk_enc

def psk2cat(psk_tensor, categories):
    device = psk_tensor.device
    cat_len = categories.shape[0]
    cat_spacings = 2 * math.pi / torch.tensor(categories).to(device)
    
    sin_phases = psk_tensor[:, :cat_len]
    cos_phases = psk_tensor[:, cat_len:2*cat_len]
    
    phases = torch.atan2(sin_phases, cos_phases)
    phases = (phases + 2 * math.pi) % (2 * math.pi)
    
    feature_cat = torch.round(phases / cat_spacings).long()

    total = feature_cat.numel()
    corrected = 0

    for i, cat in enumerate(categories):
        corrected += (feature_cat[:, i] >= cat).sum().item()
        feature_cat[:, i] = torch.where(feature_cat[:, i] >= cat, 0, feature_cat[:, i])

    if total > 0:
        casting_rate = corrected / total
        print(f"Casting rate in psk2cat: {round(casting_rate, 3)}")
    
    return feature_cat






# def convert_psk(tensor, categories):
#     """PSK Encoder: Maps categorical tensor to phase-shifted sin/cos."""
#     device = tensor.device
#     tensor = tensor.int()
#     cat_spacings = 2 * math.pi / torch.tensor(categories, device=device)
#     phases = tensor * cat_spacings
#     return torch.cat((torch.sin(phases), torch.cos(phases)), dim=-1)

# def psk2cat(psk_tensor, categories, missing_val=-1, eps=1e-3):
#     """Convert PSK-encoded tensor back to categorical (with missing value support)."""
#     device = psk_tensor.device
#     categories = torch.tensor(categories, device=device)
#     cat_len = len(categories)
#     cat_spacings = 2 * math.pi / categories

#     sin_phases, cos_phases = psk_tensor[:, :cat_len], psk_tensor[:, cat_len:]
    
#     # Detect missing values (low amplitude)
#     missing_mask = (sin_phases.abs() < eps) & (cos_phases.abs() < eps)
    
#     # Compute phase (auto-normalized to [0, 2π))
#     phases = torch.atan2(sin_phases, cos_phases)
#     phases = (phases + 2 * math.pi) % (2 * math.pi)
    
#     # Vectorized category decoding + overflow handling
#     feature_cat = torch.round(phases / cat_spacings).long()
#     feature_cat = feature_cat % categories  # Handles overflow in one step
    
#     # Apply missing values
#     feature_cat[missing_mask] = missing_val
    
#     return feature_cat








# import numpy as np
# import torch
# import math
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader


# def generalized_spiral(K, device='cpu'):
#     """More uniform spherical distribution than Fibonacci"""
#     indices = torch.arange(K, dtype=torch.float32, device=device)
#     h = -1 + 2 * (indices) / (K - 1)
#     theta = torch.acos(h)
#     phi = torch.zeros(K, device=device)
    
#     # Golden angle increment
#     golden_angle = math.pi * (3 - math.sqrt(5))
#     phi[1:] = (phi[1:] + golden_angle * indices[1:]) % (2*math.pi)
    
#     # Convert to Cartesian
#     x = torch.sin(theta) * torch.cos(phi)
#     y = torch.sin(theta) * torch.sin(phi)
#     z = torch.cos(theta)
    
#     return torch.stack([x, y, z], dim=1)




# def fibonacci_sphere(K, device='cpu'):
#     """
#     Generate K nearly-uniform points on a sphere using Fibonacci sphere method.
    
#     Args:
#         K: Number of points (categories)
#         device: Target device for the tensor
        
#     Returns:
#         Tensor of shape (K, 3) with Cartesian coordinates on the unit sphere
#     """
#     indices = torch.arange(K, dtype=torch.float32, device=device)
#     phi = math.pi * (3. - math.sqrt(5.))  # Golden angle

#     y = 1 - (indices / (K - 1)) * 2  # y goes from 1 to -1
#     radius = torch.sqrt(1 - y ** 2)  # Radius at height y

#     theta = phi * indices

#     x = torch.cos(theta) * radius
#     z = torch.sin(theta) * radius

#     return torch.stack([x, y, z], dim=1)  # shape (K, 3)












# import math
# import torch

# import torch
# import math

# def cake_slice_sphere(K, device='cpu'):
#     """
#     Uniform spherical distribution with fixed longitudinal spacing.
#     Like slicing the Earth into equal-angle wedges.
    
#     Args:
#         K: Number of points
#         device: Target device
        
#     Returns:
#         (K, 3) tensor of Cartesian coordinates
#     """
#     indices = torch.arange(K, dtype=torch.float32, device=device)
    
#     # Latitude adjustment (compensates for sphere curvature)
#     h = -1 + 2 * indices / (K - 1)  # Uniform in height
#     theta = torch.acos(h)  # Non-uniform in angle (more points near equator)
    
#     # Fixed angular spacing in longitude (like cake slices)
#     phi = (2 * math.pi * indices / K) % (2 * math.pi)
    
#     # Convert to Cartesian
#     x = torch.sin(theta) * torch.cos(phi)
#     y = torch.sin(theta) * torch.sin(phi)
#     z = torch.cos(theta)
    
#     return torch.stack([x, y, z], dim=1)




# def convert_psk(tensor, categories, precomputed_refs=None):
#     device = tensor.device
#     batch_size, num_features = tensor.shape
#     encoded = torch.zeros((batch_size, 3*num_features), device=device)
    
#     # Generate or reuse reference points (now latitude-based)
#     if precomputed_refs is None:
#         precomputed_refs = [cake_slice_sphere(K, device) for K in categories]
    
#     for i in range(num_features):
#         start, end = 3*i, 3*(i+1)
#         encoded[:, start:end] = precomputed_refs[i][tensor[:, i].long()]
    
#     return encoded, precomputed_refs

# def psk2cat(psk_tensor, categories, precomputed_refs=None, missing_val=-1, eps=1e-6):
#     device = psk_tensor.device
#     batch_size = psk_tensor.shape[0]
#     decoded = torch.full((batch_size, len(categories)), missing_val, 
#                         dtype=torch.long, device=device)
    
#     psk_3d = psk_tensor.view(batch_size, -1, 3)
#     norms = torch.norm(psk_3d, dim=2)
#     valid_mask = norms >= eps
    
#     if precomputed_refs is None:
#         precomputed_refs = [cake_slice_sphere(K, device) for K in categories]
    
#     for i, (K, refs) in enumerate(zip(categories, precomputed_refs)):
#         mask = valid_mask[:, i]
#         if mask.any():
#             dists = torch.cdist(psk_3d[mask, i], refs)
#             decoded[mask, i] = torch.argmin(dists, dim=1)
    
#     return decoded


