import numpy as np 
import random

# when acting, there is a better chance of going to good state than to bad state
# for each arm: T(s,acting,1) > T(s,not acting,1) --> T(s,acting,1) > 0.5
def assignBetter(T,N,S,K):
    for p in range(N):
        for i in range(S):
            for k in range(1,K+1):
                T[p][i][k][1] = random.randint(50000, 100000)/100000
                T[p][i][k][0] = 1 - T[p][i][k][1]
    return T

# Acting has better chances of leading to a good state than not acting
# for each arm: T(s,acting,1) > T(s,acting,0)
def satisfiesBetterAct(T,N,S,K):
    satisfies = []
    for p in range(N):
        for i in range(S):
            for k in range(1,K+1):
                satisfies.append(T[p][i][k][1] > T[p][i][0][1])
    return np.all(satisfies)

def satisfiesBetterAct_indiv(T,S,K):
    for i in range(S):
        for k in range(1,K+1):
            satisfies = T[0][i][k][1] > T[0][i][0][1]
    return satisfies

def actionTypeTransitions(N,S,K):
    T = np.random.rand(N,S,K+1,S) #transition probabilities

    # normalize so we get valid transition probabilities
    T = T /  T.sum(axis=3, keepdims=True)
    T = assignBetter(T,N,S,K)
    while not satisfiesBetterAct(T,N,S,K):
        T = np.random.rand(N,S,K+1,S)
        T = T /  T.sum(axis=3, keepdims=True)
        T = assignBetter(T,N,S,K)
    return T 

def actionTypeTransitions_indiv(S,K):
    T = np.random.rand(1,S,K+1,S) #transition probabilities

    # normalize so we get valid transition probabilities
    T = T /  T.sum(axis=3, keepdims=True)
    T = assignBetter(T,1,S,K)
    while not satisfiesBetterAct_indiv(T,S,K):
        T = np.random.rand(1,S,K+1,S)
        T = T /  T.sum(axis=3, keepdims=True)
        T = assignBetter(T,1,S,K)
    return T 

def actionTypeTransitions_all(N,S,K):
    T = []
    for _ in range(N):
        T.append(actionTypeTransitions_indiv(S,K))
    T = np.array(T).reshape(N,S,K+1,S)
    return T

def ordered_worker_T_2state(N, A, always_positive_index=True):

    S = 2

    T = np.zeros((N,S,A,S))

    # sample over all current states
    # keep them greater than 0.5 for the lower states
    sp = 0
    p_s_sp = np.random.rand(N,S)*0.5 + 0.5

    # sort along the s axis so that s0 has highest prob
    p_s_sp = np.sort(p_s_sp, axis=-1)[:,::-1]

    T[:,:,0,sp] = p_s_sp

    # set the other state in a normalized manner
    sp = 1
    T[:,:,0,sp] = 1 - p_s_sp


    # now sample the other actions in an ordered manner
    sp = 0
    p_s_a_sp = np.random.rand(N,S,A-1)

    # if always_positive_index, then make passive an upper bound
    if always_positive_index:
        p_s_a_sp = p_s_a_sp * T[:,:,0,sp].reshape(N,S,1)

    # now sort by action -- lower actions get lower probs for low states
    p_s_a_sp = np.sort(p_s_a_sp,axis=-1)

    T[:,:,1:,sp] = p_s_a_sp


    # set the others in a normalized way
    sp = 1
    T[:,:,1:,sp] = 1 - p_s_a_sp


    return T

def counter_example_T(N):
    P2 = np.array([
                    [0.90, 0.10, 0.0],
                    [0.0, 0.05, 0.95],
                    [0.0, 0.25, 0.75]
                    
            ])


    P1 = np.array([
                    [0.05, 0.95, 0.0],
                    [0.75, 0.25, 0.0],
                    [0.0, 0.24, 0.76]
                    
            ])

    P0 = np.array([
                    [0.95, 0.05, 0.0],
                    [0.75, 0.25, 0.0],
                    [0.0, 0.25, 0.75]
            ])


    T_i = np.array([P0, P1, P2])
    T_i = np.swapaxes(T_i, 0, 1)
    T = np.array([T_i for _ in range(N)])

    return T