import os
import pickle
import decoupled_methods 
import initialize
import evaluation
import hawkins_methods

import numpy as np 
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from whittle_binary_search_v2 import binary_search_all_arms, adjust_indexes


N = 5 # number of arms
S = 3 # number of states
A = 2 # number of actions
HB = 1 # indexes should be the same regardless of budget, but make sure it is non-negative
LB = 0
K = 2 # number of action types
a_index = 1 # for binary-action case, a_index is always 1
gamma = 0.95

# randomly sample start states
#start_state = np.random.choice(list(range(S)), size=N, replace=True)
start_state = np.zeros(N,dtype=int)

# make the reward function something simple, only based on current state
R = np.array([[0,0,1] for i in range(N)])

# Costs for each (action type, arm)
c1 = [1]*N
c2 = [1]*N
costs = np.array([[0,c1[i],c2[i]] for i in range(N)])

# Sample a transition probability matrix for each arm
# T[0][0] is the TP for arm 0 when in state 0. Index 0 not acting, rest for K 
P2 = np.array([
                [0.90, 0.10, 0.0],
                [0.0, 0.05, 0.95],
                [0.0, 0.25, 0.75]
                
        ])


P1 = np.array([
                [0.05, 0.95, 0.0],
                [0.75, 0.25, 0.0],
                [0.0, 0.24, 0.76]
                
        ])

P0 = np.array([
                [0.95, 0.05, 0.0],
                [0.75, 0.25, 0.0],
                [0.0, 0.25, 0.75]
        ])


T_i = np.array([P0, P1, P2])
T = np.array([T_i for _ in range(N)])

def jitter_T(T,eps=0.01):

    noise_val = np.random.rand(*T.shape)*eps
    T += noise_val
    T = T/T.sum(axis=-1, keepdims=True)
    return T


T = jitter_T(T)

#compute all per arm indexes for BS
index_lb = -1
index_ub = 1
tolerance = 1e-4
per_arm_indexes = np.zeros((N,S,K))
for s in range(S):
        current_state = np.array([s]*N)
        for a_index in range(1,K+1):
                per_arm_indexes[:,s,a_index-1] = binary_search_all_arms(T, R, costs, current_state, a_index, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)


EPOCHS = 30
STEPS = 25
algos = ['bsadj_separate', 'bsadj_joint', 'bs_separate', 'bs_joint','hawkins','random','no_action']
rewards = {}
allActions = {}
allBudget = {}
for algo in algos:
    rewards[algo] = np.ndarray(shape=(EPOCHS,STEPS))
    allActions[algo] = np.ndarray(shape=(EPOCHS,STEPS,N))
    allBudget[algo] = np.ndarray(shape=(EPOCHS,STEPS,K))

for epoch in tqdm(range(EPOCHS)):
    current_states = {}
    for algo in algos:
        current_states[algo] = start_state.copy()

    for i in range(STEPS):
        # get actions
        indexes_bsadj_separate =  adjust_indexes(T, R, costs, [HB]*K, current_states['bsadj_separate'], per_arm_indexes, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)
        actions_bsadj_separate = decoupled_methods.greedyAllocation(indexes_bsadj_separate.T, costs, N, HB, K)
        indexes_bsadj_joint =  adjust_indexes(T, R, costs, [HB]*K, current_states['bsadj_joint'], per_arm_indexes, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)
        actions_bsajd_joint = decoupled_methods.jointGreedyAllocation(indexes_bsadj_joint.T, costs, N, HB, K)

        actions_bs_separate = decoupled_methods.getDecoupledActions(N, T, R, HB, K, costs, current_states['bs_separate'], binarySearch=True, jointAllocation = False)
        actions_bs_joint = decoupled_methods.getDecoupledActions(N, T, R, HB, K, costs, current_states['bs_joint'], binarySearch=True, jointAllocation = True)
        actions_hawkins = hawkins_methods.get_hawkins_actions(N, T, R, costs, HB, LB, current_states['hawkins'], gamma)
        actions_random = decoupled_methods.randomUntilBudget(costs, HB, K, N)
        actions = {
            'bsadj_separate': actions_bsadj_separate,
            'bsadj_joint': actions_bsajd_joint,
            'bs_separate':actions_bs_separate,
            'bs_joint': actions_bs_joint,
            'hawkins': actions_hawkins,
            'random': actions_random,
            'no_action': np.zeros(N,dtype=int)} 

        # go to next state and get reward
        for algo in actions.keys():
            allBudget[algo][epoch,i] = decoupled_methods.usedBudget(costs,actions[algo], K)
            allActions[algo][epoch,i] = actions[algo]
            current_states[algo] = evaluation.nextState(actions[algo], current_states[algo], S, T)
            rewards[algo][epoch,i] = evaluation.getReward(current_states[algo], R)

    
means = []
std = []
for algo in algos:
    means.append(rewards[algo].sum(axis=1).mean())
    std.append(rewards[algo].sum(axis=1).std())


fig, (ax1, ax2) = plt.subplots(figsize=(15,5),ncols=2)
for algo in algos:
    acr= evaluation.ACR(rewards[algo], start_state, R)
    sns.lineplot(data=pd.DataFrame(acr[:,1:]).melt(), x='variable',y='value',label=algo,ax=ax1)
ax1.legend(loc=4)
ax1.set_ylabel('Average cumulative reward')
ax1.set_xlabel('Step')
ax2.bar(x=algos, height=means, yerr=std)
ax2.set_ylabel(f'Sum of Rewards ({STEPS} Steps)')
fig.suptitle(f'N={N} HB={HB} LB={LB} K={K} \n Cost 1: {c1[0]}, Cost 2: {c2[0]}')
plt.show()