from math import ceil
import numpy as np
from mwrmab.algos.whittle_binary_search import adjust_indexes

def workersOrdering(indexes, tau, unallocated, M, breakties=False):
    # get next best arm to allocate for each worker
    bestNext = []
    for j in range(M):
        inter = [x for x in tau[j] if x in unallocated]
        bestNext.append(inter[-1] if len(inter)>0 else None)

    # break wroker-arm ties: if next best arm is the same for more than one worker,
    # let the arm be the best for the worker with higher index

    #                w1      w2     w3
    # best arm      1(5)    1(4)    2(3)
    # 2nd best arm          2(2)
    # 3rd                   None

    # sigma = (w1,w3)
    if breakties:
        untie = [] 
        if len(unallocated)>=M:
            bestNotNone = [x for x in bestNext if x is not None]
            while len(bestNotNone) != len(set(bestNotNone)):
                    for j in range(M):
                        for k in range(j+1, M):
                            if (bestNext[j] == bestNext[k]) and (bestNext[j] is not None):
                                untie.append(bestNext[j])
                                min = j if indexes[j,bestNext[j]]<indexes[k,bestNext[j]] else k
                                auxlist = [x for x in tau[min] if (x in unallocated) and (x not in untie) and x!=bestNext[min]]
                                bestNext[min] = auxlist[-1] if len(auxlist)>0 else None
                                bestNotNone = [x for x in bestNext if x is not None]
        
    # order workers by best next arm not yet allocated 
    # TODO: break ties with lower cost
    sigma = np.argsort([-1* indexes[j,bestNext[j]] if bestNext[j] is not None else 0 for j in range(M)])

    # remove workers with None to assign from ordering
    sigma = np.array([x for x in sigma if bestNext[x] is not None])

    return sigma, bestNext

def greedyAllocation(indexes, costs, N, B, M, breakties=False):
    tau = [list(np.argsort(x)[np.sort(x)>=0]) for x in indexes]
    unallocated = list(range(N))
    allocation = [[]]*M
    stop = False
    while not stop:
        # order workers by best next arm not yet allocated 
        if len(unallocated)>0:
            sigma, bestNext = workersOrdering(indexes, tau, unallocated, M, breakties)

            if len(sigma)<M:
                stop = True

            for j in sigma:
                if set(tau[j]) & set(unallocated):
                    x = bestNext[j]
                    #shouldn't this ordering be done earlir? sigma should take this into account?
                    while costs[x][j+1] + costs[:,j+1][allocation[j]].sum() > B: 
                        if x in tau[j]:
                            tau[j].remove(x)
                        if not set(tau[j]) & set(unallocated):
                            stop = True
                            break
                        else:
                            x = [x for x in tau[j] if x in unallocated][-1]
                    if (set(tau[j]) & set(unallocated)):
                        allocation[j] = allocation[j] + [x]
                        if x in unallocated: unallocated.remove(x)
                        if x in tau[j]: tau[j].remove(x)
        else:
            break
    
    actions = [0]*N
    for j, alloc in enumerate(allocation):
        for n in alloc:
            actions[n] = j+1
    
    return np.array(actions)



''''
def greedyAllocation(results, costs, N, B, K):
    pref = [list(np.argsort(x)[np.sort(x)>=0]) for x in results]
    budget = [B]*K
    pull = [[]]*K
    stop = [False]*K
    # import pdb;pdb.set_trace()
    while not np.any(stop):
        for i in range(K):
            if not np.any(stop):
                if len(pref[i])>0:
                    assigned = False
                    while (not assigned) and (len(pref[i])>0):
                        best = pref[i][-1]
                        if budget[i] - costs[best][i+1] >= 0:
                            # action type i is assigned to best
                            pull[i] = pull[i] + [best]
                            assigned = True
                            pref[i].pop()
                            #update budget
                            budget[i] -= costs[best][i+1]
                            #remove item from the other pref list
                            # TODO: make this faster
                            for j in range(K):
                                if best in pref[j]:
                                    pref[j].remove(best)
                        else:
                            pref[i].pop()
                            if len(pref[i])==0:
                                stop[i]=True    
                else:
                    stop[i]=True

    #Get actions
    pulled = set(sum([pull[k] for k in range(K)],[]))
    notPulled = list(set(range(N)) - pulled)
    pull = [notPulled] + pull
    actions = [np.where([x in y for y in pull])[0][0]for x in range(N)]
    return np.array(actions)
'''

def jointGreedyAllocation(results, costs, N, B, K):
    budget = [B]*K
    actions = [0]*N
    indexes = np.concatenate(results)
    pref = np.flip(np.argsort(indexes))
    for i in pref:
        #action_type = int(i>=N)+1
        action_type = int(np.floor(i/N)+1)
        arm = int(np.floor(i/K))
        if budget[action_type-1] - costs[arm][action_type] >= 0:
            if (actions[i%N]==0) and (indexes[i]>0):
                actions[i%N] = action_type
                budget[action_type-1] -= costs[i%N][action_type]
    return np.array(actions)

def printSummary(pull, budget, B, K, N):
    pulled = set(sum([pull[k] for k in range(K)],[]))
    for k in range(K):
        print(f'Action type {k} acts on arms {pull[k]} with cost {B-budget[k]}')
    print(f'No action on arms {list(set(range(N)) - pulled)}')

def randomUntilBudgetOld(costs, B, K, N):
    stop = [False]*K
    while not np.all(stop):
        actions = np.random.randint(0, K+1, N, dtype=int)
        budgets = []
        for k in range(K):
            budgets.append(costs[actions==k+1,k+1].sum())
        stop = [b<=B for b in budgets]
    return actions

def randomUntilBudget(costs, B, K, N):
    actions = np.zeros(N, dtype=int)
    budgets = np.ones(K)*B
    payments = np.zeros(K)
    arm_order = np.random.choice(np.arange(N),size=N,replace=False)
    for arm in arm_order:
        potential_actions = np.arange(1,K+1)
        valid_actions = potential_actions[costs[arm, 1:] + payments <= budgets]
        if len(valid_actions) > 0:
            action = np.random.choice(valid_actions)
            actions[arm] = action
            payments[action-1] += costs[arm, action]
            if (payments >= budgets).all():
                break
    return actions


