import numpy as np
import scipy as sp
from typing import Union

def J(n, r, t1, t2):
    # Suface Area of a Hyperspherical Cap Cut by a Hyperplane
    # n: dimension
    # r: radius
    # t1: theta_1 (minimum angle from the axis)
    # t2: theta_2 (maximum angle from the axis)

    # Check t1 and t2
    if t1 < 0 or t1 > np.pi/2 or t2 < 0 or t2 > np.pi/2:
        raise ValueError(f'Arguments t1 and t2 should lie within [0,pi/2] but are {t1} and {t2}')
        return
    if t1==t2:
        return 0

    # Number of slices for numerical integration
    k = 10000

    # Integration from t1 to t2
    dt = (t2-t1)/k
    t = np.arange(t1+dt, t2, dt)

    # Regularized Incomplete Beta Function within the integration
    I = sp.special.betainc((n-2)/2, 1/2, 1 - (np.tan(t1) / np.tan(t))**2)

    # Take integral by sum
    J = np.sin(t)**(n-2) * I
    J = np.sum(J) * dt

    area = J * np.pi**((n-1)/2) / sp.special.gamma((n-1)/2) * r**(n-1)

    return area

def A(n, r, t):
    """
    Surface Area of a Hyperspherical Cap
    n: dimension
    r: radius
    t: theta (colatitude angle)
    """
    if t <= np.pi/2:
        area = np.pi**(n/2) / sp.special.gamma(n/2) * r**(n-1) * sp.special.betainc((n-1)/2, 0.5, np.sin(t)**2)
    else:
        area = 2 * np.pi**(n/2) / sp.special.gamma(n/2) * r**(n-1) * (1 - 0.5 * sp.special.betainc((n-1)/2, 0.5, np.sin(t)**2))
    return area

def calculate_intersection(n, theta_v):
    '''
    To calculate intersection of two hypersperical caps with colatitude angle pi/2
    n: dimension
    theta_v: angle between the two axes
    '''
    if np.pi < theta_v:
        raise ValueError(f'theta_v={theta_v} > pi')
    if theta_v<=np.pi/2:
        return A(n, 1, np.pi/2) - J(n, 1, np.pi/2-theta_v, np.pi/2)
    else:
        return J(n, 1, theta_v-np.pi/2, np.pi/2)
    
def percentage_intersection(n, theta_v: Union[list, np.array]):
    '''
    n: dimension
    theta_v: angle between the two axes
    '''
    cap_area = A(n, 1, np.pi/2)

    out = []

    for tv in theta_v:
        out.append(calculate_intersection(n, tv)/cap_area)
    return out


def posterior_prob(ij, ik, jk, lmbda = 0.1):
    '''
    Calculates P(u.j>u.k|u.i>u.k)
    ij: i.j
    ik: i.k
    jk: j.k
    '''
    tmp = (ij-ik-jk+1)/((4*(1-jk)*(1-ik))**0.5 + lmbda)
    return 1 - (np.arccos(tmp))/np.pi