import decoupled_methods 
import initialize
import evaluation
import hawkins_methods
import combinatorial_mdp
import random

import numpy as np 
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from whittle_binary_search_v3 import binary_search_all_arms, adjust_indexes


N = 3 # number of arms
S = 3 # number of states
A = 2 # number of actions
HB = 1 # indexes should be the same regardless of budget, but make sure it is non-negative
LB = 0
K = 2 # number of action types
a_index = 1 # for binary-action case, a_index is always 1
gamma = 0.95

# randomly sample start states
start_state = np.random.choice([0,1], size=N, replace=True)

# make the reward function something simple, only based on current state
R = np.array([[0,0,1] for i in range(N)])

# Costs for each (action type, arm)
c1 = [1]*N
c2 = [1]*N
costs = np.array([[0,c1[i],c2[i]] for i in range(N)])

# Sample a transition probability matrix for each arm
# T[0][0] is the TP for arm 0 when in state 0. Index 0 not acting, rest for K 
# T = initialize.actionTypeTransitions(N,S,K)
#np.save('./logs/T/T_2.npy', T)
# T = np.load('./logs/T/T_1.npy')

P2 = np.array([
                [0.90, 0.10, 0.0],
                [0.0, 0.05, 0.95],
                [0.0, 0.25, 0.75]
                
        ])


P1 = np.array([
                [0.05, 0.95, 0.0],
                [0.75, 0.25, 0.0],
                [0.0, 0.24, 0.76]
                
        ])

P0 = np.array([
                [0.95, 0.05, 0.0],
                [0.75, 0.25, 0.0],
                [0.0, 0.25, 0.75]
        ])


T_i = np.array([P0, P1, P2])
T_i = np.swapaxes(T_i, 0, 1)
T = np.array([T_i for _ in range(N)])

params = {
    'MWRMAB_adj': [False,True],
    #'MWRMAB_adj_joi': [True,True],
    'MWRMAB': [False, False],
    #'MWRMAB_joi': [True, False] 
}
#compute all per arm indexes for BS
index_lb = -1
index_ub = 1
tolerance = 1e-4
per_arm_indexes = np.zeros((N,S,K))
for s in range(S):
        current_state = np.array([s]*N)
        for a_index in range(1,K+1):
                # be sure T is in N,S,A,S
                per_arm_indexes[:,s,a_index-1] = binary_search_all_arms(T, R, costs, current_state, a_index, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)


# need to make a new CombMDP object at the beginning

   # make sure that T is in shape NxAxSxS
comb_mdp_fair = combinatorial_mdp.CombMDP()
comb_mdp_fair.make_mdp(T, R, costs, HB, LB, fairness_epsilon=None)
comb_mdp_fair.value_iteration(gamma)
comb_mdp_fair.enumerate_policy()

comb_mdp = combinatorial_mdp.CombMDP()
x = comb_mdp.make_mdp(T, R, costs, HB, LB, fairness_epsilon=99999999)
comb_mdp.value_iteration(gamma)
comb_mdp.enumerate_policy()


# Note: only need to run value iteration once for each new T!
comb_mdp_fair.value_iteration(gamma)
comb_mdp.value_iteration(gamma)


EPOCHS = 20
STEPS = 20

algos = list(params.keys()) + ['OPT_fair','OPT','hawkins','random','no_action']
# algos = list(params.keys()) + ['OPT','random','no_action']
rewards = {}
allActions = {}
allBudget = {}
for algo in algos:
    rewards[algo] = np.ndarray(shape=(EPOCHS,STEPS))
    allActions[algo] = np.ndarray(shape=(EPOCHS,STEPS,N))
    allBudget[algo] = np.ndarray(shape=(EPOCHS,STEPS,K))

