import numpy as np
from scipy.stats import entropy
from scipy.stats import beta
import scipy.integrate as integrate


### Statistics routines for longitudinal runs
### All of these take in a single longitudinal run/history, which is a list of length T
### where each item in the list is of the form (action, reward, [all_rewards])
### They should return a single scalar which we can scatter plot

def last_opt(item):
    if len(item) == 0:
        return 0
    opts = np.array([x[0] == 0 for x in item])
    ind = np.where(opts)[0][-1] if np.any(opts) else 0
    return ind/len(item)

def suffix_fail(item,t=50):
    if len(item) == 0:
        return 1
    opts = np.array([x[0] == 0 for x in item])
    ind = np.where(opts)[0][-1] if np.any(opts) else 0
    return 1 if ind < t else 0

def ave_reward(item):
    return np.mean([x[1] for x in item])

def ave_reward_last_half(item):
    arr = np.array([x[1] for x in item])
    T = len(arr)
    return np.mean(arr[int(T/2):])

def opt_count(item):
    opts = np.array([1 if x[0]==0 else 0 for x in item])
    return np.mean(opts)

def opt_count_last_half(item):
    opts = np.array([1 if x[0]==0 else 0 for x in item])
    T = len(opts)
    return np.mean(opts[int(T/2):])

def min_count(item,K=5):
    if len(item) == 0:
        return 0
    cts = np.zeros(K) # len(item[0][2]))
    for x in item:
        act = x[0]
        cts[act] += 1
    return np.min(cts)/len(item)

def min_count_last_half(item,K=5):
    if len(item) == 0:
        return 0
    T = len(item)
    return min_count(item[int(T/2):],K)

def greedy_frac(item,K=5):
    if len(item) == 0:
        return 0
    score = 0
    cts = np.zeros(K) # len(item[0][2]))
    sums = np.zeros(K)
    # means = np.zeros(K) # len(item[0][2]))
    for x in item:
        act = x[0]
        rew = x[1]
        ## greedy actions are any with empirically best mean, or 0 count
        means = np.divide(sums, cts, out=np.zeros_like(sums), where=cts!=0)
        # means = sums/cts
        greedy_acts = np.flatnonzero(means == means.max())
        zero_acts = [i for i in range(len(cts)) if cts[i] == 0]
        if (act in greedy_acts or act in zero_acts): # and act != 0:
            score += 1
        sums[act] += rew
        # means[act] = (means[act]*cts[act] + rew)/(cts[act]+1)
        cts[act] += 1
    return score/len(item)

def ts_frac(item,K=5):
    ### Computes the fraction of rounds where the algorithm 
    ### chooses the posterior-best arm 
    ### (according to Beta-Bernoulli posterior computation)
    if len(item) == 0:
        return 0
    score = 0
    prior = [np.array([1,1]) for i in range(K) ]

    for x in item:
        act = x[0]
        rew = x[1]
        score += integrate.quad(
            lambda x: np.prod([beta.cdf(x,prior[z][0],prior[z][1]) for z in range(len(prior)) if z != act])
            *beta.pdf(x,prior[act][0], prior[act][1]), 0, 1, epsabs=0.001)[0]
        # means = np.array([z[0]/(z[0]+z[1]) for z in prior])
        # ts_acts = np.flatnonzero(means == means.max())
        # if act in ts_acts:
        #     score += 1
        prior[act] += np.array([rew,1-rew])
    return score/len(item)

### Aggregated statistics across all runs for an alg
def get_entropy(dl, alg):
    data = dl.get_action_freqs(alg)
    if data is None:
        return None
    lst = [entropy(data[i,:]/np.sum(data[i,:])) for i in range(data.shape[0])]
    return (np.mean(lst)/np.log(dl.K))

def get_greedy_frac(dl, alg):
    data = dl.all_results[alg]
    if len(data) == 0:
        return None
    return np.mean([greedy_frac(replicate,K=dl.K) for replicate in data])
            
def get_ts_frac(dl, alg,num_reps=40):
    data = dl.all_results[alg]
    if len(data) == 0:
        return None
    n = len(data)
    num_reps = min(num_reps,n)
    return np.mean([ts_frac(replicate,K=dl.K) for replicate in data[0:num_reps]])

def get_min_frac(dl,alg):
    tens = dl.get_action_tensor(alg)
    tmp = np.cumsum(tens,axis=1)[:,-1,:]
    tmp = np.min(tmp,axis=1)
    return np.mean(tmp)/dl.T

#### Statistics routines for puzzles

def puzzle_greedy_frac(data):
    emp_best = [np.flatnonzero(x[0] == x[0].max()) for x in data]
    choice = [x[2] for x in data]
    val = [1 if choice[i] in emp_best[i] else 0 for i in range(len(choice))]
    return (np.mean(val))

def puzzle_max_count(data):
    max_count = [np.flatnonzero(x[1] == x[1].max()) for x in data]
    choice = [x[2] for x in data]
    val = [1 if choice[i] in max_count[i] else 0 for i in range(len(choice))]
    return(np.mean(val))
    
def puzzle_opt_count(data):
    arr = np.array([x[2] for x in data])
    val = np.count_nonzero(arr==0)/len(arr)
    return(val)

def puzzle_min_count(data):
    min_count = [np.flatnonzero(x[1] == x[1].min()) for x in data]
    choice = [x[2] for x in data]
    val = [1 if choice[i] in min_count[i] else 0 for i in range(len(choice))]
    return np.mean(val)
