import decoupled_methods 
import initialize
import evaluation
import hawkins_methods
import combinatorial_mdp
import random

import numpy as np 
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from whittle_binary_search_v3 import binary_search_all_arms, adjust_indexes


N = 5 # number of arms
S = 2 # number of states
A = 2 # number of actions
HB = 1 # indexes should be the same regardless of budget, but make sure it is non-negative
LB = 0
K = 2 # number of action types
a_index = 1 # for binary-action case, a_index is always 1
gamma = 0.95

# randomly sample start states
start_state = np.random.choice([0,1,2], size=N, replace=True)

# make the reward function something simple, only based on current state
R = np.array([[0,0,1] for i in range(N)])

# Costs for each (action type, arm)
c1 = [1]*N
c2 = [1]*N
costs = np.array([[0,c1[i],c2[i]] for i in range(N)])

# Sample a transition probability matrix for each arm
# T[0][0] is the TP for arm 0 when in state 0. Index 0 not acting, rest for K 
#T = initialize.actionTypeTransitions(N,S,K)
#T_i = np.array([P0, P1, P2])
#T_i = np.swapaxes(T_i, 0, 1)
#T = np.array([T_i for _ in range(N)])

params = {
    'MWRMAB_adj': [False,True],
    'MWRMAB': [False, False]
}

EPOCHS = 25
STEPS = 50
algos = list(params.keys()) + ['OPT_fair','OPT','hawkins','random','no_action']
rewards = {}
allActions = {}
allBudget = {}
for algo in algos:
    rewards[algo] = np.ndarray(shape=(EPOCHS,STEPS))
    allActions[algo] = np.ndarray(shape=(EPOCHS,STEPS,N))
    allBudget[algo] = np.ndarray(shape=(EPOCHS,STEPS,K))

for epoch in tqdm(range(EPOCHS)):
    N = random.randint(3,8)
    HB = np.max([1,np.floor(N/4)])
    c1 = [1]*N
    c2 = [1]*N
    costs = np.array([[0,c1[i],c2[i]] for i in range(N)])
    R = np.array([[0,1] for i in range(N)])

    T = initialize.actionTypeTransitions(N,S,K) #N,S,A,S

    #compute all per arm indexes for BS
    index_lb = -1
    index_ub = 1
    tolerance = 1e-4
    per_arm_indexes = np.zeros((N,S,K))
    for s in range(S):
            current_state = np.array([s]*N)
            for a_index in range(1,K+1):
                    # be sure T is in N,S,A,S
                    per_arm_indexes[:,s,a_index-1] = binary_search_all_arms(T, R, costs, current_state, a_index, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)


    # need to make a new CombMDP object at the beginning

    # make sure that T is in shape NxAxSxS
    comb_mdp_fair = combinatorial_mdp.CombMDP()
    comb_mdp_fair.make_mdp(T, R, costs, HB, LB, fairness_epsilon=None)
    comb_mdp_fair.value_iteration(gamma)
    comb_mdp_fair.enumerate_policy()

    comb_mdp = combinatorial_mdp.CombMDP()
    x = comb_mdp.make_mdp(T, R, costs, HB, LB, fairness_epsilon=99999999)
    comb_mdp.value_iteration(gamma)

    current_states = {}
    #start_state = np.random.choice([0,1], size=N, replace=True)
    for algo in algos:
        current_states[algo] = start_state.copy()

    for i in range(STEPS):
        # get actions
        actions = {}

        # compute actions for decoupled algorithms
        for algo in params.keys():
            actions[algo] = decoupled_methods.getDecoupledActions(per_arm_indexes, T, R, HB, costs, current_states[algo], jointAllocation = params[algo][0], adjusted = params[algo][1])
        
        # compute actions for baselines
        actions['OPT_fair'] = comb_mdp_fair.get_action(current_states['OPT_fair'])
        actions['OPT'] = comb_mdp.get_action(current_states['OPT'])
        actions['hawkins'] = hawkins_methods.get_hawkins_actions(N, T, R, costs, HB, LB, current_states['hawkins'], gamma)
        actions['random'] = decoupled_methods.randomUntilBudget(costs, HB, K, N)
        actions['no_action'] = np.zeros(N,dtype=int)

        # go to next state and get reward
        for algo in actions.keys():
            allBudget[algo][epoch,i] = decoupled_methods.usedBudget(costs,actions[algo], K)
            allActions[algo][epoch,i] = actions[algo]
            current_states[algo] = evaluation.nextState(actions[algo], current_states[algo], S, T)
            rewards[algo][epoch,i] = evaluation.getReward(current_states[algo], R)/N