for epoch in tqdm(range(EPOCHS)):

    #T = initialize.actionTypeTransitions(N,S,K)

    current_states = {}
    #start_state = np.random.choice([0,1], size=N, replace=True)
    for algo in algos:
        current_states[algo] = start_state.copy()

    for i in range(STEPS):
        # get actions
        actions = {}

        # compute actions for decoupled algorithms
        for algo in params.keys():
            actions[algo] = decoupled_methods.getDecoupledActions(per_arm_indexes, T, R, HB, costs, current_states[algo], jointAllocation = params[algo][0], adjusted = params[algo][1])
            # if algo=='MWRMAB_adj':
            #     print(actions)
        
        # compute actions for baselines
        actions['OPT_fair'] = comb_mdp_fair.get_action(current_states['OPT_fair'])
        actions['OPT'] = comb_mdp.get_action(current_states['OPT'])
        actions['hawkins'] = hawkins_methods.get_hawkins_actions(N, T, R, costs, HB, LB, current_states['hawkins'], gamma)
        actions['random'] = decoupled_methods.randomUntilBudget(costs, HB, K, N)
        actions['no_action'] = np.zeros(N,dtype=int)

        # go to next state and get reward
        for algo in actions.keys():
            allBudget[algo][epoch,i] = decoupled_methods.usedBudget(costs,actions[algo], K)
            allActions[algo][epoch,i] = actions[algo]
            current_states[algo] = evaluation.nextState(actions[algo], current_states[algo], S, T)
            rewards[algo][epoch,i] = evaluation.getReward(current_states[algo], R)

    
means = []
std = []
algos_plot = algos[:-1]
for algo in algos_plot:
    means.append(rewards[algo].mean(axis=1).mean())
    std.append(rewards[algo].mean(axis=1).std())


fig, ax2 = plt.subplots(figsize=(8,4))
ax2.bar(x=algos_plot, height=means, yerr=std)
ax2.set_ylabel(f'Mean Stepwise Reward')
fig.suptitle(f'N={N} B={HB} M={K} \n Cost 1: {c1}, Cost 2: {c2}')
plt.show()


fraction = []
fraction_std = []
for algo in algos_plot[:-1]:
    x = ((np.absolute(allBudget[algo][:,:,0]-allBudget[algo][:,:,1])<=HB)/STEPS).sum(axis=1)
    fraction.append(x.mean())
    fraction_std.append(x.std())

fig, ax2 = plt.subplots(figsize=(8,4))
ax2.bar(x=algos_plot[:-1], height=fraction, yerr = fraction_std)
ax2.set_ylabel(f'Fraction of timesteps with fair allocation')
#ax2.set_ylim(0.97,1)
plt.show()


fig, (ax1, ax2) = plt.subplots(figsize=(18,5),ncols=2,gridspec_kw={'width_ratios': [1, 2.5]})
for algo in algos_plot:
    acr= evaluation.ACR(rewards[algo], start_state, R)
    sns.lineplot(data=pd.DataFrame(acr[:,1:]).melt(), x='variable',y='value',label=algo,ax=ax1)
ax1.legend(loc=4)
ax1.set_ylabel('Average cumulative reward')
ax1.set_xlabel('Step')
ax2.bar(x=algos_plot, height=means, yerr=std)
ax2.set_ylabel(f'Sum of Rewards ({STEPS} Steps)')
fig.suptitle(f'N={N} HB={HB} LB={LB} K={K} \n Cost 1: {c1}, Cost 2: {c2}')
plt.show()


means = []
std = []
algos_plot = algos
for algo in algos_plot:
    means.append(np.absolute(allBudget[algo][:,:,0]-allBudget[algo][:,:,1]).mean(axis=0).mean())
    std.append(np.absolute(allBudget[algo][:,:,0]-allBudget[algo][:,:,1]).mean(axis=0).std())


fig, ax2 = plt.subplots(figsize=(10,5))
ax2.bar(x=algos_plot, height=means, yerr=std)
ax2.set_ylabel(f'(max spent budget) - (min spent budget)')
#fig.suptitle(f'N={N} B={HB} M={K} \n Cost 1: {c1}, Cost 2: {c2}')
plt.show()

# difference in used budget
plots = algos_plot[:4]
fig, ax = plt.subplots(figsize=(6,5))
for i in range(len(plots)):
    #sns.lineplot(data=pd.DataFrame(allBudget[plots[i]][:,:,0]).melt(), x='variable',y='value', label='Action type 1', ax=ax[i])
    #sns.lineplot(data=pd.DataFrame(allBudget[plots[i]][:,:,1]).melt(), x='variable',y='value', label='Action type 2', ax=ax[i])
    df = (pd.DataFrame(allBudget[plots[i]][:,:,1])-pd.DataFrame(allBudget[plots[i]][:,:,0])).melt()
    df['value'] = df.value.abs()
    sns.lineplot(data=df, x='variable',y='value', label=plots[i], ax=ax)
