import numpy as np
from scipy.spatial.distance import euclidean
from numpy.linalg import norm
import torch
import warnings

def spherical_to_cartesian(r, theta, phi):
    x = r * np.sin(theta) * np.cos(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(theta)
    return np.array([x, y, z])

def geodesic_distance_on_sphere(p1, p2):
    return np.arccos(np.dot(p1, p2) / (norm(p1) * norm(p2)))

def normalize_to_sphere(point):
    return point / norm(point)

def point_on_geodesic_circle(center, radius, angle):
    theta_offset = radius
    center_theta = np.arccos(center[2])
    center_phi = np.arctan2(center[1], center[0])
    new_theta = center_theta + theta_offset * np.cos(angle)
    new_phi = center_phi + theta_offset * np.sin(angle)
    point = spherical_to_cartesian(1, new_theta, new_phi)
    return normalize_to_sphere(point)

def random_action_on_circle(center, radius):
    theta = np.random.uniform(0, 2 * np.pi)
    return point_on_geodesic_circle(center, radius, theta)

def action_minimizing_geodesic_distance(theta_A_cart, radius_A, theta_B_cart, num_points=100):
    best_action = None
    min_distance = np.inf
    for angle in np.linspace(0, 2 * np.pi, num_points):
        action = point_on_geodesic_circle(theta_A_cart, radius_A, angle)
        distance = geodesic_distance_on_sphere(action, theta_B_cart)
        if distance < min_distance:
            min_distance = distance
            best_action = action
    return best_action

def gisa_algorithm(t, theta_A, theta_B, C_t):
    radius_A = C_t(t)
    radius_B = C_t(t)
    theta_A_cart = spherical_to_cartesian(1, theta_A[0], theta_A[1])
    theta_B_cart = spherical_to_cartesian(1, theta_B[0], theta_B[1])
    geodesic_dist = geodesic_distance_on_sphere(theta_A_cart, theta_B_cart)
    overlap = geodesic_dist < (radius_A + radius_B)
    if overlap > 0.5:
        action = random_action_on_circle(theta_A_cart, radius_A)
    else:
        action = action_minimizing_geodesic_distance(theta_A_cart, radius_A, theta_B_cart)
    if action is None:
        pass
    return action, radius_A, radius_B

def lin_reg(manifold_points, rewards, reg_lambda=1e-6):
    mask = ~(torch.isnan(manifold_points).any(dim=1) | torch.isnan(rewards) |
             torch.isinf(manifold_points).any(dim=1) | torch.isinf(rewards))
    if mask.sum() < manifold_points.size(0):
        warnings.warn(f"Removed {manifold_points.size(0) - mask.sum().item()} rows due to NaNs or Infs.")
    manifold_points_clean = manifold_points[mask]
    rewards_clean = rewards[mask]
    XtX = manifold_points_clean.T @ manifold_points_clean
    XtX_reg = XtX + reg_lambda * torch.eye(XtX.size(0))
    XtX_inv = torch.pinverse(XtX_reg)
    XtY = manifold_points_clean.T @ rewards_clean
    weights = XtX_inv @ XtY
    weights = weights.squeeze()
    return weights
