import numpy as np
import pandas as pd
from scipy.spatial.distance import jensenshannon

def JS(n1, n2):
    # Calculate JS for y
    samples1 = np.array([n1[0] + n1[1], n1[2] + n1[3]])
    samples2 = np.array([n2[0] + n2[1], n2[2] + n2[3]])
    p1 = samples1 / samples1.sum()
    p2 = samples2 / samples2.sum()
    js_y = jensenshannon(p1, p2, base=2)
    # print(f'JS on P(y): {js_y}')

    # Calculate JS for a
    samples1 = np.array([n1[0] + n1[2], n1[1] + n1[3]])
    samples2 = np.array([n2[0] + n2[2], n2[1] + n2[3]])
    p1 = samples1 / samples1.sum()
    p2 = samples2 / samples2.sum()
    js_a = jensenshannon(p1, p2, base=2)
    # print(f'JS on P(a): {js_a}')
    return js_y, js_a

def cond_JS(n1, n2):
    samples_1_a0 = np.array([n1[0], n1[2]])
    samples_2_a0 = np.array([n2[0], n2[2]])
    samples_1_a1 = np.array([n1[1], n1[3]])
    samples_2_a1 = np.array([n2[1], n2[3]])

    p1_a0 = samples_1_a0 / samples_1_a0.sum()
    p2_a0 = samples_2_a0 / samples_2_a0.sum()
    p1_a1 = samples_1_a1 / samples_1_a1.sum()
    p2_a1 = samples_2_a1 / samples_2_a1.sum()

    js_divergence_a0 = jensenshannon(p1_a0, p2_a0, base=2)
    js_divergence_a1 = jensenshannon(p1_a1, p2_a1, base=2)

    average_js_divergence = (js_divergence_a0 * (n2[0] + n2[2]) + js_divergence_a1 * (n2[1] + n2[3])) / sum(n2)
    return average_js_divergence