
import numpy as np

def get_Ps_given():
    return [0.5, 0.5]


def get_Py1x_given(num_cat=4):
    # get empty vector of length (2, num_cat)
    # fill ech row with numbers between 0 and 1, such that the numbers increase from left to right
    Py1x_given = np.zeros((2, num_cat))

    Py1x_given[0] = np.linspace(0.1, 0.7, num_cat)
    Py1x_given[1] = np.linspace(0.1, 0.9, num_cat)
    return Py1x_given

def _get_Pxx_given(num_cat=4, seed=6):
    # s=0
    # one-sided feedback these should be same
    # get an empty  array named Test that has the dimension (1,num_cat).
    Test = np.zeros((1,num_cat))


    T000 = np.array([[0.70, 0.1, 0.1, 0.1], [0.1, 0.7, 0.1, 0.1],[0.1, 0.1, 0.7, 0.1], [0.1, 0.1, 0.1, 0.7] ])
    T001 = np.array([[0.70, 0.1, 0.1, 0.1], [0.1, 0.7, 0.1, 0.1],[0.1, 0.1, 0.7, 0.1], [0.1, 0.1, 0.1, 0.7] ])

    T010 = np.array([[0.85, 0.05, 0.05, 0.05], [0.85, 0.05, 0.05, 0.05],[0.65, 0.25, 0.05, 0.05], [0.55, 0.25, 0.15, 0.05] ])
    T011 = np.array([[0.3, 0.55, 0.1, 0.05], [0.1, 0.35, 0.5, 0.05],[0.05, 0.05, 0.3, 0.6], [0.05, 0.05, 0.15, 0.75] ])

    # s=1,
    T100 = np.array([[0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05],[0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85] ])
    T101 = np.array([[0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05],[0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85] ])

    T110 = np.array([[0.85, 0.05, 0.05, 0.05], [0.65, 0.25, 0.05, 0.05],[0.3, 0.4, 0.25, 0.05], [0.15, 0.25, 0.35, 0.25] ])
    T111 = np.array([[0.1, 0.55, 0.3, 0.05], [0.1, 0.1, 0.5, 0.3],[0.05, 0.05, 0.1, 0.8], [0.05, 0.05, 0.05, 0.85] ])

    #Tsd
    T00 = np.array([T000, T001])
    T01 = np.array([T010, T011])
    # careful, here copy pasting 0
    T10 = np.array([T100, T101])
    T11 = np.array([T110, T111])


    #Ts
    T0 = np.array([T00, T01])
    T1 = np.array([T10, T11])
    #T
    T = np.array([T0, T1])

    for s in [0,1]:
        for d in [0,1]:
            for y in [0,1]:
                for x in range(4):
                    if sum(T[s][d][y][x]).round(4) != 1:
                        print("!!!!! Does not sum up", s, d, y, x)
    return T

def get_Pxx_given(num_cat=4, seed=6):

# fix numpy seed
    np.random.seed(seed)

    # get an empty array of dimensions (2, 2, 2, num_cat, num_cat).
    Pxx_given = np.zeros((2, 2, 2, num_cat, num_cat))


    for i in range(2):
        for j in range(2):
            for k in range(2):
                # get an empty  array named Test that has the dimension (num_cat,num_cat).
                Test = np.zeros((num_cat,num_cat))
                # fill the array with numbers between 0 and 1, such that each row sums to 1
                for m in range(num_cat):
                    Test[m,:] = np.random.dirichlet(np.ones(num_cat),size=1)

                #assert that each row sums to one.
                assert np.allclose(np.sum(Test, axis=1), 1), "rows doe not sum to one"
                Pxx_given[i][j][k] = Test
    # print(Pxx_given)
    return Pxx_given



def get_Px_given(num_cat = 4, seed=6):
    np.random.seed(seed)
    # get empty vector of length (2, num_cat)
    Px_given = np.zeros((2, num_cat))
    # fill each row with numbers between 0 and 1, such that each row sums to 1
    for i in range(2):
        Px_given[i] = np.random.dirichlet(np.ones(num_cat),size=1)

    return Px_given
#%%


