import numpy as np
import torch
from mpl_toolkits.mplot3d import Axes3D
import torch
from tqdm import tqdm
from scipy.spatial import cKDTree
from scipy.optimize import linear_sum_assignment
from torch import cdist

import matplotlib.pyplot as plt
import torch.nn.functional as F

def bounded_randn(num_samples, input_dim, mean, std, lower_bound, upper_bound):
    x_np = np.random.randn(num_samples, input_dim)
    x_np = mean + std * x_np
    x_np = np.clip(x_np, lower_bound, upper_bound)
    x = torch.tensor(x_np, dtype=torch.float32, requires_grad=True)
    return x

def bounded_randu(num_samples, input_dim, mean, std, lower_bound, upper_bound):
    x_np = np.random.uniform(size=(num_samples, input_dim))
    x_np = x_np * (upper_bound - lower_bound) + lower_bound
    x = torch.tensor(x_np, dtype=torch.float32, requires_grad=True)
    return x

def stereographic_projection(x):
    n = x.shape[-1] - 1
    numerator = x[..., :n]
    denominator = 1 - x[..., n:]
    return numerator / denominator

def inverse_stereographic_projection(y):
    norm_sq = torch.sum(y ** 2, dim=-1, keepdim=True)
    denominator = 1 + norm_sq
    y_scaled = 2 * y / denominator
    last_component = (norm_sq - 1) / denominator
    return torch.cat([y_scaled, last_component], dim=-1)

def tunnel_stereographic_proj(x, N=1):
    for _ in range(N):
        x = stereographic_projection(x)
    return x

def tunnel_stereographic_inv_proj(x, N=1):
    for _ in range(N):
        x = inverse_stereographic_projection(x)
    return x

def plot_3D_x1_x2(x1):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x1[:, 0], x1[:, 1], x1[:, 2], c='r', label='x1_x2_concat')
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')
    ax.legend()
    plt.show()
    return True

def spherical_repulsion(points, num_iterations=10, lr=0.01, epsilon=1e-6):
    N = points.size(0)
    points = F.normalize(points, p=2, dim=1)
    for iteration in tqdm(range(num_iterations)):
        forces = torch.zeros_like(points)
        for i in range(N):
            for j in range(i + 1, N):
                direction = points[i] - points[j]
                distance = torch.norm(direction) + epsilon
                force_magnitude = 1.0 / (distance ** 2)
                force_direction = direction / distance
                forces[i] += force_magnitude * force_direction
                forces[j] -= force_magnitude * force_direction
        points += lr * forces
        points = F.normalize(points, p=2, dim=1)
        if torch.max(torch.norm(forces, dim=1)) < epsilon:
            break
    return points

def fibonacci_sphere(samples=100):
    points = []
    phi = torch.pi * (3.0 - torch.sqrt(torch.tensor(5.0, requires_grad=True)))
    for i in range(samples):
        y = 1 - (i / float(samples - 1)) * 2
        radius = torch.sqrt(torch.tensor(1.0, requires_grad=True) - y * y)
        theta = phi * i
        x = torch.cos(theta) * radius
        z = torch.sin(theta) * radius
        points.append([x, y, z])
    return torch.tensor(points, requires_grad=True)

def fibonacci_sphere_grad(samples=100):
    indices = torch.arange(0, samples, dtype=torch.float32, requires_grad=True)
    phi = torch.pi * (3.0 - torch.sqrt(torch.tensor(5.0)))
    y = 1 - (2 * indices / (samples - 1))
    radius = torch.sqrt(1 - y ** 2)
    theta = phi * indices
    x = torch.cos(theta) * radius
    z = torch.sin(theta) * radius
    points = torch.stack((x, y, z), dim=1)
    return points

