#from symbol import argument
import numpy as np

class RankingAgent:
    ''' 
    Dynamical Ranking Exploration-Exploitation (DREE) Algorithm instantiated with f(t):=log(t)**(1+delta)
    '''
    def __init__(self, n_arms, time_horizon, delta):
        self.n_arms = n_arms
        self.pulls = np.zeros(self.n_arms)
        self.hist = np.zeros(time_horizon)
        self.t = 0
        self.last_pull = None
        self.first = None
        self.delta = delta

    def reset(self):
        self.pulls = np.zeros(self.n_arms)
        self.hist = 0*self.hist
        self.t = 0
        self.last_pull = None
        self.first = None
   
    def pull_arm(self):
        f = np.log(self.t+1)**(1+self.delta)
        if self.t < self.n_arms:
            # initial Exploration
            self.last_pull = self.t            
        else:
            exp_arm = np.where(self.pulls <= f)[0]
            if (len(exp_arm) > 0):
                # Exploration
                self.last_pull = int(exp_arm[0])         
            else:
                # Exploitation
                self.last_pull = int(self.first)             
        
        arm = self.last_pull
        self.hist[self.t] = arm
        self.t += 1
        return int(arm)

    def update(self, first):
        self.pulls[self.last_pull] += 1
        self.first = first



class RLPE:

    def __init__(self, n_arms, time_horizon):
        self.n_arms = n_arms
        self.pulls = np.zeros(self.n_arms)
        self.hist = np.zeros(time_horizon)
        self.hist_first = np.zeros(self.n_arms)
        self.time_horizon = time_horizon

        # define loggrid
        logt = int(np.log(time_horizon))
        self.loggrid = self.time_horizon**np.linspace(1/2,1,logt)
        print('Loggrid defined')
        print(self.loggrid)
        self.logindex = 0

        self.t = 1
        self.last_pull = None
        self.new_arms = range(n_arms)
    
    def reset(self):
        self.pulls = np.zeros(self.n_arms)
        self.hist = 0*self.hist
        self.hist_first = np.zeros(self.n_arms)
        self.logindex = 0
        self.t = 0
        self.last_pull = None
        self.new_arms = range(self.n_arms)

    def get_least_pulled(self):
        pulls = self.time_horizon
        candidate_arm = 0
        for j in range(len(self.new_arms)):
            arm = self.new_arms[j]
            if (self.pulls[arm] < pulls):
                candidate_arm = arm
                pulls = self.pulls[arm]
        return candidate_arm, pulls
   
    def pull_arm(self):

        # pull the less pulled arm in the set of new arms
        candidate_arm, _ = self.get_least_pulled()

        self.hist[self.t] = candidate_arm
        self.last_pull = candidate_arm
        self.t += 1
        return int(candidate_arm)

    def update(self, first):

        self.pulls[self.last_pull] += 1
        _, pulls = self.get_least_pulled()

        # update this thing only if all the arms have been pulled the same number of times
        if pulls == np.max(self.pulls):            
            self.hist_first[int(first)] += 1

        if pulls > self.loggrid[self.logindex]: 
            # critical exponent
            # alpha = np.log(self.t)/np.log(self.time_horizon) - 1/2
            alpha = np.log(pulls)/np.log(self.time_horizon) - 1/2
            # print('t/n = {} pulls = {} filter = {}, n = {}'.format(self.t/self.n_arms, pulls, self.time_horizon**(2*alpha), np.sum(self.hist_first)))
            
            self.new_arms = np.array([j for j in range(self.n_arms) if self.hist_first[j] >= int(self.time_horizon**(2*alpha))])
            
            if self.logindex < len(self.loggrid)-1:
                self.logindex += 1



class ECagent:
    '''
    Explore then Commit (EC) Algorithm instantiated with m:=T**(2/3)
    '''
    def __init__(self, n_arms, time_horizon):
        self.n_arms = n_arms
        self.pulls = np.zeros(self.n_arms)
        self.hist = np.zeros(time_horizon)
        self.hist_first = np.zeros(self.n_arms)
        self.time_horizon = time_horizon
        self.t = 0
        self.last_pull = None
        self.new_arms = np.array([])

    def reset(self):
        self.pulls = np.zeros(self.n_arms)
        self.hist = 0*self.hist
        self.hist_first = np.zeros(self.n_arms)
        self.t = 0
        self.last_pull = None
        self.new_arms = np.array([])
   
    def pull_arm(self):
        
        if self.t == int(self.time_horizon**(2/3)*self.n_arms):
            self.new_arms = np.argmax(self.hist_first)

        if self.t < int(self.time_horizon**(2/3)*self.n_arms):
            # Exploration
            self.last_pull = np.mod(self.t, self.n_arms)         
        else: 
            # Exploitation 
            self.last_pull = self.new_arms
        
        arm = self.last_pull
        self.hist[self.t] = arm
        self.t += 1
        return int(arm)

    def update(self, first):
        self.pulls[self.last_pull] += 1
        self.hist_first[int(first)] +=1