ax.set_xlabel('Step')
ax.set_ylabel('Max Used Budget - Min Used Budget')
ax.legend(loc=3)
fig.suptitle(f'N={N} HB={HB} LB={LB} K={K} \n Cost 1: {c1}, Cost 2: {c2}', fontsize=10)
plt.subplots_adjust(left=5, bottom=5, right=5, top=20)
plt.show()

# used budget
plots = algos_plot[:4]
fig, ax = plt.subplots(figsize=(20,5), ncols=len(plots))
for i in range(len(plots)):
    sns.lineplot(data=pd.DataFrame(allBudget[plots[i]][:,:,0]).melt(), x='variable',y='value', label='Action type 1', ax=ax[i])
    sns.lineplot(data=pd.DataFrame(allBudget[plots[i]][:,:,1]).melt(), x='variable',y='value', label='Action type 2', ax=ax[i])
    ax[i].set_xlabel('Step')
    ax[i].set_ylabel('Used Budget')
    ax[i].set_title(plots[i], fontsize=10, fontweight="bold")
    ax[i].legend(loc=3)
fig.suptitle(f'N={N} HB={HB} LB={LB} K={K} \n Cost 1: {c1}, Cost 2: {c2}', fontsize=10)
plt.subplots_adjust(left=5, bottom=5, right=5, top=20)
plt.show()


# Count pulls
algo = 'hawkins'
def countActions(x):
    counts = pd.DataFrame(x).apply(lambda x: x.value_counts()).fillna(0)
    add = [x for x in list(range(K)) if x not in counts.index]
    for idx in add:
        counts.loc[idx] = 0
    counts.sort_index(ignore_index=False,inplace=True)
    return counts
allCounts = pd.concat(list(map(countActions,allActions[algo])))
mean = allCounts.reset_index().groupby('index').mean()
std = allCounts.reset_index().groupby('index').std()

fig, ax = plt.subplots()
mean.T.plot.bar(color=['gray','tab:blue','tab:orange'],ax=ax,yerr=std.T)
ax.legend(['No action', 'Action type 1', 'Action type 2'],fontsize=8)
ax.set_ylabel(f'Average Count ({STEPS} trials, {EPOCHS} epochs)',fontsize=10)
ax.set_xlabel('Arm')
plt.suptitle(f'{algo}',fontsize=10,fontweight="bold")
plt.title(f'N={N} HB={HB} LB={LB} K={K} \n Cost 1: {c1[0]}, Cost 2: {c2[0]}',fontsize=10)
plt.xticks(rotation=0, horizontalalignment="center")
plt.subplots_adjust(left=5, bottom=5, right=5, top=20)
plt.show()


# action type - state indexes

results0 = decoupled_methods.decoupledIndex(costs, T, R, HB, [0]*N, binarySearch=True)
results1 = decoupled_methods.decoupledIndex(costs, T, R, HB, [1]*N, binarySearch=True)
results2 = decoupled_methods.decoupledIndex(costs, T, R, HB, [2]*N, binarySearch=True)

index_action_type_1 = pd.DataFrame(list(map(np.ravel, [results0[0],results1[0],results2[0]])))
index_action_type_2 = pd.DataFrame(list(map(np.ravel, [results0[1],results1[1]])))

index_hawkins = pd.DataFrame(hawkins_methods.hawkins_fairness(T, R, costs, HB, LB, [0]*N, gamma=gamma)[0]).T
lambdas_hawkins = pd.DataFrame(hawkins_methods.hawkins_fairness(T, R, costs, HB, LB, [0]*N, gamma=gamma)[1])

compareT = pd.DataFrame(T[:,:,1,1]).T>pd.DataFrame(T[:,:,2,1]).T
compareI = index_action_type_1>index_action_type_2
where = np.where(compareT != compareI)
for i in range(len(where[0])):
    arm = where[1][i]
    state = where[0][i]
    print(f'Arm {arm} in state {state} \n T[1]:{T[arm,state,1,1]:.3f}, T[2]:{T[arm,state,2,1]:.3f} \n I[1]:{index_action_type_1.loc[state,arm]:.3f}, I[2]:{index_action_type_2.loc[state,arm]:.3f}')
    print('\n')