import numpy as np
from scipy.spatial.distance import euclidean
from numpy.linalg import norm
import torch
import warnings
from common.variables import OUTPUT_DIM

def spherical_to_cartesian(r, theta, phi):
    x = r * np.sin(theta) * np.cos(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(theta)
    return np.array([x, y, z])

def random_action_on_circle_N(center, radius, n_dims):
    angles = np.random.uniform(0, 2 * np.pi, size=n_dims)
    return point_on_geodesic_circle_N(center, radius, angles)

def point_on_geodesic_circle_N(center, radius, angles):
    theta_offsets = radius * np.cos(angles)
    center_angles = cartesian_to_spherical_N(center)
    new_angles = np.array(center_angles) + theta_offsets
    point = spherical_to_cartesian_N(1, *new_angles)
    return normalize_to_sphere(point)

def cartesian_to_spherical_N(cartesian_coords):
    r = np.linalg.norm(cartesian_coords)
    spherical_coords = []
    for i in range(len(cartesian_coords) - 1):
        if i == 0:
            theta = np.arccos(cartesian_coords[-1] / r)
            spherical_coords.append(theta)
        elif i < len(cartesian_coords) - 2:
            norm_sub = np.linalg.norm(cartesian_coords[i:])
            theta = np.arccos(cartesian_coords[i] / norm_sub)
            spherical_coords.append(theta)
        else:
            phi = np.arctan2(cartesian_coords[-1], cartesian_coords[-2])
            spherical_coords.append(phi)
    return spherical_coords

def spherical_to_cartesian_N(r, *angles):
    coords = []
    for i in range(len(angles) - 1):
        coord = r
        for j in range(i):
            coord *= np.sin(angles[j])
        coord *= np.cos(angles[i])
        coords.append(coord)
    final_coord1 = r
    for i in range(len(angles) - 1):
        final_coord1 *= np.sin(angles[i])
    coords.append(final_coord1 * np.cos(angles[-1]))
    final_coord2 = final_coord1 * np.sin(angles[-1])
    coords.append(final_coord2)
    return np.array(coords)

def geodesic_distance_on_sphere(p1, p2):
    return np.arccos(np.dot(p1, p2) / (norm(p1) * norm(p2)))

def normalize_to_sphere(point):
    return point / norm(point)

def point_on_geodesic_circle(center, radius, angle):
    theta_offset = radius
    center_theta = np.arccos(center[2])
    center_phi = np.arctan2(center[1], center[0])
    new_theta = center_theta + theta_offset * np.cos(angle)
    new_phi = center_phi + theta_offset * np.sin(angle)
    point = spherical_to_cartesian(1, new_theta, new_phi)
    return normalize_to_sphere(point)

def random_action_on_circle(center, radius):
    theta = np.random.uniform(0, 2 * np.pi)
    return point_on_geodesic_circle(center, radius, theta)

def action_minimizing_geodesic_distance(theta_A_cart, radius_A, theta_B_cart, num_points=100):
    best_action = None
    min_distance = np.inf
    for angle in np.linspace(0, 2 * np.pi, num_points):
        if OUTPUT_DIM <= 3:
            action = point_on_geodesic_circle(theta_A_cart, radius_A, angle)
        else:
            action = point_on_geodesic_circle_N(theta_A_cart, radius_A, angle)
        distance = geodesic_distance_on_sphere(action, theta_B_cart)
        if distance < min_distance:
            min_distance = distance
            best_action = action
    return best_action

def gisa_algorithm(t, theta_A, theta_B, C_t, D = 3):
    radius_A = C_t(t)
    radius_B = C_t(t)
    if D <= 3:
        theta_A_cart = spherical_to_cartesian(1, theta_A[0], theta_A[1])
        theta_B_cart = spherical_to_cartesian(1, theta_B[0], theta_B[1])
    else:
        theta_A_cart = spherical_to_cartesian_N(1, *theta_A)
        theta_B_cart = spherical_to_cartesian_N(1, *theta_B)
    geodesic_dist = geodesic_distance_on_sphere(theta_A_cart, theta_B_cart)
    overlap = geodesic_dist < (radius_A + radius_B)
    if overlap > 0.5:
        if OUTPUT_DIM <= 3:
            action = random_action_on_circle(theta_A_cart, radius_A)
        else:
            action = random_action_on_circle_N(theta_A_cart, radius_A, OUTPUT_DIM)
    else:
        action = action_minimizing_geodesic_distance(theta_A_cart, radius_A, theta_B_cart)
    if action.shape[0] == OUTPUT_DIM + 1:
        action = cartesian_to_spherical_N(action)
    return action, radius_A, radius_B

def lin_reg(manifold_points, rewards, reg_lambda=1e-6):
    mask = ~(torch.isnan(manifold_points).any(dim=1) | torch.isnan(rewards) |
             torch.isinf(manifold_points).any(dim=1) | torch.isinf(rewards))
    if mask.sum() < manifold_points.size(0):
        warnings.warn(f"Removed {manifold_points.size(0) - mask.sum().item()} rows due to NaNs or Infs.")
    manifold_points_clean = manifold_points[mask]
    rewards_clean = rewards[mask]
    XtX = manifold_points_clean.T @ manifold_points_clean
    XtX_reg = XtX + reg_lambda * torch.eye(XtX.size(0))
    XtX_inv = torch.pinverse(XtX_reg)
    XtY = manifold_points_clean.T @ rewards_clean
    weights = XtX_inv @ XtY
    weights = weights.squeeze()
    return weights