def optimal_bijective_mapping(data_points, num_target_points=None):
    N = data_points.size(0)
    if num_target_points is None:
        num_target_points = N
    target_points = fibonacci_sphere_grad(samples=num_target_points)
    data_points = F.normalize(data_points, p=2, dim=1)
    distances = torch.cdist(data_points, target_points)
    distance_matrix = distances.detach().cpu().numpy()
    row_indices, col_indices = linear_sum_assignment(distance_matrix)
    assigned_points = target_points[col_indices]
    return assigned_points

def approximate_inverse_sphere_mapping(target_points, num_original_points):
    M = target_points.size(0)
    approx_data_points = fibonacci_sphere_grad(samples=num_original_points)
    target_points = F.normalize(target_points, p=2, dim=1)
    approx_data_points = F.normalize(approx_data_points, p=2, dim=1)
    distances = torch.cdist(target_points, approx_data_points)
    distance_matrix = distances.detach().cpu().numpy()
    row_indices, col_indices = linear_sum_assignment(distance_matrix)
    inverted_points = approx_data_points[col_indices]
    return inverted_points

def optimal_bijective_mapping_with_gradients(data_points, num_target_points=None):
    N = data_points.size(0)
    if num_target_points is None:
        num_target_points = N
    target_points = fibonacci_sphere_grad(samples=num_target_points)
    data_points = F.normalize(data_points, p=2, dim=1)
    distances = torch.cdist(data_points, target_points)
    with torch.no_grad():
        distance_matrix = distances.detach().cpu().numpy()
        row_indices, col_indices = linear_sum_assignment(distance_matrix)
    col_indices_tensor = torch.tensor(col_indices, dtype=torch.long, device=target_points.device)
    assigned_points = target_points[col_indices_tensor]
    return assigned_points

def bijective_nearest_mapping_tensor(data_points, num_target_points=None):
    N = data_points.size(0)
    if num_target_points is None:
        num_target_points = N
    target_points = fibonacci_sphere_grad(samples=num_target_points)
    data_points = F.normalize(data_points, p=2, dim=1)
    distances = cdist(data_points, target_points)
    _, nearest_indices = distances.min(dim=1)
    assigned_points = torch.zeros_like(data_points, requires_grad=True)
    assigned_target_points = torch.zeros(num_target_points, dtype=torch.float32, requires_grad=True)
    updated_assigned_points = []
    for i in range(N):
        nearest_index = nearest_indices[i]
        nearest_index = nearest_index.item()
        while assigned_target_points[nearest_index]:
            nearest_index = (nearest_index + 1) % num_target_points
        updated_assigned_points.append(target_points[nearest_index].unsqueeze(0))
        assigned_target_points = assigned_target_points.scatter(0, torch.tensor([nearest_index]), torch.tensor(1.0))
    assigned_points = torch.cat(updated_assigned_points, dim=0)
    return assigned_points

def bijective_nearest_mapping(data_points, num_target_points=None):
    N = data_points.size(0)
    if num_target_points is None:
        num_target_points = N
    target_points = fibonacci_sphere(samples=num_target_points)
    if target_points.requires_grad:
        target_points_np = target_points.detach().numpy()
    else:
        target_points_np = target_points.numpy()
    data_points = F.normalize(data_points, p=2, dim=1)
    target_tree = cKDTree(target_points_np)
    assigned_indices = []
    assigned_points = torch.zeros_like(data_points)
    if data_points.requires_grad:
        data_points_np = data_points.detach().numpy()
    else:
        data_points_np = data_points.numpy()
    for i in range(N):
        _, nearest_index = target_tree.query(data_points_np[i])
        nearest_index = int(nearest_index)
        max_neighbors = min(len(target_points_np), len(assigned_indices) + 1)
        while nearest_index in assigned_indices and max_neighbors <= len(target_points_np):
            _, nearest_index_list = target_tree.query(data_points_np[i], k=max_neighbors)
            for index in nearest_index_list:
                index = int(index)
                if index not in assigned_indices:
                    nearest_index = index
                    break
            max_neighbors += 1
            if max_neighbors > len(target_points_np):
                break
        assigned_indices.append(nearest_index)
        assigned_points[i] = target_points[nearest_index]
    return assigned_points
